"""Task B agent — the Recommender. Given a UserPersona, return ranked items with per-item reasoning. The workflow: 1. build_query(persona, mode) → construct a semantic search query from the user's themes, voice, and vocabulary fingerprint. Different query for cold-start. 2. retrieve_candidates(query, persona, mode) → ChromaDB search for top 50 items → exclude items already in the user's history → if cross_domain_only, filter to domains the user hasn't engaged 3. rerank(persona, candidates) → LLM picks top 10 from the 50 and writes one-sentence reasoning for each. This is where intelligence per feature lives — every recommendation comes with grounded justification. 4. format_response(ranked_items) → return RankedRecommendation objects with item details + reasoning The system handles three explicit modes: - "warm" : user has history → blend behavioral signals + persona description into the query - "cold_start" : no history → use persona description + themes only - "cross_domain": user has history → recommend in domains they haven't engaged with, using their voice as a bridge """ from __future__ import annotations import logging from dataclasses import dataclass, asdict from typing import Literal from pydantic import BaseModel, Field from core.llm import LLMClient from core.persona import UserPersona from core.retrieval import ItemRetriever, RetrievedItem from core.review_aggregator import ReviewAggregator, ReviewSnippet from core.hyde import HydeRetriever from core.reflection import reflect_on_recommendations, ReflectionTrace log = logging.getLogger(__name__) Mode = Literal["warm", "cold_start", "cross_domain"] # ────────────────────────────────────────────────────────────────────────────── # Output schemas # ────────────────────────────────────────────────────────────────────────────── class RerankedItem(BaseModel): """One item in the LLM's reranked output.""" item_id: str = Field(description="The parent_asin of the chosen item") reasoning: str = Field(description="One sentence on why this fits this user, citing 1-2 specific persona signals") class RerankedList(BaseModel): """LLM's structured output: a ranked list of 10 items with reasons.""" items: list[RerankedItem] = Field( description="Top 10 items ranked best-first. Each must reference real item_ids from the candidate list." ) @dataclass class Recommendation: """Final output unit returned to the API caller.""" rank: int item_id: str title: str domain: str categories: str reasoning: str retrieval_distance: float def as_dict(self) -> dict: return asdict(self) # ────────────────────────────────────────────────────────────────────────────── # Workflow steps # ────────────────────────────────────────────────────────────────────────────── def detect_mode(persona: UserPersona, requested_cross_domain: bool) -> Mode: """Pick the appropriate workflow mode based on persona state.""" if requested_cross_domain and persona.n_reviews > 0: return "cross_domain" if persona.n_reviews == 0 or not persona.history_samples: return "cold_start" return "warm" def build_query(persona: UserPersona, mode: Mode) -> str: """Construct a semantic search query for the candidate retrieval step. Strategy: - For warm/cross_domain users, derive the query from their HIGH-RATED history items. Their actual taste is in what they like, not in their average vocabulary. We extract a short snippet from each 4+ star review and combine with themes/voice if available. - For cold_start, lean on themes + voice + stated preferences. - Always include a fallback so the query is never empty. """ themes = ", ".join(persona.preferred_themes) if persona.preferred_themes else "" voice = persona.voice_one_liner or "" if mode in ("warm", "cross_domain") and persona.history_samples: # Take up to 4 of the user's highest-rated past items as taste signal liked = sorted( [s for s in persona.history_samples if s["rating"] >= 4.0], key=lambda s: -s["rating"], )[:4] # If they have very few 4+ stars, fall back to top 3 by rating regardless if not liked: liked = sorted(persona.history_samples, key=lambda s: -s["rating"])[:3] snippets: list[str] = [] for s in liked: text = s["text"].strip() # Take the first ~30 words of each review — captures the gist snippet = " ".join(text.split()[:30]) snippets.append(snippet) liked_signal = " | ".join(snippets) prefix = "Find items similar to ones this reader rated highly." bridge = f" Their voice: {voice}" if voice else "" themes_bit = f" Themes they care about: {themes}." if themes else "" if mode == "cross_domain": return ( f"{prefix} They want to discover ADJACENT content in new domains." f"{bridge}{themes_bit}" f"\nWhat they've liked before: {liked_signal}" ) return ( f"{prefix}{bridge}{themes_bit}" f"\nWhat they've liked before: {liked_signal}" ) # cold_start path (or any mode where history is empty) parts = [] if voice: parts.append(voice) if themes: parts.append(f"Interested in: {themes}.") if persona.common_complaints: parts.append(f"Avoid: {', '.join(persona.common_complaints)}.") if not parts: # Last-resort fallback so we never send an empty query parts.append("popular well-reviewed items") return " ".join(parts) def retrieve_candidates(retriever: ItemRetriever, persona: UserPersona, mode: Mode, k_candidates: int = 50, hyde: HydeRetriever | None = None) -> list[RetrievedItem]: """Pull top-k semantically-similar items, respecting mode constraints. For cold_start mode, if a HydeRetriever is supplied, retrieval goes through HyDE (generate hypothetical items → embed → match real catalog). This sidesteps catalog-pollution that hurts naive cold-start retrieval. If HyDE fails or isn't supplied, falls back to the normal query path. """ # ── Cold-start: try HyDE first ─────────────────────────────────────── if mode == "cold_start" and hyde is not None: # Domains the persona's interests suggest — default to all three allowed = ["Books", "Movies_and_TV", "Kindle_Store"] hyde_candidates = hyde.retrieve(persona, k_candidates=k_candidates, allowed_domains=allowed) if hyde_candidates: log.info(f"Cold-start via HyDE: {len(hyde_candidates)} candidates") return hyde_candidates log.warning("HyDE returned no candidates; falling back to normal retrieval") # ── Normal retrieval path (warm, cross_domain, or HyDE fallback) ───── query = build_query(persona, mode) # Exclude items the user has already engaged with exclude_ids = {s["parent_asin"] for s in persona.history_samples} domains = None if mode == "cross_domain": all_domains = {"Books", "Movies_and_TV", "Kindle_Store"} unknown = sorted(all_domains - set(persona.domains)) if unknown: domains = unknown else: log.info("Cross-domain requested but user has touched all domains; falling back to warm mode") elif mode == "warm" and persona.domains: # Warm mode: retrieve within the domains the user actually engages with. # A single-domain user gets that one domain; a multi-domain user gets # all of theirs (retrieve() balances across them). Discovery in NEW # domains is the job of cross_domain mode, not warm mode. domains = list(persona.domains) log.info(f"Warm mode: restricting retrieval to user's domains {domains}") log.info(f"Retrieving {k_candidates} candidates for mode={mode}, query={query[:120]}") candidates = retriever.retrieve( query=query, k=k_candidates, domains=domains, exclude_ids=exclude_ids, ) log.info(f"Retrieved {len(candidates)} candidates") return candidates def build_rerank_prompt(persona: UserPersona, candidates: list[RetrievedItem], k_final: int, mode: Mode, reviews_by_item: dict[str, list[ReviewSnippet]] | None = None, ) -> str: """Render the rerank prompt — persona + candidate list + instructions. If reviews_by_item is provided, each candidate is enriched with 3-5 real reader reviews. This is the Stage 2c improvement: instead of judging by polluted titles, the LLM sees what humans actually said. """ parts = ["You are a thoughtful recommendation agent. Your job is to pick the best items for this specific user from a candidate list, and explain each pick with reference to the user's signals.\n"] parts.append("=" * 60) parts.append("THE USER") parts.append("=" * 60) parts.append(persona.to_prompt_block()) if persona.history_samples and mode != "cold_start": parts.append("\nRecent things this user engaged with (do NOT recommend these — they've already seen them):") for h in persona.history_samples[:5]: parts.append(f" - [{h['domain']}] {h['rating']}★: {h['text'][:120]}") parts.append("\n" + "=" * 60) parts.append(f"CANDIDATE ITEMS ({len(candidates)} retrieved by semantic search)") if reviews_by_item: parts.append("Each candidate includes a sample of REAL reader reviews — judge each item by what readers said, not by the title alone (many titles in this catalog are review headlines, not real product titles).") parts.append("=" * 60) for c in candidates: line = f"[{c.item_id}] ({c.domain}) {c.title}" if c.categories: line += f" | categories: {c.categories[:80]}" if c.description and len(c.description) > len(c.title) + 10: line += f"\n Description: {c.description[:200]}" # Stage 2c: include real reader reviews if available snippets = (reviews_by_item or {}).get(c.item_id, []) if snippets: line += "\n Reader reviews:" for s in snippets: line += f"\n {s.rating}★: {s.text}" parts.append(line) parts.append("\n" + "=" * 60) parts.append("YOUR TASK") parts.append("=" * 60) mode_hint = { "warm": "Pick items that match this user's established tastes — themes, voice, rating patterns. Use the reader reviews to confirm tone/pacing/style fit.", "cold_start": "Pick items that match the user's stated preferences. Be conservative — favor well-rated, widely-appealing items in the requested domain. The reader reviews are your most reliable signal here — titles in this catalog are noisy.", "cross_domain": "This user has tastes in some domains but you're recommending in OTHER domains. Find items in the candidate list that bridge their known tastes to the new domain — explain the bridge in each reasoning. Use the reader reviews to find genuine thematic bridges.", }[mode] parts.append(mode_hint) parts.append("") parts.append( f"Output the top {k_final} items as a ranked list (best first).\n" f"For each:\n" f" - 'item_id' must be one of the bracketed IDs above (exactly as written, e.g. 'B0073UKXBE')\n" f" - 'reasoning' is one sentence citing 1-2 specific signals from the persona, " f"ideally referencing what the reader reviews revealed about the item " f"(e.g. 'Reviews call it tight and fast-paced — matches their dislike of padding')\n" f"Do not invent item_ids. Do not repeat items. Order matters — best first." ) return "\n".join(parts) def rerank(llm: LLMClient, persona: UserPersona, candidates: list[RetrievedItem], k_final: int, mode: Mode, aggregator: ReviewAggregator | None = None, reviews_per_item: int = 4, enrich_top_n: int = 25, critique_feedback: str | None = None, conversation_context: str | None = None) -> list[RerankedItem]: """LLM rerank step. Returns up to k_final items with reasoning. If `aggregator` is provided, the top `enrich_top_n` candidates are enriched with real reader reviews before reranking. This is the Stage 2c improvement — it lets the LLM judge items by human language, not catalog metadata. If `critique_feedback` is provided, it is injected into the prompt as feedback from a previous self-critique pass (Stage 3c) — the reranker is told to fix the flagged problems. If `conversation_context` is provided, it is injected as multi-turn dialogue context — prior turns, what was recommended, what the user rejected and why — so the reranker reasons over the whole conversation rather than a single static persona. """ if not candidates: return [] # Stage 2c: fetch reviews for the top N candidates reviews_by_item: dict[str, list[ReviewSnippet]] = {} if aggregator is not None: # Only enrich the top N — saves prompt tokens for items unlikely to be picked top_ids = [c.item_id for c in candidates[:enrich_top_n]] log.info(f"Fetching {reviews_per_item} reviews each for top {len(top_ids)} candidates") reviews_by_item = aggregator.get_reviews_for_items( top_ids, k=reviews_per_item, exclude_user_id=persona.user_id, ) n_enriched = sum(1 for r in reviews_by_item.values() if r) log.info(f" → {n_enriched}/{len(top_ids)} candidates have reviews") prompt = build_rerank_prompt(persona, candidates, k_final, mode, reviews_by_item=reviews_by_item) if conversation_context: prompt += ( f"\n\n{'=' * 60}\n" f"MULTI-TURN CONVERSATION CONTEXT\n" f"{'=' * 60}\n" f"{conversation_context}\n" f"This is an ongoing conversation. Treat the request above as the " f"current turn, building on everything before it. Carry forward " f"the preferences the user has expressed; honour every rejection " f"and the stated reason for it; and narrow from earlier " f"recommendations rather than starting over. The picks should feel " f"like a continuation of this conversation." ) if critique_feedback: prompt += ( f"\n\n{'=' * 60}\n" f"FEEDBACK FROM A PREVIOUS ATTEMPT — FIX THESE ISSUES\n" f"{'=' * 60}\n" f"{critique_feedback}\n" f"Re-pick the top {k_final} addressing this feedback. Avoid the " f"problems flagged above." ) try: result = llm.structured( prompt, schema=RerankedList, model="reasoning", system="You are an expert recommendation agent that explains every pick.", ) except Exception as e: log.error(f"Rerank LLM call failed: {e}; falling back to retrieval order") return [ RerankedItem(item_id=c.item_id, reasoning=f"Matched semantic search for this user's profile.") for c in candidates[:k_final] ] # Filter to valid item_ids only (LLM occasionally hallucinates) valid_ids = {c.item_id for c in candidates} cleaned: list[RerankedItem] = [] seen: set[str] = set() for item in result.items: if item.item_id in valid_ids and item.item_id not in seen: cleaned.append(item) seen.add(item.item_id) if len(cleaned) >= k_final: break # If LLM returned fewer than k_final valid items, top up from retrieval order if len(cleaned) < k_final: for c in candidates: if c.item_id not in seen: cleaned.append(RerankedItem( item_id=c.item_id, reasoning="Strong semantic match for the user's profile.", )) seen.add(c.item_id) if len(cleaned) >= k_final: break return cleaned def format_response(ranked: list[RerankedItem], candidates_by_id: dict[str, RetrievedItem]) -> list[Recommendation]: """Wrap reranked items in the final Recommendation dataclass.""" out: list[Recommendation] = [] for i, item in enumerate(ranked, 1): cand = candidates_by_id.get(item.item_id) if cand is None: continue out.append(Recommendation( rank=i, item_id=cand.item_id, title=cand.title, domain=cand.domain, categories=cand.categories, reasoning=item.reasoning, retrieval_distance=cand.distance, )) return out # ────────────────────────────────────────────────────────────────────────────── # Agent # ────────────────────────────────────────────────────────────────────────────── class RecommendationAgent: """The Task B agent. Usage: agent = RecommendationAgent() recs = agent.run(persona, k=10, cross_domain=False) # recs is list[Recommendation] """ def __init__(self, llm: LLMClient | None = None, retriever: ItemRetriever | None = None, aggregator: ReviewAggregator | None = None, candidates_k: int = 50, use_review_enrichment: bool = True, use_hyde: bool = True, use_reflection: bool = True, reflection_max_iterations: int = 2): self.llm = llm or LLMClient() self.retriever = retriever or ItemRetriever() self.use_review_enrichment = use_review_enrichment if use_review_enrichment: self.aggregator = aggregator or ReviewAggregator() else: self.aggregator = None # HyDE for cold-start retrieval self.use_hyde = use_hyde if use_hyde: self.hyde = HydeRetriever(self.llm, self.retriever) else: self.hyde = None # Self-reflection self.use_reflection = use_reflection self.reflection_max_iterations = reflection_max_iterations self.candidates_k = candidates_k # Introspection hooks for the demo UI (purely additive) self.last_mode: str | None = None self.last_candidate_count: int = 0 self.last_reflection_trace = None def run(self, persona: UserPersona, k: int = 10, cross_domain: bool = False, conversation_context: str | None = None) -> list[Recommendation]: mode = detect_mode(persona, requested_cross_domain=cross_domain) log.info(f"Recommendation mode for user {persona.user_id}: {mode}") # Introspection hooks for the demo UI (purely additive — core logic # never reads these). Reset at the start of each run. self.last_mode = mode self.last_candidate_count = 0 self.last_reflection_trace = None # Step 1+2: retrieve candidates (HyDE for cold-start if enabled) candidates = retrieve_candidates( self.retriever, persona, mode, k_candidates=self.candidates_k, hyde=self.hyde, ) if not candidates: log.warning("No candidates retrieved; returning empty list") return [] cand_by_id = {c.item_id: c for c in candidates} self.last_candidate_count = len(candidates) # Step 3: LLM rerank with optional review enrichment (Stage 2c) ranked = rerank(self.llm, persona, candidates, k_final=k, mode=mode, aggregator=self.aggregator, conversation_context=conversation_context) recs = format_response(ranked, cand_by_id) # Step 4: self-reflection — critique + refine (Stage 3c) if self.use_reflection and recs: log.info("Running self-reflection on recommendations") def _refine(critique_issues: str) -> list[dict]: """Re-run rerank with the critique injected, return rec dicts.""" refined_ranked = rerank( self.llm, persona, candidates, k_final=k, mode=mode, aggregator=self.aggregator, critique_feedback=critique_issues, conversation_context=conversation_context, ) refined_recs = format_response(refined_ranked, cand_by_id) return [r.as_dict() for r in refined_recs] rec_dicts = [r.as_dict() for r in recs] final_dicts, trace = reflect_on_recommendations( self.llm, persona, rec_dicts, mode, refine_fn=_refine, max_iterations=self.reflection_max_iterations, ) self.last_reflection_trace = trace # introspection hook for the UI # Rebuild Recommendation objects from the (possibly refined) dicts if trace.refined: recs = _rebuild_recommendations(final_dicts, cand_by_id) return recs def _rebuild_recommendations(rec_dicts: list[dict], cand_by_id: dict[str, RetrievedItem], ) -> list[Recommendation]: """Reconstruct Recommendation objects from dicts after a refinement pass.""" out: list[Recommendation] = [] for i, d in enumerate(rec_dicts, 1): cand = cand_by_id.get(d.get("item_id", "")) if cand is None: continue out.append(Recommendation( rank=i, item_id=cand.item_id, title=cand.title, domain=cand.domain, categories=cand.categories, reasoning=d.get("reasoning", ""), retrieval_distance=cand.distance, )) return out