from __future__ import annotations import json import logging import math import re from collections import defaultdict from pathlib import Path from typing import Any import numpy as np logger = logging.getLogger(__name__) def extract_user_id(persona: str) -> str | None: for pat in ( r"(?im)^user_id:\s*(\S+)", r"(?im)^User ID:\s*(\S+)", r"(?i)\buser_id\s*=\s*(\S+)", ): m = re.search(pat, persona) if m: return m.group(1).strip() return None class TaskAReviewRagIndex: def __init__(self, path: Path) -> None: self.path = path self._rows: list[dict[str, Any]] = [] self._mat: np.ndarray | None = None self._by_user: dict[str, list[int]] = defaultdict(list) self._loaded = False @property def loaded(self) -> bool: return self._loaded def load(self) -> None: if self._loaded: return if not self.path.is_file(): logger.warning("Task A RAG file missing — run scripts/build_task_a_review_rag.py or Docker build (%s)", self.path) self._rows = [] self._mat = None self._loaded = True return embeddings: list[list[float]] = [] with self.path.open(encoding="utf-8", errors="replace") as f: for line in f: line = line.strip() if not line: continue row = json.loads(line) emb = row.get("embedding") if not emb: continue uid = row.get("user_id") or "" clean = {k: v for k, v in row.items() if k != "embedding"} idx = len(self._rows) self._rows.append(clean) embeddings.append(emb) if uid: self._by_user[str(uid)].append(idx) if not embeddings: logger.warning("Task A RAG file empty or invalid: %s", self.path) self._mat = None self._loaded = True return self._mat = np.asarray(embeddings, dtype=np.float32) norms = np.linalg.norm(self._mat, axis=1, keepdims=True) norms[norms == 0] = 1.0 self._mat = self._mat / norms self._loaded = True logger.info("Task A RAG loaded %d snippets from %s", len(self._rows), self.path) def _encode_query(self, embedder: Any, text: str) -> np.ndarray: t = text.replace("\n", " ")[:8000] vec = embedder.encode([t], convert_to_numpy=True, normalize_embeddings=False)[0] q = np.asarray(vec, dtype=np.float32) nq = np.linalg.norm(q) if nq == 0: q = np.ones_like(q) / math.sqrt(len(q)) else: q = q / nq return q def retrieve( self, persona: str, product: str, embedder: Any, top_k: int, user_examples_cap: int = 3, ) -> list[dict[str, Any]]: self.load() if self._mat is None or not self._rows: return [] k = max(1, top_k) uid = extract_user_id(persona) qtext = f"{persona.strip()}\n{product.strip()}" q = self._encode_query(embedder, qtext) scores = self._mat @ q order = np.argsort(-scores) picked_idx: list[int] = [] seen: set[int] = set() if uid and uid in self._by_user: uidx = self._by_user[uid] ranked_u = sorted(uidx, key=lambda i: float(scores[i]), reverse=True) for i in ranked_u[: min(user_examples_cap, k)]: picked_idx.append(i) seen.add(i) for i in order.tolist(): if len(picked_idx) >= k: break if i in seen: continue picked_idx.append(i) seen.add(i) out: list[dict[str, Any]] = [] for i in picked_idx[:k]: r = self._rows[i] out.append( { "stars": r.get("stars"), "review_excerpt": r.get("review_excerpt", ""), "business_context": r.get("business_context", ""), "same_user_as_persona": bool(uid and r.get("user_id") == uid), } ) return out