File size: 5,261 Bytes
73c205a 1243127 1c181b2 73c205a 1c181b2 652302c 0bf3001 1c181b2 73c205a 1c181b2 652302c d47b370 652302c 1c181b2 73c205a 0bf3001 1c181b2 73c205a 1c181b2 0bf3001 1c181b2 10bc91f 1c181b2 652302c d47b370 652302c d47b370 73c205a d47b370 1c181b2 1243127 1c181b2 1243127 652302c 1243127 73c205a 1c181b2 652302c 73c205a 652302c 1c181b2 652302c 1c181b2 73c205a 1c181b2 73c205a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | from __future__ import annotations
import logging
import os
import re
import time
from pathlib import Path
from typing import Any
from app._paths import submission_root
from app.gemini_client import gemini_generate_chat, gemini_generate_text, use_gemini
from app.shared_models import (
embedding_model_name_task_a,
get_embedder,
)
from app.task_a_rag import TaskAReviewRagIndex
from app.user_modeling_prompt import build_prompt_parts_with_rag
logger = logging.getLogger(__name__)
def parse_model_output(raw: str) -> tuple[int | None, str]:
text = raw.strip()
stars: int | None = None
m = re.search(r"Stars:\s*(\d)", text, re.IGNORECASE | re.MULTILINE)
if m:
stars = max(1, min(5, int(m.group(1))))
rev_m = re.search(r"Review:\s*\n([\s\S]*)", text, re.IGNORECASE | re.MULTILINE)
review = rev_m.group(1).strip() if rev_m else ""
if not review and "Review:" in text:
parts = re.split(r"(?i)^Review:\s*$", text, maxsplit=1, flags=re.MULTILINE)
if len(parts) > 1:
review = parts[-1].strip()
return stars, review
def _resolve_path(raw: str) -> Path:
p = Path(raw)
if p.is_absolute():
return p
return submission_root() / p
def _task_a_gen_step() -> str:
return "gemini_generate"
class UserModelingService:
def __init__(self) -> None:
self._max_tokens = int(os.environ.get("TASK_A_MAX_TOKENS", "1024"))
self._temperature = float(os.environ.get("TASK_A_TEMPERATURE", "0.35"))
self._embedding_model_name = embedding_model_name_task_a()
rag_raw = os.environ.get(
"TASK_A_REVIEWS_EMBEDDED",
"data/task_a_reviews_embedded.jsonl",
)
self._rag_path = _resolve_path(rag_raw)
self._rag_top_k = int(os.environ.get("TASK_A_RAG_TOP_K", "5"))
self._rag_index: TaskAReviewRagIndex | None = None
def _ensure_query_embedder(self) -> Any:
return get_embedder(self._embedding_model_name)
def _rag(self) -> TaskAReviewRagIndex:
if self._rag_index is None:
self._rag_index = TaskAReviewRagIndex(self._rag_path)
return self._rag_index
def warm(self) -> None:
if self._rag_path.is_file():
self._rag().load()
def _retrieve_examples(self, persona: str, product: str) -> list[dict[str, Any]]:
if not self._rag_path.is_file():
return []
rag = self._rag()
embedder = self._ensure_query_embedder()
return rag.retrieve(persona, product, embedder, self._rag_top_k)
def _generate(self, persona: str, product: str, examples: list[dict[str, Any]]) -> str:
inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
if not use_gemini():
raise RuntimeError("Task 1 requires Gemini for generation.")
return gemini_generate_text(
system_instruction=inst,
user_text=user_body,
temperature=self._temperature,
max_output_tokens=min(int(self._max_tokens), 1024),
)
def _generate_fix(
self, persona: str, product: str, prior_raw: str, examples: list[dict[str, Any]]
) -> str:
inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
fix_user = (
"Your answer must follow exactly:\nStars: <1-5>\nReview:\n<text>\n\n"
"The Review must be first person (I/my/me), as the user who visited — not third person. Fix strictly."
)
if not use_gemini():
raise RuntimeError("Task 1 requires Gemini for generation.")
return gemini_generate_chat(
[
{"role": "system", "content": inst},
{"role": "user", "content": user_body},
{"role": "assistant", "content": prior_raw},
{"role": "user", "content": fix_user},
],
temperature=0.2,
max_output_tokens=min(int(self._max_tokens), 1024),
)
def generate(self, persona: str, product: str, *, include_raw: bool = False) -> dict[str, Any]:
t0 = time.perf_counter()
examples = self._retrieve_examples(persona, product)
logger.info("Task A RAG: %d snippets (%.2fs)", len(examples), time.perf_counter() - t0)
gen_step = _task_a_gen_step()
raw = self._generate(persona, product, examples)
logger.info("Task A generation done in %.2fs", time.perf_counter() - t0)
stars, review = parse_model_output(raw)
if stars is None or not review:
logger.warning("Parse incomplete — retrying once with format nudge (Task 1).")
raw = self._generate_fix(persona, product, raw, examples)
stars, review = parse_model_output(raw)
steps = ["yelp_rag_retrieve", gen_step, "parse_stars_review"]
if not examples:
steps = [gen_step, "parse_stars_review"]
out: dict[str, Any] = {
"task": "1_user_modeling",
"agent_steps": steps,
"rag_snippets_used": len(examples),
"stars": stars,
"review": review,
"parse_ok": stars is not None and bool(review),
}
if include_raw:
out["raw"] = raw
return out
|