File size: 5,061 Bytes
8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 f12c26c 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | """Загрузка модели + LoRA-адаптера и инференс.
На десктопе/ноутбуке без GPU работает на CPU. Медленно, но достаточно для
разработки и демо. На Kaggle/Colab — на GPU, быстрее.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.business.vocabulary import BusinessVocabulary
from src.config import settings
from src.data.prompt import build_chat_messages
from src.models.postprocess import postprocess
logger = logging.getLogger(__name__)
@dataclass
class GenerationResult:
sql: str
raw_output: str
class InferenceEngine:
"""Singleton-обёртка над моделью. Загружается один раз при старте API."""
def __init__(
self,
base_model_name: str | None = None,
lora_adapter_path: str | None = None,
device: str | None = None,
):
self.base_model_name = base_model_name or settings.base_model_name
self.lora_adapter_path = lora_adapter_path or settings.lora_adapter_path
self.device = device or settings.device
self.tokenizer = None
self.model = None
self._loaded = False
@property
def loaded(self) -> bool:
"""Публичное свойство — статус загрузки модели."""
return self._loaded
def load(self) -> None:
"""Лениво грузим модель. На CPU без квантизации."""
if self._loaded:
return
logger.info("Загрузка базовой модели %s на устройство %s",
self.base_model_name, self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
# bfloat16 вдвое меньше float32 (~6 ГБ vs ~12 ГБ) и поддерживается на CPU
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
torch_dtype=torch.bfloat16,
device_map=self.device if self.device != "cpu" else None,
)
# Подцепляем LoRA-адаптер: сначала ищем локально, потом на HF Hub
adapter_path = Path(self.lora_adapter_path)
adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path
try:
from peft import PeftModel
logger.info("Подключение LoRA-адаптера %s", adapter_id)
self.model = PeftModel.from_pretrained(self.model, adapter_id)
except ImportError:
logger.warning("peft не установлен, используется базовая модель без LoRA")
except Exception as e: # noqa: BLE001 — лог достаточен, без падения
logger.warning("Не удалось подгрузить LoRA-адаптер %s: %s",
adapter_id, e)
self.model.eval()
self._loaded = True
logger.info("InferenceEngine готов к работе")
def generate(
self,
schema: str,
question: str,
vocabulary: BusinessVocabulary | None = None,
max_new_tokens: int | None = None,
) -> GenerationResult:
"""Принимает schema (текст DDL) и вопрос, возвращает SQL.
Если передан непустой ``vocabulary``, бизнес-термины компании
подмешиваются в системное сообщение через PromptBuilder.
Это соответствует разделу 3.6 пояснительной записки.
"""
if not self._loaded:
self.load()
messages = build_chat_messages(schema, question, vocabulary=vocabulary)
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Параметры сэмплинга. При do_sample=False temperature игнорируется,
# поэтому не передаём её — иначе transformers выводит warning.
gen_kwargs = {
"max_new_tokens": max_new_tokens or settings.max_new_tokens,
"do_sample": settings.do_sample,
"pad_token_id": self.tokenizer.eos_token_id,
}
if settings.do_sample:
gen_kwargs["temperature"] = settings.temperature
with torch.no_grad():
output_ids = self.model.generate(**inputs, **gen_kwargs)
new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return GenerationResult(sql=postprocess(raw), raw_output=raw)
|