Spaces:
Sleeping
Sleeping
| """Task B agent β the Recommender. | |
| Given a UserPersona, return ranked items with per-item reasoning. | |
| The workflow: | |
| 1. build_query(persona, mode) | |
| β construct a semantic search query from the user's themes, | |
| voice, and vocabulary fingerprint. Different query for cold-start. | |
| 2. retrieve_candidates(query, persona, mode) | |
| β ChromaDB search for top 50 items | |
| β exclude items already in the user's history | |
| β if cross_domain_only, filter to domains the user hasn't engaged | |
| 3. rerank(persona, candidates) | |
| β LLM picks top 10 from the 50 and writes one-sentence reasoning | |
| for each. This is where intelligence per feature lives β every | |
| recommendation comes with grounded justification. | |
| 4. format_response(ranked_items) | |
| β return RankedRecommendation objects with item details + reasoning | |
| The system handles three explicit modes: | |
| - "warm" : user has history β blend behavioral signals + persona | |
| description into the query | |
| - "cold_start" : no history β use persona description + themes only | |
| - "cross_domain": user has history β recommend in domains they haven't | |
| engaged with, using their voice as a bridge | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass, asdict | |
| from typing import Literal | |
| from pydantic import BaseModel, Field | |
| from core.llm import LLMClient | |
| from core.persona import UserPersona | |
| from core.retrieval import ItemRetriever, RetrievedItem | |
| from core.review_aggregator import ReviewAggregator, ReviewSnippet | |
| from core.hyde import HydeRetriever | |
| from core.reflection import reflect_on_recommendations, ReflectionTrace | |
| log = logging.getLogger(__name__) | |
| Mode = Literal["warm", "cold_start", "cross_domain"] | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Output schemas | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RerankedItem(BaseModel): | |
| """One item in the LLM's reranked output.""" | |
| item_id: str = Field(description="The parent_asin of the chosen item") | |
| reasoning: str = Field(description="One sentence on why this fits this user, citing 1-2 specific persona signals") | |
| class RerankedList(BaseModel): | |
| """LLM's structured output: a ranked list of 10 items with reasons.""" | |
| items: list[RerankedItem] = Field( | |
| description="Top 10 items ranked best-first. Each must reference real item_ids from the candidate list." | |
| ) | |
| class Recommendation: | |
| """Final output unit returned to the API caller.""" | |
| rank: int | |
| item_id: str | |
| title: str | |
| domain: str | |
| categories: str | |
| reasoning: str | |
| retrieval_distance: float | |
| def as_dict(self) -> dict: | |
| return asdict(self) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Workflow steps | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_mode(persona: UserPersona, requested_cross_domain: bool) -> Mode: | |
| """Pick the appropriate workflow mode based on persona state.""" | |
| if requested_cross_domain and persona.n_reviews > 0: | |
| return "cross_domain" | |
| if persona.n_reviews == 0 or not persona.history_samples: | |
| return "cold_start" | |
| return "warm" | |
| def build_query(persona: UserPersona, mode: Mode) -> str: | |
| """Construct a semantic search query for the candidate retrieval step. | |
| Strategy: | |
| - For warm/cross_domain users, derive the query from their HIGH-RATED | |
| history items. Their actual taste is in what they like, not in their | |
| average vocabulary. We extract a short snippet from each 4+ star | |
| review and combine with themes/voice if available. | |
| - For cold_start, lean on themes + voice + stated preferences. | |
| - Always include a fallback so the query is never empty. | |
| """ | |
| themes = ", ".join(persona.preferred_themes) if persona.preferred_themes else "" | |
| voice = persona.voice_one_liner or "" | |
| if mode in ("warm", "cross_domain") and persona.history_samples: | |
| # Take up to 4 of the user's highest-rated past items as taste signal | |
| liked = sorted( | |
| [s for s in persona.history_samples if s["rating"] >= 4.0], | |
| key=lambda s: -s["rating"], | |
| )[:4] | |
| # If they have very few 4+ stars, fall back to top 3 by rating regardless | |
| if not liked: | |
| liked = sorted(persona.history_samples, key=lambda s: -s["rating"])[:3] | |
| snippets: list[str] = [] | |
| for s in liked: | |
| text = s["text"].strip() | |
| # Take the first ~30 words of each review β captures the gist | |
| snippet = " ".join(text.split()[:30]) | |
| snippets.append(snippet) | |
| liked_signal = " | ".join(snippets) | |
| prefix = "Find items similar to ones this reader rated highly." | |
| bridge = f" Their voice: {voice}" if voice else "" | |
| themes_bit = f" Themes they care about: {themes}." if themes else "" | |
| if mode == "cross_domain": | |
| return ( | |
| f"{prefix} They want to discover ADJACENT content in new domains." | |
| f"{bridge}{themes_bit}" | |
| f"\nWhat they've liked before: {liked_signal}" | |
| ) | |
| return ( | |
| f"{prefix}{bridge}{themes_bit}" | |
| f"\nWhat they've liked before: {liked_signal}" | |
| ) | |
| # cold_start path (or any mode where history is empty) | |
| parts = [] | |
| if voice: | |
| parts.append(voice) | |
| if themes: | |
| parts.append(f"Interested in: {themes}.") | |
| if persona.common_complaints: | |
| parts.append(f"Avoid: {', '.join(persona.common_complaints)}.") | |
| if not parts: | |
| # Last-resort fallback so we never send an empty query | |
| parts.append("popular well-reviewed items") | |
| return " ".join(parts) | |
| def retrieve_candidates(retriever: ItemRetriever, persona: UserPersona, | |
| mode: Mode, k_candidates: int = 50, | |
| hyde: HydeRetriever | None = None) -> list[RetrievedItem]: | |
| """Pull top-k semantically-similar items, respecting mode constraints. | |
| For cold_start mode, if a HydeRetriever is supplied, retrieval goes | |
| through HyDE (generate hypothetical items β embed β match real catalog). | |
| This sidesteps catalog-pollution that hurts naive cold-start retrieval. | |
| If HyDE fails or isn't supplied, falls back to the normal query path. | |
| """ | |
| # ββ Cold-start: try HyDE first βββββββββββββββββββββββββββββββββββββββ | |
| if mode == "cold_start" and hyde is not None: | |
| # Domains the persona's interests suggest β default to all three | |
| allowed = ["Books", "Movies_and_TV", "Kindle_Store"] | |
| hyde_candidates = hyde.retrieve(persona, k_candidates=k_candidates, | |
| allowed_domains=allowed) | |
| if hyde_candidates: | |
| log.info(f"Cold-start via HyDE: {len(hyde_candidates)} candidates") | |
| return hyde_candidates | |
| log.warning("HyDE returned no candidates; falling back to normal retrieval") | |
| # ββ Normal retrieval path (warm, cross_domain, or HyDE fallback) βββββ | |
| query = build_query(persona, mode) | |
| # Exclude items the user has already engaged with | |
| exclude_ids = {s["parent_asin"] for s in persona.history_samples} | |
| domains = None | |
| if mode == "cross_domain": | |
| all_domains = {"Books", "Movies_and_TV", "Kindle_Store"} | |
| unknown = sorted(all_domains - set(persona.domains)) | |
| if unknown: | |
| domains = unknown | |
| else: | |
| log.info("Cross-domain requested but user has touched all domains; falling back to warm mode") | |
| elif mode == "warm" and persona.domains: | |
| # Warm mode: retrieve within the domains the user actually engages with. | |
| # A single-domain user gets that one domain; a multi-domain user gets | |
| # all of theirs (retrieve() balances across them). Discovery in NEW | |
| # domains is the job of cross_domain mode, not warm mode. | |
| domains = list(persona.domains) | |
| log.info(f"Warm mode: restricting retrieval to user's domains {domains}") | |
| log.info(f"Retrieving {k_candidates} candidates for mode={mode}, query={query[:120]}") | |
| candidates = retriever.retrieve( | |
| query=query, | |
| k=k_candidates, | |
| domains=domains, | |
| exclude_ids=exclude_ids, | |
| ) | |
| log.info(f"Retrieved {len(candidates)} candidates") | |
| return candidates | |
| def build_rerank_prompt(persona: UserPersona, candidates: list[RetrievedItem], | |
| k_final: int, mode: Mode, | |
| reviews_by_item: dict[str, list[ReviewSnippet]] | None = None, | |
| ) -> str: | |
| """Render the rerank prompt β persona + candidate list + instructions. | |
| If reviews_by_item is provided, each candidate is enriched with 3-5 real | |
| reader reviews. This is the Stage 2c improvement: instead of judging by | |
| polluted titles, the LLM sees what humans actually said. | |
| """ | |
| parts = ["You are a thoughtful recommendation agent. Your job is to pick the best items for this specific user from a candidate list, and explain each pick with reference to the user's signals.\n"] | |
| parts.append("=" * 60) | |
| parts.append("THE USER") | |
| parts.append("=" * 60) | |
| parts.append(persona.to_prompt_block()) | |
| if persona.history_samples and mode != "cold_start": | |
| parts.append("\nRecent things this user engaged with (do NOT recommend these β they've already seen them):") | |
| for h in persona.history_samples[:5]: | |
| parts.append(f" - [{h['domain']}] {h['rating']}β : {h['text'][:120]}") | |
| parts.append("\n" + "=" * 60) | |
| parts.append(f"CANDIDATE ITEMS ({len(candidates)} retrieved by semantic search)") | |
| if reviews_by_item: | |
| parts.append("Each candidate includes a sample of REAL reader reviews β judge each item by what readers said, not by the title alone (many titles in this catalog are review headlines, not real product titles).") | |
| parts.append("=" * 60) | |
| for c in candidates: | |
| line = f"[{c.item_id}] ({c.domain}) {c.title}" | |
| if c.categories: | |
| line += f" | categories: {c.categories[:80]}" | |
| if c.description and len(c.description) > len(c.title) + 10: | |
| line += f"\n Description: {c.description[:200]}" | |
| # Stage 2c: include real reader reviews if available | |
| snippets = (reviews_by_item or {}).get(c.item_id, []) | |
| if snippets: | |
| line += "\n Reader reviews:" | |
| for s in snippets: | |
| line += f"\n {s.rating}β : {s.text}" | |
| parts.append(line) | |
| parts.append("\n" + "=" * 60) | |
| parts.append("YOUR TASK") | |
| parts.append("=" * 60) | |
| mode_hint = { | |
| "warm": "Pick items that match this user's established tastes β themes, voice, rating patterns. Use the reader reviews to confirm tone/pacing/style fit.", | |
| "cold_start": "Pick items that match the user's stated preferences. Be conservative β favor well-rated, widely-appealing items in the requested domain. The reader reviews are your most reliable signal here β titles in this catalog are noisy.", | |
| "cross_domain": "This user has tastes in some domains but you're recommending in OTHER domains. Find items in the candidate list that bridge their known tastes to the new domain β explain the bridge in each reasoning. Use the reader reviews to find genuine thematic bridges.", | |
| }[mode] | |
| parts.append(mode_hint) | |
| parts.append("") | |
| parts.append( | |
| f"Output the top {k_final} items as a ranked list (best first).\n" | |
| f"For each:\n" | |
| f" - 'item_id' must be one of the bracketed IDs above (exactly as written, e.g. 'B0073UKXBE')\n" | |
| f" - 'reasoning' is one sentence citing 1-2 specific signals from the persona, " | |
| f"ideally referencing what the reader reviews revealed about the item " | |
| f"(e.g. 'Reviews call it tight and fast-paced β matches their dislike of padding')\n" | |
| f"Do not invent item_ids. Do not repeat items. Order matters β best first." | |
| ) | |
| return "\n".join(parts) | |
| def rerank(llm: LLMClient, persona: UserPersona, candidates: list[RetrievedItem], | |
| k_final: int, mode: Mode, | |
| aggregator: ReviewAggregator | None = None, | |
| reviews_per_item: int = 4, | |
| enrich_top_n: int = 25, | |
| critique_feedback: str | None = None, | |
| conversation_context: str | None = None) -> list[RerankedItem]: | |
| """LLM rerank step. Returns up to k_final items with reasoning. | |
| If `aggregator` is provided, the top `enrich_top_n` candidates are | |
| enriched with real reader reviews before reranking. This is the Stage | |
| 2c improvement β it lets the LLM judge items by human language, not | |
| catalog metadata. | |
| If `critique_feedback` is provided, it is injected into the prompt as | |
| feedback from a previous self-critique pass (Stage 3c) β the reranker | |
| is told to fix the flagged problems. | |
| If `conversation_context` is provided, it is injected as multi-turn | |
| dialogue context β prior turns, what was recommended, what the user | |
| rejected and why β so the reranker reasons over the whole conversation | |
| rather than a single static persona. | |
| """ | |
| if not candidates: | |
| return [] | |
| # Stage 2c: fetch reviews for the top N candidates | |
| reviews_by_item: dict[str, list[ReviewSnippet]] = {} | |
| if aggregator is not None: | |
| # Only enrich the top N β saves prompt tokens for items unlikely to be picked | |
| top_ids = [c.item_id for c in candidates[:enrich_top_n]] | |
| log.info(f"Fetching {reviews_per_item} reviews each for top {len(top_ids)} candidates") | |
| reviews_by_item = aggregator.get_reviews_for_items( | |
| top_ids, k=reviews_per_item, | |
| exclude_user_id=persona.user_id, | |
| ) | |
| n_enriched = sum(1 for r in reviews_by_item.values() if r) | |
| log.info(f" β {n_enriched}/{len(top_ids)} candidates have reviews") | |
| prompt = build_rerank_prompt(persona, candidates, k_final, mode, | |
| reviews_by_item=reviews_by_item) | |
| if conversation_context: | |
| prompt += ( | |
| f"\n\n{'=' * 60}\n" | |
| f"MULTI-TURN CONVERSATION CONTEXT\n" | |
| f"{'=' * 60}\n" | |
| f"{conversation_context}\n" | |
| f"This is an ongoing conversation. Treat the request above as the " | |
| f"current turn, building on everything before it. Carry forward " | |
| f"the preferences the user has expressed; honour every rejection " | |
| f"and the stated reason for it; and narrow from earlier " | |
| f"recommendations rather than starting over. The picks should feel " | |
| f"like a continuation of this conversation." | |
| ) | |
| if critique_feedback: | |
| prompt += ( | |
| f"\n\n{'=' * 60}\n" | |
| f"FEEDBACK FROM A PREVIOUS ATTEMPT β FIX THESE ISSUES\n" | |
| f"{'=' * 60}\n" | |
| f"{critique_feedback}\n" | |
| f"Re-pick the top {k_final} addressing this feedback. Avoid the " | |
| f"problems flagged above." | |
| ) | |
| try: | |
| result = llm.structured( | |
| prompt, | |
| schema=RerankedList, | |
| model="reasoning", | |
| system="You are an expert recommendation agent that explains every pick.", | |
| ) | |
| except Exception as e: | |
| log.error(f"Rerank LLM call failed: {e}; falling back to retrieval order") | |
| return [ | |
| RerankedItem(item_id=c.item_id, | |
| reasoning=f"Matched semantic search for this user's profile.") | |
| for c in candidates[:k_final] | |
| ] | |
| # Filter to valid item_ids only (LLM occasionally hallucinates) | |
| valid_ids = {c.item_id for c in candidates} | |
| cleaned: list[RerankedItem] = [] | |
| seen: set[str] = set() | |
| for item in result.items: | |
| if item.item_id in valid_ids and item.item_id not in seen: | |
| cleaned.append(item) | |
| seen.add(item.item_id) | |
| if len(cleaned) >= k_final: | |
| break | |
| # If LLM returned fewer than k_final valid items, top up from retrieval order | |
| if len(cleaned) < k_final: | |
| for c in candidates: | |
| if c.item_id not in seen: | |
| cleaned.append(RerankedItem( | |
| item_id=c.item_id, | |
| reasoning="Strong semantic match for the user's profile.", | |
| )) | |
| seen.add(c.item_id) | |
| if len(cleaned) >= k_final: | |
| break | |
| return cleaned | |
| def format_response(ranked: list[RerankedItem], | |
| candidates_by_id: dict[str, RetrievedItem]) -> list[Recommendation]: | |
| """Wrap reranked items in the final Recommendation dataclass.""" | |
| out: list[Recommendation] = [] | |
| for i, item in enumerate(ranked, 1): | |
| cand = candidates_by_id.get(item.item_id) | |
| if cand is None: | |
| continue | |
| out.append(Recommendation( | |
| rank=i, | |
| item_id=cand.item_id, | |
| title=cand.title, | |
| domain=cand.domain, | |
| categories=cand.categories, | |
| reasoning=item.reasoning, | |
| retrieval_distance=cand.distance, | |
| )) | |
| return out | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Agent | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RecommendationAgent: | |
| """The Task B agent. | |
| Usage: | |
| agent = RecommendationAgent() | |
| recs = agent.run(persona, k=10, cross_domain=False) | |
| # recs is list[Recommendation] | |
| """ | |
| def __init__(self, llm: LLMClient | None = None, | |
| retriever: ItemRetriever | None = None, | |
| aggregator: ReviewAggregator | None = None, | |
| candidates_k: int = 50, | |
| use_review_enrichment: bool = True, | |
| use_hyde: bool = True, | |
| use_reflection: bool = True, | |
| reflection_max_iterations: int = 2): | |
| self.llm = llm or LLMClient() | |
| self.retriever = retriever or ItemRetriever() | |
| self.use_review_enrichment = use_review_enrichment | |
| if use_review_enrichment: | |
| self.aggregator = aggregator or ReviewAggregator() | |
| else: | |
| self.aggregator = None | |
| # HyDE for cold-start retrieval | |
| self.use_hyde = use_hyde | |
| if use_hyde: | |
| self.hyde = HydeRetriever(self.llm, self.retriever) | |
| else: | |
| self.hyde = None | |
| # Self-reflection | |
| self.use_reflection = use_reflection | |
| self.reflection_max_iterations = reflection_max_iterations | |
| self.candidates_k = candidates_k | |
| # Introspection hooks for the demo UI (purely additive) | |
| self.last_mode: str | None = None | |
| self.last_candidate_count: int = 0 | |
| self.last_reflection_trace = None | |
| def run(self, persona: UserPersona, k: int = 10, | |
| cross_domain: bool = False, | |
| conversation_context: str | None = None) -> list[Recommendation]: | |
| mode = detect_mode(persona, requested_cross_domain=cross_domain) | |
| log.info(f"Recommendation mode for user {persona.user_id}: {mode}") | |
| # Introspection hooks for the demo UI (purely additive β core logic | |
| # never reads these). Reset at the start of each run. | |
| self.last_mode = mode | |
| self.last_candidate_count = 0 | |
| self.last_reflection_trace = None | |
| # Step 1+2: retrieve candidates (HyDE for cold-start if enabled) | |
| candidates = retrieve_candidates( | |
| self.retriever, persona, mode, k_candidates=self.candidates_k, | |
| hyde=self.hyde, | |
| ) | |
| if not candidates: | |
| log.warning("No candidates retrieved; returning empty list") | |
| return [] | |
| cand_by_id = {c.item_id: c for c in candidates} | |
| self.last_candidate_count = len(candidates) | |
| # Step 3: LLM rerank with optional review enrichment (Stage 2c) | |
| ranked = rerank(self.llm, persona, candidates, k_final=k, mode=mode, | |
| aggregator=self.aggregator, | |
| conversation_context=conversation_context) | |
| recs = format_response(ranked, cand_by_id) | |
| # Step 4: self-reflection β critique + refine (Stage 3c) | |
| if self.use_reflection and recs: | |
| log.info("Running self-reflection on recommendations") | |
| def _refine(critique_issues: str) -> list[dict]: | |
| """Re-run rerank with the critique injected, return rec dicts.""" | |
| refined_ranked = rerank( | |
| self.llm, persona, candidates, k_final=k, mode=mode, | |
| aggregator=self.aggregator, | |
| critique_feedback=critique_issues, | |
| conversation_context=conversation_context, | |
| ) | |
| refined_recs = format_response(refined_ranked, cand_by_id) | |
| return [r.as_dict() for r in refined_recs] | |
| rec_dicts = [r.as_dict() for r in recs] | |
| final_dicts, trace = reflect_on_recommendations( | |
| self.llm, persona, rec_dicts, mode, | |
| refine_fn=_refine, | |
| max_iterations=self.reflection_max_iterations, | |
| ) | |
| self.last_reflection_trace = trace # introspection hook for the UI | |
| # Rebuild Recommendation objects from the (possibly refined) dicts | |
| if trace.refined: | |
| recs = _rebuild_recommendations(final_dicts, cand_by_id) | |
| return recs | |
| def _rebuild_recommendations(rec_dicts: list[dict], | |
| cand_by_id: dict[str, RetrievedItem], | |
| ) -> list[Recommendation]: | |
| """Reconstruct Recommendation objects from dicts after a refinement pass.""" | |
| out: list[Recommendation] = [] | |
| for i, d in enumerate(rec_dicts, 1): | |
| cand = cand_by_id.get(d.get("item_id", "")) | |
| if cand is None: | |
| continue | |
| out.append(Recommendation( | |
| rank=i, | |
| item_id=cand.item_id, | |
| title=cand.title, | |
| domain=cand.domain, | |
| categories=cand.categories, | |
| reasoning=d.get("reasoning", ""), | |
| retrieval_distance=cand.distance, | |
| )) | |
| return out | |