File size: 5,261 Bytes
73c205a
 
 
 
 
1243127
1c181b2
73c205a
 
1c181b2
652302c
0bf3001
 
 
 
1c181b2
 
73c205a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c181b2
 
 
 
 
 
 
652302c
d47b370
652302c
 
1c181b2
73c205a
 
 
0bf3001
1c181b2
 
 
 
 
 
 
73c205a
1c181b2
0bf3001
1c181b2
 
 
 
 
 
10bc91f
 
 
 
1c181b2
 
 
 
 
 
 
652302c
 
d47b370
 
 
 
 
 
 
 
652302c
 
 
 
 
 
 
 
 
d47b370
 
 
 
 
 
 
 
 
 
 
73c205a
d47b370
1c181b2
 
1243127
1c181b2
1243127
652302c
 
1243127
73c205a
 
 
1c181b2
652302c
73c205a
 
652302c
1c181b2
652302c
1c181b2
73c205a
 
1c181b2
 
73c205a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

import logging
import os
import re
import time
from pathlib import Path
from typing import Any

from app._paths import submission_root
from app.gemini_client import gemini_generate_chat, gemini_generate_text, use_gemini
from app.shared_models import (
    embedding_model_name_task_a,
    get_embedder,
)
from app.task_a_rag import TaskAReviewRagIndex
from app.user_modeling_prompt import build_prompt_parts_with_rag

logger = logging.getLogger(__name__)


def parse_model_output(raw: str) -> tuple[int | None, str]:
    text = raw.strip()
    stars: int | None = None
    m = re.search(r"Stars:\s*(\d)", text, re.IGNORECASE | re.MULTILINE)
    if m:
        stars = max(1, min(5, int(m.group(1))))
    rev_m = re.search(r"Review:\s*\n([\s\S]*)", text, re.IGNORECASE | re.MULTILINE)
    review = rev_m.group(1).strip() if rev_m else ""
    if not review and "Review:" in text:
        parts = re.split(r"(?i)^Review:\s*$", text, maxsplit=1, flags=re.MULTILINE)
        if len(parts) > 1:
            review = parts[-1].strip()
    return stars, review


def _resolve_path(raw: str) -> Path:
    p = Path(raw)
    if p.is_absolute():
        return p
    return submission_root() / p


def _task_a_gen_step() -> str:
    return "gemini_generate"


class UserModelingService:
    def __init__(self) -> None:
        self._max_tokens = int(os.environ.get("TASK_A_MAX_TOKENS", "1024"))
        self._temperature = float(os.environ.get("TASK_A_TEMPERATURE", "0.35"))
        self._embedding_model_name = embedding_model_name_task_a()
        rag_raw = os.environ.get(
            "TASK_A_REVIEWS_EMBEDDED",
            "data/task_a_reviews_embedded.jsonl",
        )
        self._rag_path = _resolve_path(rag_raw)
        self._rag_top_k = int(os.environ.get("TASK_A_RAG_TOP_K", "5"))
        self._rag_index: TaskAReviewRagIndex | None = None

    def _ensure_query_embedder(self) -> Any:
        return get_embedder(self._embedding_model_name)

    def _rag(self) -> TaskAReviewRagIndex:
        if self._rag_index is None:
            self._rag_index = TaskAReviewRagIndex(self._rag_path)
        return self._rag_index

    def warm(self) -> None:
        if self._rag_path.is_file():
            self._rag().load()

    def _retrieve_examples(self, persona: str, product: str) -> list[dict[str, Any]]:
        if not self._rag_path.is_file():
            return []
        rag = self._rag()
        embedder = self._ensure_query_embedder()
        return rag.retrieve(persona, product, embedder, self._rag_top_k)

    def _generate(self, persona: str, product: str, examples: list[dict[str, Any]]) -> str:
        inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
        if not use_gemini():
            raise RuntimeError("Task 1 requires Gemini for generation.")
        return gemini_generate_text(
            system_instruction=inst,
            user_text=user_body,
            temperature=self._temperature,
            max_output_tokens=min(int(self._max_tokens), 1024),
        )

    def _generate_fix(
        self, persona: str, product: str, prior_raw: str, examples: list[dict[str, Any]]
    ) -> str:
        inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
        fix_user = (
            "Your answer must follow exactly:\nStars: <1-5>\nReview:\n<text>\n\n"
            "The Review must be first person (I/my/me), as the user who visited — not third person. Fix strictly."
        )
        if not use_gemini():
            raise RuntimeError("Task 1 requires Gemini for generation.")
        return gemini_generate_chat(
            [
                {"role": "system", "content": inst},
                {"role": "user", "content": user_body},
                {"role": "assistant", "content": prior_raw},
                {"role": "user", "content": fix_user},
            ],
            temperature=0.2,
            max_output_tokens=min(int(self._max_tokens), 1024),
        )


    def generate(self, persona: str, product: str, *, include_raw: bool = False) -> dict[str, Any]:
        t0 = time.perf_counter()
        examples = self._retrieve_examples(persona, product)
        logger.info("Task A RAG: %d snippets (%.2fs)", len(examples), time.perf_counter() - t0)
        gen_step = _task_a_gen_step()
        raw = self._generate(persona, product, examples)
        logger.info("Task A generation done in %.2fs", time.perf_counter() - t0)
        stars, review = parse_model_output(raw)

        if stars is None or not review:
            logger.warning("Parse incomplete — retrying once with format nudge (Task 1).")
            raw = self._generate_fix(persona, product, raw, examples)
            stars, review = parse_model_output(raw)

        steps = ["yelp_rag_retrieve", gen_step, "parse_stars_review"]
        if not examples:
            steps = [gen_step, "parse_stars_review"]

        out: dict[str, Any] = {
            "task": "1_user_modeling",
            "agent_steps": steps,
            "rag_snippets_used": len(examples),
            "stars": stars,
            "review": review,
            "parse_ok": stars is not None and bool(review),
        }
        if include_raw:
            out["raw"] = raw
        return out