File size: 5,061 Bytes
8871df9
 
cc2ed2f
 
8871df9
 
 
 
cc2ed2f
8871df9
 
 
 
 
 
cc2ed2f
8871df9
 
 
 
cc2ed2f
 
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2ed2f
 
 
 
 
8871df9
 
 
 
 
cc2ed2f
 
8871df9
 
 
 
f12c26c
8871df9
 
 
 
 
 
 
 
cc2ed2f
8871df9
 
cc2ed2f
 
 
 
8871df9
 
 
cc2ed2f
8871df9
 
 
 
 
cc2ed2f
8871df9
 
cc2ed2f
 
 
 
 
 
8871df9
 
 
cc2ed2f
8871df9
 
 
 
 
cc2ed2f
 
 
 
 
 
 
 
 
 
8871df9
cc2ed2f
 
 
8871df9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Загрузка модели + LoRA-адаптера и инференс.

На десктопе/ноутбуке без GPU работает на CPU. Медленно, но достаточно для
разработки и демо. На Kaggle/Colab — на GPU, быстрее.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.business.vocabulary import BusinessVocabulary
from src.config import settings
from src.data.prompt import build_chat_messages
from src.models.postprocess import postprocess

logger = logging.getLogger(__name__)


@dataclass
class GenerationResult:
    sql: str
    raw_output: str


class InferenceEngine:
    """Singleton-обёртка над моделью. Загружается один раз при старте API."""

    def __init__(
        self,
        base_model_name: str | None = None,
        lora_adapter_path: str | None = None,
        device: str | None = None,
    ):
        self.base_model_name = base_model_name or settings.base_model_name
        self.lora_adapter_path = lora_adapter_path or settings.lora_adapter_path
        self.device = device or settings.device
        self.tokenizer = None
        self.model = None
        self._loaded = False

    @property
    def loaded(self) -> bool:
        """Публичное свойство — статус загрузки модели."""
        return self._loaded

    def load(self) -> None:
        """Лениво грузим модель. На CPU без квантизации."""
        if self._loaded:
            return

        logger.info("Загрузка базовой модели %s на устройство %s",
                    self.base_model_name, self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
        # bfloat16 вдвое меньше float32 (~6 ГБ vs ~12 ГБ) и поддерживается на CPU
        self.model = AutoModelForCausalLM.from_pretrained(
            self.base_model_name,
            torch_dtype=torch.bfloat16,
            device_map=self.device if self.device != "cpu" else None,
        )

        # Подцепляем LoRA-адаптер: сначала ищем локально, потом на HF Hub
        adapter_path = Path(self.lora_adapter_path)
        adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path
        try:
            from peft import PeftModel
            logger.info("Подключение LoRA-адаптера %s", adapter_id)
            self.model = PeftModel.from_pretrained(self.model, adapter_id)
        except ImportError:
            logger.warning("peft не установлен, используется базовая модель без LoRA")
        except Exception as e:  # noqa: BLE001 — лог достаточен, без падения
            logger.warning("Не удалось подгрузить LoRA-адаптер %s: %s",
                           adapter_id, e)

        self.model.eval()
        self._loaded = True
        logger.info("InferenceEngine готов к работе")

    def generate(
        self,
        schema: str,
        question: str,
        vocabulary: BusinessVocabulary | None = None,
        max_new_tokens: int | None = None,
    ) -> GenerationResult:
        """Принимает schema (текст DDL) и вопрос, возвращает SQL.

        Если передан непустой ``vocabulary``, бизнес-термины компании
        подмешиваются в системное сообщение через PromptBuilder.
        Это соответствует разделу 3.6 пояснительной записки.
        """
        if not self._loaded:
            self.load()

        messages = build_chat_messages(schema, question, vocabulary=vocabulary)
        prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        # Параметры сэмплинга. При do_sample=False temperature игнорируется,
        # поэтому не передаём её — иначе transformers выводит warning.
        gen_kwargs = {
            "max_new_tokens": max_new_tokens or settings.max_new_tokens,
            "do_sample": settings.do_sample,
            "pad_token_id": self.tokenizer.eos_token_id,
        }
        if settings.do_sample:
            gen_kwargs["temperature"] = settings.temperature

        with torch.no_grad():
            output_ids = self.model.generate(**inputs, **gen_kwargs)

        new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
        raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
        return GenerationResult(sql=postprocess(raw), raw_output=raw)