| """Загрузка модели + LoRA-адаптера и инференс. |
| |
| На десктопе/ноутбуке без GPU работает на CPU. Медленно, но достаточно для разработки и демо. |
| На Kaggle/Colab — на GPU, быстрее. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from src.config import settings |
| from src.data.prompt import build_chat_messages |
| from src.models.postprocess import postprocess |
|
|
|
|
| @dataclass |
| class GenerationResult: |
| sql: str |
| raw_output: str |
|
|
|
|
| class InferenceEngine: |
| """Singleton-обёртка над моделью. Загружается один раз при старте API.""" |
|
|
| def __init__( |
| self, |
| base_model_name: str | None = None, |
| lora_adapter_path: str | None = None, |
| device: str | None = None, |
| ): |
| self.base_model_name = base_model_name or settings.base_model_name |
| self.lora_adapter_path = lora_adapter_path or settings.lora_adapter_path |
| self.device = device or settings.device |
| self.tokenizer = None |
| self.model = None |
| self._loaded = False |
|
|
| def load(self) -> None: |
| """Лениво грузим модель. На CPU без квантизации.""" |
| if self._loaded: |
| return |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name) |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.base_model_name, |
| dtype=torch.bfloat16, |
| device_map=self.device if self.device != "cpu" else None, |
| ) |
|
|
| |
| adapter_path = Path(self.lora_adapter_path) |
| adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path |
| try: |
| from peft import PeftModel |
| self.model = PeftModel.from_pretrained(self.model, adapter_id) |
| except ImportError: |
| pass |
|
|
| self.model.eval() |
| self._loaded = True |
|
|
| def generate( |
| self, |
| schema: str, |
| question: str, |
| max_new_tokens: int | None = None, |
| ) -> GenerationResult: |
| """Принимает schema (текст DDL) и вопрос, возвращает SQL.""" |
| if not self._loaded: |
| self.load() |
|
|
| messages = build_chat_messages(schema, question) |
| prompt = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
| with torch.no_grad(): |
| output_ids = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens or settings.max_new_tokens, |
| do_sample=settings.do_sample, |
| temperature=settings.temperature if settings.do_sample else 1.0, |
| pad_token_id=self.tokenizer.eos_token_id, |
| ) |
|
|
| new_tokens = output_ids[0][inputs["input_ids"].shape[1] :] |
| raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| return GenerationResult(sql=postprocess(raw), raw_output=raw) |
|
|