| """Загрузка модели + LoRA-адаптера и инференс. |
| |
| На десктопе/ноутбуке без GPU работает на CPU. Медленно, но достаточно для |
| разработки и демо. На Kaggle/Colab — на GPU, быстрее. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from src.business.vocabulary import BusinessVocabulary |
| from src.config import settings |
| from src.data.prompt import build_chat_messages |
| from src.models.postprocess import postprocess |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class GenerationResult: |
| sql: str |
| raw_output: str |
|
|
|
|
| class InferenceEngine: |
| """Singleton-обёртка над моделью. Загружается один раз при старте API.""" |
|
|
| def __init__( |
| self, |
| base_model_name: str | None = None, |
| lora_adapter_path: str | None = None, |
| device: str | None = None, |
| ): |
| self.base_model_name = base_model_name or settings.base_model_name |
| self.lora_adapter_path = lora_adapter_path or settings.lora_adapter_path |
| self.device = device or settings.device |
| self.tokenizer = None |
| self.model = None |
| self._loaded = False |
|
|
| @property |
| def loaded(self) -> bool: |
| """Публичное свойство — статус загрузки модели.""" |
| return self._loaded |
|
|
| def load(self) -> None: |
| """Лениво грузим модель. На CPU без квантизации.""" |
| if self._loaded: |
| return |
|
|
| logger.info("Загрузка базовой модели %s на устройство %s", |
| self.base_model_name, self.device) |
| self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name) |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.base_model_name, |
| torch_dtype=torch.bfloat16, |
| device_map=self.device if self.device != "cpu" else None, |
| ) |
|
|
| |
| adapter_path = Path(self.lora_adapter_path) |
| adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path |
| try: |
| from peft import PeftModel |
| logger.info("Подключение LoRA-адаптера %s", adapter_id) |
| self.model = PeftModel.from_pretrained(self.model, adapter_id) |
| except ImportError: |
| logger.warning("peft не установлен, используется базовая модель без LoRA") |
| except Exception as e: |
| logger.warning("Не удалось подгрузить LoRA-адаптер %s: %s", |
| adapter_id, e) |
|
|
| self.model.eval() |
| self._loaded = True |
| logger.info("InferenceEngine готов к работе") |
|
|
| def generate( |
| self, |
| schema: str, |
| question: str, |
| vocabulary: BusinessVocabulary | None = None, |
| max_new_tokens: int | None = None, |
| ) -> GenerationResult: |
| """Принимает schema (текст DDL) и вопрос, возвращает SQL. |
| |
| Если передан непустой ``vocabulary``, бизнес-термины компании |
| подмешиваются в системное сообщение через PromptBuilder. |
| Это соответствует разделу 3.6 пояснительной записки. |
| """ |
| if not self._loaded: |
| self.load() |
|
|
| messages = build_chat_messages(schema, question, vocabulary=vocabulary) |
| prompt = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
| |
| |
| gen_kwargs = { |
| "max_new_tokens": max_new_tokens or settings.max_new_tokens, |
| "do_sample": settings.do_sample, |
| "pad_token_id": self.tokenizer.eos_token_id, |
| } |
| if settings.do_sample: |
| gen_kwargs["temperature"] = settings.temperature |
|
|
| with torch.no_grad(): |
| output_ids = self.model.generate(**inputs, **gen_kwargs) |
|
|
| new_tokens = output_ids[0][inputs["input_ids"].shape[1]:] |
| raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| return GenerationResult(sql=postprocess(raw), raw_output=raw) |
|
|