from __future__ import annotations import logging import os import re import time from pathlib import Path from typing import Any from app._paths import submission_root from app.shared_models import ( causal_lm_model_id_task_a, embedding_model_name_task_a, get_causal_lm, get_embedder, inference_lock, ) from app.task_a_rag import TaskAReviewRagIndex from app.user_modeling_prompt import build_prompt_parts_with_rag logger = logging.getLogger(__name__) def parse_model_output(raw: str) -> tuple[int | None, str]: text = raw.strip() stars: int | None = None m = re.search(r"Stars:\s*(\d)", text, re.IGNORECASE | re.MULTILINE) if m: stars = max(1, min(5, int(m.group(1)))) rev_m = re.search(r"Review:\s*\n([\s\S]*)", text, re.IGNORECASE | re.MULTILINE) review = rev_m.group(1).strip() if rev_m else "" if not review and "Review:" in text: parts = re.split(r"(?i)^Review:\s*$", text, maxsplit=1, flags=re.MULTILINE) if len(parts) > 1: review = parts[-1].strip() return stars, review def _resolve_path(raw: str) -> Path: p = Path(raw) if p.is_absolute(): return p return submission_root() / p class UserModelingService: def __init__(self) -> None: self._max_tokens = int(os.environ.get("TASK_A_MAX_TOKENS", "1024")) self._temperature = float(os.environ.get("TASK_A_TEMPERATURE", "0.35")) self._local_llm_model_id = causal_lm_model_id_task_a() self._embedding_model_name = embedding_model_name_task_a() rag_raw = os.environ.get( "TASK_A_REVIEWS_EMBEDDED", "data/task_a_reviews_embedded.jsonl", ) self._rag_path = _resolve_path(rag_raw) self._rag_top_k = int(os.environ.get("TASK_A_RAG_TOP_K", "5")) self._rag_index: TaskAReviewRagIndex | None = None def _ensure_query_embedder(self) -> Any: return get_embedder(self._embedding_model_name) def _rag(self) -> TaskAReviewRagIndex: if self._rag_index is None: self._rag_index = TaskAReviewRagIndex(self._rag_path) return self._rag_index def warm(self) -> None: if self._rag_path.is_file(): self._rag().load() def _ensure_local_llm(self) -> tuple[Any, Any, Any]: return get_causal_lm(self._local_llm_model_id) def _retrieve_examples(self, persona: str, product: str) -> list[dict[str, Any]]: if not self._rag_path.is_file(): return [] rag = self._rag() embedder = self._ensure_query_embedder() return rag.retrieve(persona, product, embedder, self._rag_top_k) def _generate_local(self, persona: str, product: str, examples: list[dict[str, Any]]) -> str: tok, mdl, device = self._ensure_local_llm() inst, user_body = build_prompt_parts_with_rag(persona, product, examples) messages = [ {"role": "system", "content": inst}, {"role": "user", "content": user_body}, ] prompt_txt = tok.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) try: import torch # type: ignore[import-untyped] except ImportError as e: raise RuntimeError("Task 1 needs torch.") from e inputs = tok(prompt_txt, return_tensors="pt").to(device) if tok.pad_token_id is None: tok.pad_token_id = tok.eos_token_id max_new = min(int(self._max_tokens), 768) with inference_lock(), torch.no_grad(): out = mdl.generate( **inputs, max_new_tokens=max_new, do_sample=True, temperature=self._temperature, top_p=0.9, pad_token_id=tok.pad_token_id, ) gen_ids = out[0][inputs["input_ids"].shape[1] :] return tok.decode(gen_ids, skip_special_tokens=True).strip() def _generate_local_fix(self, persona: str, product: str, prior_raw: str, examples: list[dict[str, Any]]) -> str: tok, mdl, device = self._ensure_local_llm() inst, user_body = build_prompt_parts_with_rag(persona, product, examples) fix_user = ( "Your answer must follow exactly:\nStars: <1-5>\nReview:\n\n\n" "The Review must be first person (I/my/me), as the user who visited — not third person. Fix strictly." ) messages = [ {"role": "system", "content": inst}, {"role": "user", "content": user_body}, {"role": "assistant", "content": prior_raw}, {"role": "user", "content": fix_user}, ] prompt_txt = tok.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) import torch # type: ignore[import-untyped] inputs = tok(prompt_txt, return_tensors="pt").to(device) if tok.pad_token_id is None: tok.pad_token_id = tok.eos_token_id max_new = min(int(self._max_tokens), 768) with inference_lock(), torch.no_grad(): out = mdl.generate( **inputs, max_new_tokens=max_new, do_sample=False, pad_token_id=tok.pad_token_id, ) gen_ids = out[0][inputs["input_ids"].shape[1] :] return tok.decode(gen_ids, skip_special_tokens=True).strip() def generate(self, persona: str, product: str, *, include_raw: bool = False) -> dict[str, Any]: t0 = time.perf_counter() examples = self._retrieve_examples(persona, product) logger.info("Task A RAG: %d snippets (%.2fs)", len(examples), time.perf_counter() - t0) raw = self._generate_local(persona, product, examples) logger.info("Task A generation done in %.2fs", time.perf_counter() - t0) stars, review = parse_model_output(raw) if stars is None or not review: logger.warning("Parse incomplete — retrying once with format nudge (Task 1).") raw = self._generate_local_fix(persona, product, raw, examples) stars, review = parse_model_output(raw) steps = ["yelp_rag_retrieve", "local_hf_causal_lm", "parse_stars_review"] if not examples: steps = ["local_hf_causal_lm", "parse_stars_review"] out: dict[str, Any] = { "task": "1_user_modeling", "agent_steps": steps, "rag_snippets_used": len(examples), "stars": stars, "review": review, "parse_ok": stars is not None and bool(review), } if include_raw: out["raw"] = raw return out