from __future__ import annotations import json import logging import os import re from typing import Iterable, Tuple import torch from transformers import AutoModelForCausalLM, AutoTokenizer from .models import BlockKey, ClientProfile, SessionState logger = logging.getLogger(__name__) class LocalLLM: def __init__(self, model_name: str, local_path: str | None = None) -> None: self._model_name = model_name self._local_path = local_path or "/models/llm" self._device = torch.device("cpu") self._tokenizer: AutoTokenizer | None = None self._model: AutoModelForCausalLM | None = None self._available = False self._load_model() def _load_model(self) -> None: path = self._local_path if os.path.exists(self._local_path) else self._model_name try: self._tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self._model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float32, trust_remote_code=True, ) if self._tokenizer.pad_token_id is None: self._tokenizer.pad_token_id = self._tokenizer.eos_token_id self._model.to(self._device) self._available = True except Exception as exc: # pragma: no cover - model may fail locally logger.warning("Failed to load LLM %s: %s", path, exc) self._available = False @property def available(self) -> bool: return self._available and self._tokenizer is not None and self._model is not None def _generate( self, prompt: str, *, temperature: float, max_tokens: int, do_sample: bool = True, ) -> str: if not self.available or not self._tokenizer or not self._model: raise RuntimeError("LLM is not available") inputs = self._tokenizer(prompt, return_tensors="pt").to(self._device) with torch.no_grad(): output = self._model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=do_sample, eos_token_id=self._tokenizer.eos_token_id, pad_token_id=self._tokenizer.pad_token_id, ) completion = output[0][inputs["input_ids"].shape[-1]:] text = self._tokenizer.decode(completion, skip_special_tokens=True) return text.strip() def _sanitize_question(self, text: str) -> str: cleaned = text.strip() if not cleaned: return "" cleaned = re.sub( r"^(вопрос|запрос|пример|например)\s*[:\-]\s*", "", cleaned, flags=re.IGNORECASE, ) cleaned = cleaned.strip(" \"'") cleaned = re.sub(r"\s+", " ", cleaned).strip() match = re.search(r".+?\?", cleaned) if match: cleaned = match.group(0) else: match = re.search(r".+?[.!]", cleaned) if match: cleaned = match.group(0).rstrip(".!") + "?" if len(cleaned) > 160: cleaned = cleaned[:157].rstrip() + "?" if cleaned and not cleaned.endswith("?"): cleaned += "?" if len(cleaned) < 5: return "" return cleaned def generate_question( self, *, block: BlockKey, client: ClientProfile, answered_pairs: Iterable[Tuple[str, str]], ) -> str: prompt = self._question_prompt(block, client, answered_pairs) raw = self._generate(prompt, temperature=0.2, max_tokens=80) return self._sanitize_question(raw) def generate_report(self, state: SessionState) -> str: prompt = self._report_prompt(state) return self._generate(prompt, temperature=0.0, max_tokens=600, do_sample=False) def classify_topics(self, *, block: BlockKey, text: str) -> dict[str, bool]: topics = self._topic_keys(block) if not topics: return {} prompt = self._topics_prompt(block, text, topics) raw = self._generate(prompt, temperature=0.0, max_tokens=160, do_sample=False) return self._parse_topics(raw, topics) def _question_prompt( self, block: BlockKey, client: ClientProfile, answered_pairs: Iterable[Tuple[str, str]], ) -> str: block_titles = { BlockKey.health: "здоровье, травмы и опыт", BlockKey.goals: "цели клиента", BlockKey.readiness: "готовность к режиму и формат", } answers = "\n".join( f"Вопрос: {question}\nОтвет клиента: {answer}" for question, answer in answered_pairs ) or "Ответов пока нет." return ( "Ты - кроссфит-тренер. Общайся уважительно, по-русски, и задавай один короткий вопрос."\ + f"\nКлиент: {client.name}, формат: {client.preferred_format}."\ + f"\nТекущий блок: {block_titles.get(block, block.value)}."\ + f"\nИзвестные ответы:\n{answers}"\ + "\nСформулируй один вопрос (до 120 символов) без пояснений, списков и слов 'Вопрос'/'Запрос'." ) def _report_prompt(self, state: SessionState) -> str: sections = [] for question in state.questions: answer = state.transcripts.get(question.id, "Нет ответа") sections.append(f"- {question.prompt}\n Ответ: {answer}") convo = "\n".join(sections) return ( "Ты пишешь отчёт для клиента, который заполнил фитнес-опрос."\ + "\nИспользуй только факты из ответов. Ничего не выдумывай."\ + "\nЗапрещено добавлять возраст, профессию, диагнозы или личные данные, если их нет в ответах."\ + "\nЕсли данных недостаточно, пиши: 'Не указано' или 'Нужно уточнить'."\ + "\nНе обращайся к тренеру и не давай советов тренеру — обращайся к клиенту на 'вы'."\ + "\nНе пересказывай ответы дословно, делай выводы и рекомендации."\ + "\nНе повторяй разделы."\ + "\nФормат: заголовки и списки."\ + "\nРазделы: Краткий профиль, Риски и ограничения, Цели и метрики, Рекомендации, Следующие шаги."\ + "\nТон деловой, по-русски. Без медицинских диагнозов и запугивания."\ + f"\nИмя клиента: {state.client.name}. Электронная почта: {state.client.email}. Формат: {state.client.preferred_format}."\ + f"\nОтветы клиента:\n{convo}"\ + "\nВерни только отчёт, без пояснений." ) def _topic_keys(self, block: BlockKey) -> list[str]: if block == BlockKey.health: return ["injuries", "training_history", "technique_limitations"] if block == BlockKey.goals: return ["primary_goal", "timeline", "success_metrics"] if block == BlockKey.readiness: return ["weekly_frequency", "equipment_location", "format_preference"] return [] def _topics_prompt(self, block: BlockKey, text: str, topics: list[str]) -> str: block_labels = { BlockKey.health: "здоровье и опыт", BlockKey.goals: "цели", BlockKey.readiness: "готовность и формат", } topics_list = ", ".join(topics) return ( "Ты помощник тренера. Определи, какие темы уже упомянуты в ответах клиента."\ + "\nОтветь строго JSON-объектом с ключами из списка, значения true/false."\ + f"\nБлок: {block_labels.get(block, block.value)}."\ + f"\nТемы: {topics_list}."\ + f"\nОтветы клиента: {text}"\ + "\nВерни только JSON без пояснений." ) def _parse_topics(self, raw: str, topics: list[str]) -> dict[str, bool]: payload = self._extract_json(raw) if not payload: return {} data = {} for key in topics: value = payload.get(key) if isinstance(value, bool): data[key] = value return data @staticmethod def _extract_json(raw: str) -> dict: if not raw: return {} match = re.search(r"\{.*\}", raw, flags=re.DOTALL) if not match: return {} snippet = match.group(0) try: return json.loads(snippet) except json.JSONDecodeError: return {}