| from __future__ import annotations |
|
|
| import logging |
| import os |
| import re |
| import time |
| from pathlib import Path |
| from typing import Any |
|
|
| from app._paths import submission_root |
| from app.shared_models import ( |
| causal_lm_model_id_task_a, |
| embedding_model_name_task_a, |
| get_causal_lm, |
| get_embedder, |
| inference_lock, |
| ) |
| from app.task_a_rag import TaskAReviewRagIndex |
| from app.user_modeling_prompt import build_prompt_parts_with_rag |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def parse_model_output(raw: str) -> tuple[int | None, str]: |
| text = raw.strip() |
| stars: int | None = None |
| m = re.search(r"Stars:\s*(\d)", text, re.IGNORECASE | re.MULTILINE) |
| if m: |
| stars = max(1, min(5, int(m.group(1)))) |
| rev_m = re.search(r"Review:\s*\n([\s\S]*)", text, re.IGNORECASE | re.MULTILINE) |
| review = rev_m.group(1).strip() if rev_m else "" |
| if not review and "Review:" in text: |
| parts = re.split(r"(?i)^Review:\s*$", text, maxsplit=1, flags=re.MULTILINE) |
| if len(parts) > 1: |
| review = parts[-1].strip() |
| return stars, review |
|
|
|
|
| def _resolve_path(raw: str) -> Path: |
| p = Path(raw) |
| if p.is_absolute(): |
| return p |
| return submission_root() / p |
|
|
|
|
| class UserModelingService: |
| def __init__(self) -> None: |
| self._max_tokens = int(os.environ.get("TASK_A_MAX_TOKENS", "1024")) |
| self._temperature = float(os.environ.get("TASK_A_TEMPERATURE", "0.35")) |
| self._local_llm_model_id = causal_lm_model_id_task_a() |
| self._embedding_model_name = embedding_model_name_task_a() |
| rag_raw = os.environ.get( |
| "TASK_A_REVIEWS_EMBEDDED", |
| "data/task_a_reviews_embedded.jsonl", |
| ) |
| self._rag_path = _resolve_path(rag_raw) |
| self._rag_top_k = int(os.environ.get("TASK_A_RAG_TOP_K", "5")) |
| self._rag_index: TaskAReviewRagIndex | None = None |
|
|
| def _ensure_query_embedder(self) -> Any: |
| return get_embedder(self._embedding_model_name) |
|
|
| def _rag(self) -> TaskAReviewRagIndex: |
| if self._rag_index is None: |
| self._rag_index = TaskAReviewRagIndex(self._rag_path) |
| return self._rag_index |
|
|
| def warm(self) -> None: |
| if self._rag_path.is_file(): |
| self._rag().load() |
|
|
| def _ensure_local_llm(self) -> tuple[Any, Any, Any]: |
| return get_causal_lm(self._local_llm_model_id) |
|
|
| def _retrieve_examples(self, persona: str, product: str) -> list[dict[str, Any]]: |
| if not self._rag_path.is_file(): |
| return [] |
| rag = self._rag() |
| embedder = self._ensure_query_embedder() |
| return rag.retrieve(persona, product, embedder, self._rag_top_k) |
|
|
| def _generate_local(self, persona: str, product: str, examples: list[dict[str, Any]]) -> str: |
| tok, mdl, device = self._ensure_local_llm() |
| inst, user_body = build_prompt_parts_with_rag(persona, product, examples) |
| messages = [ |
| {"role": "system", "content": inst}, |
| {"role": "user", "content": user_body}, |
| ] |
| prompt_txt = tok.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| try: |
| import torch |
| except ImportError as e: |
| raise RuntimeError("Task 1 needs torch.") from e |
|
|
| inputs = tok(prompt_txt, return_tensors="pt").to(device) |
| if tok.pad_token_id is None: |
| tok.pad_token_id = tok.eos_token_id |
|
|
| max_new = min(int(self._max_tokens), 768) |
| with inference_lock(), torch.no_grad(): |
| out = mdl.generate( |
| **inputs, |
| max_new_tokens=max_new, |
| do_sample=True, |
| temperature=self._temperature, |
| top_p=0.9, |
| pad_token_id=tok.pad_token_id, |
| ) |
| gen_ids = out[0][inputs["input_ids"].shape[1] :] |
| return tok.decode(gen_ids, skip_special_tokens=True).strip() |
|
|
| def _generate_local_fix(self, persona: str, product: str, prior_raw: str, examples: list[dict[str, Any]]) -> str: |
| tok, mdl, device = self._ensure_local_llm() |
| inst, user_body = build_prompt_parts_with_rag(persona, product, examples) |
| fix_user = ( |
| "Your answer must follow exactly:\nStars: <1-5>\nReview:\n<text>\n\n" |
| "The Review must be first person (I/my/me), as the user who visited — not third person. Fix strictly." |
| ) |
| messages = [ |
| {"role": "system", "content": inst}, |
| {"role": "user", "content": user_body}, |
| {"role": "assistant", "content": prior_raw}, |
| {"role": "user", "content": fix_user}, |
| ] |
| prompt_txt = tok.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| import torch |
|
|
| inputs = tok(prompt_txt, return_tensors="pt").to(device) |
| if tok.pad_token_id is None: |
| tok.pad_token_id = tok.eos_token_id |
| max_new = min(int(self._max_tokens), 768) |
| with inference_lock(), torch.no_grad(): |
| out = mdl.generate( |
| **inputs, |
| max_new_tokens=max_new, |
| do_sample=False, |
| pad_token_id=tok.pad_token_id, |
| ) |
| gen_ids = out[0][inputs["input_ids"].shape[1] :] |
| return tok.decode(gen_ids, skip_special_tokens=True).strip() |
|
|
| def generate(self, persona: str, product: str, *, include_raw: bool = False) -> dict[str, Any]: |
| t0 = time.perf_counter() |
| examples = self._retrieve_examples(persona, product) |
| logger.info("Task A RAG: %d snippets (%.2fs)", len(examples), time.perf_counter() - t0) |
| raw = self._generate_local(persona, product, examples) |
| logger.info("Task A generation done in %.2fs", time.perf_counter() - t0) |
| stars, review = parse_model_output(raw) |
|
|
| if stars is None or not review: |
| logger.warning("Parse incomplete — retrying once with format nudge (Task 1).") |
| raw = self._generate_local_fix(persona, product, raw, examples) |
| stars, review = parse_model_output(raw) |
|
|
| steps = ["yelp_rag_retrieve", "local_hf_causal_lm", "parse_stars_review"] |
| if not examples: |
| steps = ["local_hf_causal_lm", "parse_stars_review"] |
|
|
| out: dict[str, Any] = { |
| "task": "1_user_modeling", |
| "agent_steps": steps, |
| "rag_snippets_used": len(examples), |
| "stars": stars, |
| "review": review, |
| "parse_ok": stars is not None and bool(review), |
| } |
| if include_raw: |
| out["raw"] = raw |
| return out |
|
|