DSN / app /user_modeling.py
nexusbert's picture
Enhance user modeling and recommendation services with async support and logging
1243127
from __future__ import annotations
import logging
import os
import re
import time
from pathlib import Path
from typing import Any
from app._paths import submission_root
from app.shared_models import (
causal_lm_model_id_task_a,
embedding_model_name_task_a,
get_causal_lm,
get_embedder,
inference_lock,
)
from app.task_a_rag import TaskAReviewRagIndex
from app.user_modeling_prompt import build_prompt_parts_with_rag
logger = logging.getLogger(__name__)
def parse_model_output(raw: str) -> tuple[int | None, str]:
text = raw.strip()
stars: int | None = None
m = re.search(r"Stars:\s*(\d)", text, re.IGNORECASE | re.MULTILINE)
if m:
stars = max(1, min(5, int(m.group(1))))
rev_m = re.search(r"Review:\s*\n([\s\S]*)", text, re.IGNORECASE | re.MULTILINE)
review = rev_m.group(1).strip() if rev_m else ""
if not review and "Review:" in text:
parts = re.split(r"(?i)^Review:\s*$", text, maxsplit=1, flags=re.MULTILINE)
if len(parts) > 1:
review = parts[-1].strip()
return stars, review
def _resolve_path(raw: str) -> Path:
p = Path(raw)
if p.is_absolute():
return p
return submission_root() / p
class UserModelingService:
def __init__(self) -> None:
self._max_tokens = int(os.environ.get("TASK_A_MAX_TOKENS", "1024"))
self._temperature = float(os.environ.get("TASK_A_TEMPERATURE", "0.35"))
self._local_llm_model_id = causal_lm_model_id_task_a()
self._embedding_model_name = embedding_model_name_task_a()
rag_raw = os.environ.get(
"TASK_A_REVIEWS_EMBEDDED",
"data/task_a_reviews_embedded.jsonl",
)
self._rag_path = _resolve_path(rag_raw)
self._rag_top_k = int(os.environ.get("TASK_A_RAG_TOP_K", "5"))
self._rag_index: TaskAReviewRagIndex | None = None
def _ensure_query_embedder(self) -> Any:
return get_embedder(self._embedding_model_name)
def _rag(self) -> TaskAReviewRagIndex:
if self._rag_index is None:
self._rag_index = TaskAReviewRagIndex(self._rag_path)
return self._rag_index
def warm(self) -> None:
if self._rag_path.is_file():
self._rag().load()
def _ensure_local_llm(self) -> tuple[Any, Any, Any]:
return get_causal_lm(self._local_llm_model_id)
def _retrieve_examples(self, persona: str, product: str) -> list[dict[str, Any]]:
if not self._rag_path.is_file():
return []
rag = self._rag()
embedder = self._ensure_query_embedder()
return rag.retrieve(persona, product, embedder, self._rag_top_k)
def _generate_local(self, persona: str, product: str, examples: list[dict[str, Any]]) -> str:
tok, mdl, device = self._ensure_local_llm()
inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
messages = [
{"role": "system", "content": inst},
{"role": "user", "content": user_body},
]
prompt_txt = tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
try:
import torch # type: ignore[import-untyped]
except ImportError as e:
raise RuntimeError("Task 1 needs torch.") from e
inputs = tok(prompt_txt, return_tensors="pt").to(device)
if tok.pad_token_id is None:
tok.pad_token_id = tok.eos_token_id
max_new = min(int(self._max_tokens), 768)
with inference_lock(), torch.no_grad():
out = mdl.generate(
**inputs,
max_new_tokens=max_new,
do_sample=True,
temperature=self._temperature,
top_p=0.9,
pad_token_id=tok.pad_token_id,
)
gen_ids = out[0][inputs["input_ids"].shape[1] :]
return tok.decode(gen_ids, skip_special_tokens=True).strip()
def _generate_local_fix(self, persona: str, product: str, prior_raw: str, examples: list[dict[str, Any]]) -> str:
tok, mdl, device = self._ensure_local_llm()
inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
fix_user = (
"Your answer must follow exactly:\nStars: <1-5>\nReview:\n<text>\n\n"
"The Review must be first person (I/my/me), as the user who visited — not third person. Fix strictly."
)
messages = [
{"role": "system", "content": inst},
{"role": "user", "content": user_body},
{"role": "assistant", "content": prior_raw},
{"role": "user", "content": fix_user},
]
prompt_txt = tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
import torch # type: ignore[import-untyped]
inputs = tok(prompt_txt, return_tensors="pt").to(device)
if tok.pad_token_id is None:
tok.pad_token_id = tok.eos_token_id
max_new = min(int(self._max_tokens), 768)
with inference_lock(), torch.no_grad():
out = mdl.generate(
**inputs,
max_new_tokens=max_new,
do_sample=False,
pad_token_id=tok.pad_token_id,
)
gen_ids = out[0][inputs["input_ids"].shape[1] :]
return tok.decode(gen_ids, skip_special_tokens=True).strip()
def generate(self, persona: str, product: str, *, include_raw: bool = False) -> dict[str, Any]:
t0 = time.perf_counter()
examples = self._retrieve_examples(persona, product)
logger.info("Task A RAG: %d snippets (%.2fs)", len(examples), time.perf_counter() - t0)
raw = self._generate_local(persona, product, examples)
logger.info("Task A generation done in %.2fs", time.perf_counter() - t0)
stars, review = parse_model_output(raw)
if stars is None or not review:
logger.warning("Parse incomplete — retrying once with format nudge (Task 1).")
raw = self._generate_local_fix(persona, product, raw, examples)
stars, review = parse_model_output(raw)
steps = ["yelp_rag_retrieve", "local_hf_causal_lm", "parse_stars_review"]
if not examples:
steps = ["local_hf_causal_lm", "parse_stars_review"]
out: dict[str, Any] = {
"task": "1_user_modeling",
"agent_steps": steps,
"rag_snippets_used": len(examples),
"stars": stars,
"review": review,
"parse_ok": stars is not None and bool(review),
}
if include_raw:
out["raw"] = raw
return out