andrewchernish1-ui
feat: tighten llm report prompt
17de0ca
from __future__ import annotations
import json
import logging
import os
import re
from typing import Iterable, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .models import BlockKey, ClientProfile, SessionState
logger = logging.getLogger(__name__)
class LocalLLM:
def __init__(self, model_name: str, local_path: str | None = None) -> None:
self._model_name = model_name
self._local_path = local_path or "/models/llm"
self._device = torch.device("cpu")
self._tokenizer: AutoTokenizer | None = None
self._model: AutoModelForCausalLM | None = None
self._available = False
self._load_model()
def _load_model(self) -> None:
path = self._local_path if os.path.exists(self._local_path) else self._model_name
try:
self._tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self._model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float32,
trust_remote_code=True,
)
if self._tokenizer.pad_token_id is None:
self._tokenizer.pad_token_id = self._tokenizer.eos_token_id
self._model.to(self._device)
self._available = True
except Exception as exc: # pragma: no cover - model may fail locally
logger.warning("Failed to load LLM %s: %s", path, exc)
self._available = False
@property
def available(self) -> bool:
return self._available and self._tokenizer is not None and self._model is not None
def _generate(
self,
prompt: str,
*,
temperature: float,
max_tokens: int,
do_sample: bool = True,
) -> str:
if not self.available or not self._tokenizer or not self._model:
raise RuntimeError("LLM is not available")
inputs = self._tokenizer(prompt, return_tensors="pt").to(self._device)
with torch.no_grad():
output = self._model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=do_sample,
eos_token_id=self._tokenizer.eos_token_id,
pad_token_id=self._tokenizer.pad_token_id,
)
completion = output[0][inputs["input_ids"].shape[-1]:]
text = self._tokenizer.decode(completion, skip_special_tokens=True)
return text.strip()
def _sanitize_question(self, text: str) -> str:
cleaned = text.strip()
if not cleaned:
return ""
cleaned = re.sub(
r"^(вопрос|запрос|пример|например)\s*[:\-]\s*",
"",
cleaned,
flags=re.IGNORECASE,
)
cleaned = cleaned.strip(" \"'")
cleaned = re.sub(r"\s+", " ", cleaned).strip()
match = re.search(r".+?\?", cleaned)
if match:
cleaned = match.group(0)
else:
match = re.search(r".+?[.!]", cleaned)
if match:
cleaned = match.group(0).rstrip(".!") + "?"
if len(cleaned) > 160:
cleaned = cleaned[:157].rstrip() + "?"
if cleaned and not cleaned.endswith("?"):
cleaned += "?"
if len(cleaned) < 5:
return ""
return cleaned
def generate_question(
self,
*,
block: BlockKey,
client: ClientProfile,
answered_pairs: Iterable[Tuple[str, str]],
) -> str:
prompt = self._question_prompt(block, client, answered_pairs)
raw = self._generate(prompt, temperature=0.2, max_tokens=80)
return self._sanitize_question(raw)
def generate_report(self, state: SessionState) -> str:
prompt = self._report_prompt(state)
return self._generate(prompt, temperature=0.0, max_tokens=600, do_sample=False)
def classify_topics(self, *, block: BlockKey, text: str) -> dict[str, bool]:
topics = self._topic_keys(block)
if not topics:
return {}
prompt = self._topics_prompt(block, text, topics)
raw = self._generate(prompt, temperature=0.0, max_tokens=160, do_sample=False)
return self._parse_topics(raw, topics)
def _question_prompt(
self,
block: BlockKey,
client: ClientProfile,
answered_pairs: Iterable[Tuple[str, str]],
) -> str:
block_titles = {
BlockKey.health: "здоровье, травмы и опыт",
BlockKey.goals: "цели клиента",
BlockKey.readiness: "готовность к режиму и формат",
}
answers = "\n".join(
f"Вопрос: {question}\nОтвет клиента: {answer}"
for question, answer in answered_pairs
) or "Ответов пока нет."
return (
"Ты - кроссфит-тренер. Общайся уважительно, по-русски, и задавай один короткий вопрос."\
+ f"\nКлиент: {client.name}, формат: {client.preferred_format}."\
+ f"\nТекущий блок: {block_titles.get(block, block.value)}."\
+ f"\nИзвестные ответы:\n{answers}"\
+ "\nСформулируй один вопрос (до 120 символов) без пояснений, списков и слов 'Вопрос'/'Запрос'."
)
def _report_prompt(self, state: SessionState) -> str:
sections = []
for question in state.questions:
answer = state.transcripts.get(question.id, "Нет ответа")
sections.append(f"- {question.prompt}\n Ответ: {answer}")
convo = "\n".join(sections)
return (
"Ты пишешь отчёт для клиента, который заполнил фитнес-опрос."\
+ "\nИспользуй только факты из ответов. Ничего не выдумывай."\
+ "\nЗапрещено добавлять возраст, профессию, диагнозы или личные данные, если их нет в ответах."\
+ "\nЕсли данных недостаточно, пиши: 'Не указано' или 'Нужно уточнить'."\
+ "\nНе обращайся к тренеру и не давай советов тренеру — обращайся к клиенту на 'вы'."\
+ "\nНе пересказывай ответы дословно, делай выводы и рекомендации."\
+ "\nНе повторяй разделы."\
+ "\nФормат: заголовки и списки."\
+ "\nРазделы: Краткий профиль, Риски и ограничения, Цели и метрики, Рекомендации, Следующие шаги."\
+ "\nТон деловой, по-русски. Без медицинских диагнозов и запугивания."\
+ f"\nИмя клиента: {state.client.name}. Электронная почта: {state.client.email}. Формат: {state.client.preferred_format}."\
+ f"\nОтветы клиента:\n{convo}"\
+ "\nВерни только отчёт, без пояснений."
)
def _topic_keys(self, block: BlockKey) -> list[str]:
if block == BlockKey.health:
return ["injuries", "training_history", "technique_limitations"]
if block == BlockKey.goals:
return ["primary_goal", "timeline", "success_metrics"]
if block == BlockKey.readiness:
return ["weekly_frequency", "equipment_location", "format_preference"]
return []
def _topics_prompt(self, block: BlockKey, text: str, topics: list[str]) -> str:
block_labels = {
BlockKey.health: "здоровье и опыт",
BlockKey.goals: "цели",
BlockKey.readiness: "готовность и формат",
}
topics_list = ", ".join(topics)
return (
"Ты помощник тренера. Определи, какие темы уже упомянуты в ответах клиента."\
+ "\nОтветь строго JSON-объектом с ключами из списка, значения true/false."\
+ f"\nБлок: {block_labels.get(block, block.value)}."\
+ f"\nТемы: {topics_list}."\
+ f"\nОтветы клиента: {text}"\
+ "\nВерни только JSON без пояснений."
)
def _parse_topics(self, raw: str, topics: list[str]) -> dict[str, bool]:
payload = self._extract_json(raw)
if not payload:
return {}
data = {}
for key in topics:
value = payload.get(key)
if isinstance(value, bool):
data[key] = value
return data
@staticmethod
def _extract_json(raw: str) -> dict:
if not raw:
return {}
match = re.search(r"\{.*\}", raw, flags=re.DOTALL)
if not match:
return {}
snippet = match.group(0)
try:
return json.loads(snippet)
except json.JSONDecodeError:
return {}