Ru2SQL / src /models /inference.py
Tyycha's picture
initial commit
8871df9
raw
history blame
3.56 kB
"""Загрузка модели + LoRA-адаптера и инференс.
На десктопе/ноутбуке без GPU работает на CPU. Медленно, но достаточно для разработки и демо.
На Kaggle/Colab — на GPU, быстрее.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.config import settings
from src.data.prompt import build_chat_messages
from src.models.postprocess import postprocess
@dataclass
class GenerationResult:
sql: str
raw_output: str
class InferenceEngine:
"""Singleton-обёртка над моделью. Загружается один раз при старте API."""
def __init__(
self,
base_model_name: str | None = None,
lora_adapter_path: str | None = None,
device: str | None = None,
):
self.base_model_name = base_model_name or settings.base_model_name
self.lora_adapter_path = lora_adapter_path or settings.lora_adapter_path
self.device = device or settings.device
self.tokenizer = None
self.model = None
self._loaded = False
def load(self) -> None:
"""Лениво грузим модель. На CPU без квантизации."""
if self._loaded:
return
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
# bfloat16 вдвое меньше float32 (~6 ГБ vs ~12 ГБ) и поддерживается на CPU
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
dtype=torch.bfloat16,
device_map=self.device if self.device != "cpu" else None,
)
# Подцепляем LoRA-адаптер: сначала ищем локально, потом на HF Hub
adapter_path = Path(self.lora_adapter_path)
adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path
try:
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.model, adapter_id)
except ImportError:
pass # peft не установлен — работаем на базовой модели
self.model.eval()
self._loaded = True
def generate(
self,
schema: str,
question: str,
max_new_tokens: int | None = None,
) -> GenerationResult:
"""Принимает schema (текст DDL) и вопрос, возвращает SQL."""
if not self._loaded:
self.load()
messages = build_chat_messages(schema, question)
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens or settings.max_new_tokens,
do_sample=settings.do_sample,
temperature=settings.temperature if settings.do_sample else 1.0,
pad_token_id=self.tokenizer.eos_token_id,
)
new_tokens = output_ids[0][inputs["input_ids"].shape[1] :]
raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return GenerationResult(sql=postprocess(raw), raw_output=raw)