Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| from typing import Iterable, Tuple | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from .models import BlockKey, ClientProfile, SessionState | |
| logger = logging.getLogger(__name__) | |
| class LocalLLM: | |
| def __init__(self, model_name: str, local_path: str | None = None) -> None: | |
| self._model_name = model_name | |
| self._local_path = local_path or "/models/llm" | |
| self._device = torch.device("cpu") | |
| self._tokenizer: AutoTokenizer | None = None | |
| self._model: AutoModelForCausalLM | None = None | |
| self._available = False | |
| self._load_model() | |
| def _load_model(self) -> None: | |
| path = self._local_path if os.path.exists(self._local_path) else self._model_name | |
| try: | |
| self._tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) | |
| self._model = AutoModelForCausalLM.from_pretrained( | |
| path, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| ) | |
| if self._tokenizer.pad_token_id is None: | |
| self._tokenizer.pad_token_id = self._tokenizer.eos_token_id | |
| self._model.to(self._device) | |
| self._available = True | |
| except Exception as exc: # pragma: no cover - model may fail locally | |
| logger.warning("Failed to load LLM %s: %s", path, exc) | |
| self._available = False | |
| def available(self) -> bool: | |
| return self._available and self._tokenizer is not None and self._model is not None | |
| def _generate( | |
| self, | |
| prompt: str, | |
| *, | |
| temperature: float, | |
| max_tokens: int, | |
| do_sample: bool = True, | |
| ) -> str: | |
| if not self.available or not self._tokenizer or not self._model: | |
| raise RuntimeError("LLM is not available") | |
| inputs = self._tokenizer(prompt, return_tensors="pt").to(self._device) | |
| with torch.no_grad(): | |
| output = self._model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=do_sample, | |
| eos_token_id=self._tokenizer.eos_token_id, | |
| pad_token_id=self._tokenizer.pad_token_id, | |
| ) | |
| completion = output[0][inputs["input_ids"].shape[-1]:] | |
| text = self._tokenizer.decode(completion, skip_special_tokens=True) | |
| return text.strip() | |
| def _sanitize_question(self, text: str) -> str: | |
| cleaned = text.strip() | |
| if not cleaned: | |
| return "" | |
| cleaned = re.sub( | |
| r"^(вопрос|запрос|пример|например)\s*[:\-]\s*", | |
| "", | |
| cleaned, | |
| flags=re.IGNORECASE, | |
| ) | |
| cleaned = cleaned.strip(" \"'") | |
| cleaned = re.sub(r"\s+", " ", cleaned).strip() | |
| match = re.search(r".+?\?", cleaned) | |
| if match: | |
| cleaned = match.group(0) | |
| else: | |
| match = re.search(r".+?[.!]", cleaned) | |
| if match: | |
| cleaned = match.group(0).rstrip(".!") + "?" | |
| if len(cleaned) > 160: | |
| cleaned = cleaned[:157].rstrip() + "?" | |
| if cleaned and not cleaned.endswith("?"): | |
| cleaned += "?" | |
| if len(cleaned) < 5: | |
| return "" | |
| return cleaned | |
| def generate_question( | |
| self, | |
| *, | |
| block: BlockKey, | |
| client: ClientProfile, | |
| answered_pairs: Iterable[Tuple[str, str]], | |
| ) -> str: | |
| prompt = self._question_prompt(block, client, answered_pairs) | |
| raw = self._generate(prompt, temperature=0.2, max_tokens=80) | |
| return self._sanitize_question(raw) | |
| def generate_report(self, state: SessionState) -> str: | |
| prompt = self._report_prompt(state) | |
| return self._generate(prompt, temperature=0.0, max_tokens=600, do_sample=False) | |
| def classify_topics(self, *, block: BlockKey, text: str) -> dict[str, bool]: | |
| topics = self._topic_keys(block) | |
| if not topics: | |
| return {} | |
| prompt = self._topics_prompt(block, text, topics) | |
| raw = self._generate(prompt, temperature=0.0, max_tokens=160, do_sample=False) | |
| return self._parse_topics(raw, topics) | |
| def _question_prompt( | |
| self, | |
| block: BlockKey, | |
| client: ClientProfile, | |
| answered_pairs: Iterable[Tuple[str, str]], | |
| ) -> str: | |
| block_titles = { | |
| BlockKey.health: "здоровье, травмы и опыт", | |
| BlockKey.goals: "цели клиента", | |
| BlockKey.readiness: "готовность к режиму и формат", | |
| } | |
| answers = "\n".join( | |
| f"Вопрос: {question}\nОтвет клиента: {answer}" | |
| for question, answer in answered_pairs | |
| ) or "Ответов пока нет." | |
| return ( | |
| "Ты - кроссфит-тренер. Общайся уважительно, по-русски, и задавай один короткий вопрос."\ | |
| + f"\nКлиент: {client.name}, формат: {client.preferred_format}."\ | |
| + f"\nТекущий блок: {block_titles.get(block, block.value)}."\ | |
| + f"\nИзвестные ответы:\n{answers}"\ | |
| + "\nСформулируй один вопрос (до 120 символов) без пояснений, списков и слов 'Вопрос'/'Запрос'." | |
| ) | |
| def _report_prompt(self, state: SessionState) -> str: | |
| sections = [] | |
| for question in state.questions: | |
| answer = state.transcripts.get(question.id, "Нет ответа") | |
| sections.append(f"- {question.prompt}\n Ответ: {answer}") | |
| convo = "\n".join(sections) | |
| return ( | |
| "Ты пишешь отчёт для клиента, который заполнил фитнес-опрос."\ | |
| + "\nИспользуй только факты из ответов. Ничего не выдумывай."\ | |
| + "\nЗапрещено добавлять возраст, профессию, диагнозы или личные данные, если их нет в ответах."\ | |
| + "\nЕсли данных недостаточно, пиши: 'Не указано' или 'Нужно уточнить'."\ | |
| + "\nНе обращайся к тренеру и не давай советов тренеру — обращайся к клиенту на 'вы'."\ | |
| + "\nНе пересказывай ответы дословно, делай выводы и рекомендации."\ | |
| + "\nНе повторяй разделы."\ | |
| + "\nФормат: заголовки и списки."\ | |
| + "\nРазделы: Краткий профиль, Риски и ограничения, Цели и метрики, Рекомендации, Следующие шаги."\ | |
| + "\nТон деловой, по-русски. Без медицинских диагнозов и запугивания."\ | |
| + f"\nИмя клиента: {state.client.name}. Электронная почта: {state.client.email}. Формат: {state.client.preferred_format}."\ | |
| + f"\nОтветы клиента:\n{convo}"\ | |
| + "\nВерни только отчёт, без пояснений." | |
| ) | |
| def _topic_keys(self, block: BlockKey) -> list[str]: | |
| if block == BlockKey.health: | |
| return ["injuries", "training_history", "technique_limitations"] | |
| if block == BlockKey.goals: | |
| return ["primary_goal", "timeline", "success_metrics"] | |
| if block == BlockKey.readiness: | |
| return ["weekly_frequency", "equipment_location", "format_preference"] | |
| return [] | |
| def _topics_prompt(self, block: BlockKey, text: str, topics: list[str]) -> str: | |
| block_labels = { | |
| BlockKey.health: "здоровье и опыт", | |
| BlockKey.goals: "цели", | |
| BlockKey.readiness: "готовность и формат", | |
| } | |
| topics_list = ", ".join(topics) | |
| return ( | |
| "Ты помощник тренера. Определи, какие темы уже упомянуты в ответах клиента."\ | |
| + "\nОтветь строго JSON-объектом с ключами из списка, значения true/false."\ | |
| + f"\nБлок: {block_labels.get(block, block.value)}."\ | |
| + f"\nТемы: {topics_list}."\ | |
| + f"\nОтветы клиента: {text}"\ | |
| + "\nВерни только JSON без пояснений." | |
| ) | |
| def _parse_topics(self, raw: str, topics: list[str]) -> dict[str, bool]: | |
| payload = self._extract_json(raw) | |
| if not payload: | |
| return {} | |
| data = {} | |
| for key in topics: | |
| value = payload.get(key) | |
| if isinstance(value, bool): | |
| data[key] = value | |
| return data | |
| def _extract_json(raw: str) -> dict: | |
| if not raw: | |
| return {} | |
| match = re.search(r"\{.*\}", raw, flags=re.DOTALL) | |
| if not match: | |
| return {} | |
| snippet = match.group(0) | |
| try: | |
| return json.loads(snippet) | |
| except json.JSONDecodeError: | |
| return {} | |