Israelbliz's picture
Upload agent.py
69a9870 verified
"""Task B agent β€” the Recommender.
Given a UserPersona, return ranked items with per-item reasoning.
The workflow:
1. build_query(persona, mode)
β†’ construct a semantic search query from the user's themes,
voice, and vocabulary fingerprint. Different query for cold-start.
2. retrieve_candidates(query, persona, mode)
β†’ ChromaDB search for top 50 items
β†’ exclude items already in the user's history
β†’ if cross_domain_only, filter to domains the user hasn't engaged
3. rerank(persona, candidates)
β†’ LLM picks top 10 from the 50 and writes one-sentence reasoning
for each. This is where intelligence per feature lives β€” every
recommendation comes with grounded justification.
4. format_response(ranked_items)
β†’ return RankedRecommendation objects with item details + reasoning
The system handles three explicit modes:
- "warm" : user has history β†’ blend behavioral signals + persona
description into the query
- "cold_start" : no history β†’ use persona description + themes only
- "cross_domain": user has history β†’ recommend in domains they haven't
engaged with, using their voice as a bridge
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, asdict
from typing import Literal
from pydantic import BaseModel, Field
from core.llm import LLMClient
from core.persona import UserPersona
from core.retrieval import ItemRetriever, RetrievedItem
from core.review_aggregator import ReviewAggregator, ReviewSnippet
from core.hyde import HydeRetriever
from core.reflection import reflect_on_recommendations, ReflectionTrace
log = logging.getLogger(__name__)
Mode = Literal["warm", "cold_start", "cross_domain"]
# ──────────────────────────────────────────────────────────────────────────────
# Output schemas
# ──────────────────────────────────────────────────────────────────────────────
class RerankedItem(BaseModel):
"""One item in the LLM's reranked output."""
item_id: str = Field(description="The parent_asin of the chosen item")
reasoning: str = Field(description="One sentence on why this fits this user, citing 1-2 specific persona signals")
class RerankedList(BaseModel):
"""LLM's structured output: a ranked list of 10 items with reasons."""
items: list[RerankedItem] = Field(
description="Top 10 items ranked best-first. Each must reference real item_ids from the candidate list."
)
@dataclass
class Recommendation:
"""Final output unit returned to the API caller."""
rank: int
item_id: str
title: str
domain: str
categories: str
reasoning: str
retrieval_distance: float
def as_dict(self) -> dict:
return asdict(self)
# ──────────────────────────────────────────────────────────────────────────────
# Workflow steps
# ──────────────────────────────────────────────────────────────────────────────
def detect_mode(persona: UserPersona, requested_cross_domain: bool) -> Mode:
"""Pick the appropriate workflow mode based on persona state."""
if requested_cross_domain and persona.n_reviews > 0:
return "cross_domain"
if persona.n_reviews == 0 or not persona.history_samples:
return "cold_start"
return "warm"
def build_query(persona: UserPersona, mode: Mode) -> str:
"""Construct a semantic search query for the candidate retrieval step.
Strategy:
- For warm/cross_domain users, derive the query from their HIGH-RATED
history items. Their actual taste is in what they like, not in their
average vocabulary. We extract a short snippet from each 4+ star
review and combine with themes/voice if available.
- For cold_start, lean on themes + voice + stated preferences.
- Always include a fallback so the query is never empty.
"""
themes = ", ".join(persona.preferred_themes) if persona.preferred_themes else ""
voice = persona.voice_one_liner or ""
if mode in ("warm", "cross_domain") and persona.history_samples:
# Take up to 4 of the user's highest-rated past items as taste signal
liked = sorted(
[s for s in persona.history_samples if s["rating"] >= 4.0],
key=lambda s: -s["rating"],
)[:4]
# If they have very few 4+ stars, fall back to top 3 by rating regardless
if not liked:
liked = sorted(persona.history_samples, key=lambda s: -s["rating"])[:3]
snippets: list[str] = []
for s in liked:
text = s["text"].strip()
# Take the first ~30 words of each review β€” captures the gist
snippet = " ".join(text.split()[:30])
snippets.append(snippet)
liked_signal = " | ".join(snippets)
prefix = "Find items similar to ones this reader rated highly."
bridge = f" Their voice: {voice}" if voice else ""
themes_bit = f" Themes they care about: {themes}." if themes else ""
if mode == "cross_domain":
return (
f"{prefix} They want to discover ADJACENT content in new domains."
f"{bridge}{themes_bit}"
f"\nWhat they've liked before: {liked_signal}"
)
return (
f"{prefix}{bridge}{themes_bit}"
f"\nWhat they've liked before: {liked_signal}"
)
# cold_start path (or any mode where history is empty)
parts = []
if voice:
parts.append(voice)
if themes:
parts.append(f"Interested in: {themes}.")
if persona.common_complaints:
parts.append(f"Avoid: {', '.join(persona.common_complaints)}.")
if not parts:
# Last-resort fallback so we never send an empty query
parts.append("popular well-reviewed items")
return " ".join(parts)
def retrieve_candidates(retriever: ItemRetriever, persona: UserPersona,
mode: Mode, k_candidates: int = 50,
hyde: HydeRetriever | None = None) -> list[RetrievedItem]:
"""Pull top-k semantically-similar items, respecting mode constraints.
For cold_start mode, if a HydeRetriever is supplied, retrieval goes
through HyDE (generate hypothetical items β†’ embed β†’ match real catalog).
This sidesteps catalog-pollution that hurts naive cold-start retrieval.
If HyDE fails or isn't supplied, falls back to the normal query path.
"""
# ── Cold-start: try HyDE first ───────────────────────────────────────
if mode == "cold_start" and hyde is not None:
# Domains the persona's interests suggest β€” default to all three
allowed = ["Books", "Movies_and_TV", "Kindle_Store"]
hyde_candidates = hyde.retrieve(persona, k_candidates=k_candidates,
allowed_domains=allowed)
if hyde_candidates:
log.info(f"Cold-start via HyDE: {len(hyde_candidates)} candidates")
return hyde_candidates
log.warning("HyDE returned no candidates; falling back to normal retrieval")
# ── Normal retrieval path (warm, cross_domain, or HyDE fallback) ─────
query = build_query(persona, mode)
# Exclude items the user has already engaged with
exclude_ids = {s["parent_asin"] for s in persona.history_samples}
domains = None
if mode == "cross_domain":
all_domains = {"Books", "Movies_and_TV", "Kindle_Store"}
unknown = sorted(all_domains - set(persona.domains))
if unknown:
domains = unknown
else:
log.info("Cross-domain requested but user has touched all domains; falling back to warm mode")
elif mode == "warm" and persona.domains:
# Warm mode: retrieve within the domains the user actually engages with.
# A single-domain user gets that one domain; a multi-domain user gets
# all of theirs (retrieve() balances across them). Discovery in NEW
# domains is the job of cross_domain mode, not warm mode.
domains = list(persona.domains)
log.info(f"Warm mode: restricting retrieval to user's domains {domains}")
log.info(f"Retrieving {k_candidates} candidates for mode={mode}, query={query[:120]}")
candidates = retriever.retrieve(
query=query,
k=k_candidates,
domains=domains,
exclude_ids=exclude_ids,
)
log.info(f"Retrieved {len(candidates)} candidates")
return candidates
def build_rerank_prompt(persona: UserPersona, candidates: list[RetrievedItem],
k_final: int, mode: Mode,
reviews_by_item: dict[str, list[ReviewSnippet]] | None = None,
) -> str:
"""Render the rerank prompt β€” persona + candidate list + instructions.
If reviews_by_item is provided, each candidate is enriched with 3-5 real
reader reviews. This is the Stage 2c improvement: instead of judging by
polluted titles, the LLM sees what humans actually said.
"""
parts = ["You are a thoughtful recommendation agent. Your job is to pick the best items for this specific user from a candidate list, and explain each pick with reference to the user's signals.\n"]
parts.append("=" * 60)
parts.append("THE USER")
parts.append("=" * 60)
parts.append(persona.to_prompt_block())
if persona.history_samples and mode != "cold_start":
parts.append("\nRecent things this user engaged with (do NOT recommend these β€” they've already seen them):")
for h in persona.history_samples[:5]:
parts.append(f" - [{h['domain']}] {h['rating']}β˜…: {h['text'][:120]}")
parts.append("\n" + "=" * 60)
parts.append(f"CANDIDATE ITEMS ({len(candidates)} retrieved by semantic search)")
if reviews_by_item:
parts.append("Each candidate includes a sample of REAL reader reviews β€” judge each item by what readers said, not by the title alone (many titles in this catalog are review headlines, not real product titles).")
parts.append("=" * 60)
for c in candidates:
line = f"[{c.item_id}] ({c.domain}) {c.title}"
if c.categories:
line += f" | categories: {c.categories[:80]}"
if c.description and len(c.description) > len(c.title) + 10:
line += f"\n Description: {c.description[:200]}"
# Stage 2c: include real reader reviews if available
snippets = (reviews_by_item or {}).get(c.item_id, [])
if snippets:
line += "\n Reader reviews:"
for s in snippets:
line += f"\n {s.rating}β˜…: {s.text}"
parts.append(line)
parts.append("\n" + "=" * 60)
parts.append("YOUR TASK")
parts.append("=" * 60)
mode_hint = {
"warm": "Pick items that match this user's established tastes β€” themes, voice, rating patterns. Use the reader reviews to confirm tone/pacing/style fit.",
"cold_start": "Pick items that match the user's stated preferences. Be conservative β€” favor well-rated, widely-appealing items in the requested domain. The reader reviews are your most reliable signal here β€” titles in this catalog are noisy.",
"cross_domain": "This user has tastes in some domains but you're recommending in OTHER domains. Find items in the candidate list that bridge their known tastes to the new domain β€” explain the bridge in each reasoning. Use the reader reviews to find genuine thematic bridges.",
}[mode]
parts.append(mode_hint)
parts.append("")
parts.append(
f"Output the top {k_final} items as a ranked list (best first).\n"
f"For each:\n"
f" - 'item_id' must be one of the bracketed IDs above (exactly as written, e.g. 'B0073UKXBE')\n"
f" - 'reasoning' is one sentence citing 1-2 specific signals from the persona, "
f"ideally referencing what the reader reviews revealed about the item "
f"(e.g. 'Reviews call it tight and fast-paced β€” matches their dislike of padding')\n"
f"Do not invent item_ids. Do not repeat items. Order matters β€” best first."
)
return "\n".join(parts)
def rerank(llm: LLMClient, persona: UserPersona, candidates: list[RetrievedItem],
k_final: int, mode: Mode,
aggregator: ReviewAggregator | None = None,
reviews_per_item: int = 4,
enrich_top_n: int = 25,
critique_feedback: str | None = None,
conversation_context: str | None = None) -> list[RerankedItem]:
"""LLM rerank step. Returns up to k_final items with reasoning.
If `aggregator` is provided, the top `enrich_top_n` candidates are
enriched with real reader reviews before reranking. This is the Stage
2c improvement β€” it lets the LLM judge items by human language, not
catalog metadata.
If `critique_feedback` is provided, it is injected into the prompt as
feedback from a previous self-critique pass (Stage 3c) β€” the reranker
is told to fix the flagged problems.
If `conversation_context` is provided, it is injected as multi-turn
dialogue context β€” prior turns, what was recommended, what the user
rejected and why β€” so the reranker reasons over the whole conversation
rather than a single static persona.
"""
if not candidates:
return []
# Stage 2c: fetch reviews for the top N candidates
reviews_by_item: dict[str, list[ReviewSnippet]] = {}
if aggregator is not None:
# Only enrich the top N β€” saves prompt tokens for items unlikely to be picked
top_ids = [c.item_id for c in candidates[:enrich_top_n]]
log.info(f"Fetching {reviews_per_item} reviews each for top {len(top_ids)} candidates")
reviews_by_item = aggregator.get_reviews_for_items(
top_ids, k=reviews_per_item,
exclude_user_id=persona.user_id,
)
n_enriched = sum(1 for r in reviews_by_item.values() if r)
log.info(f" β†’ {n_enriched}/{len(top_ids)} candidates have reviews")
prompt = build_rerank_prompt(persona, candidates, k_final, mode,
reviews_by_item=reviews_by_item)
if conversation_context:
prompt += (
f"\n\n{'=' * 60}\n"
f"MULTI-TURN CONVERSATION CONTEXT\n"
f"{'=' * 60}\n"
f"{conversation_context}\n"
f"This is an ongoing conversation. Treat the request above as the "
f"current turn, building on everything before it. Carry forward "
f"the preferences the user has expressed; honour every rejection "
f"and the stated reason for it; and narrow from earlier "
f"recommendations rather than starting over. The picks should feel "
f"like a continuation of this conversation."
)
if critique_feedback:
prompt += (
f"\n\n{'=' * 60}\n"
f"FEEDBACK FROM A PREVIOUS ATTEMPT β€” FIX THESE ISSUES\n"
f"{'=' * 60}\n"
f"{critique_feedback}\n"
f"Re-pick the top {k_final} addressing this feedback. Avoid the "
f"problems flagged above."
)
try:
result = llm.structured(
prompt,
schema=RerankedList,
model="reasoning",
system="You are an expert recommendation agent that explains every pick.",
)
except Exception as e:
log.error(f"Rerank LLM call failed: {e}; falling back to retrieval order")
return [
RerankedItem(item_id=c.item_id,
reasoning=f"Matched semantic search for this user's profile.")
for c in candidates[:k_final]
]
# Filter to valid item_ids only (LLM occasionally hallucinates)
valid_ids = {c.item_id for c in candidates}
cleaned: list[RerankedItem] = []
seen: set[str] = set()
for item in result.items:
if item.item_id in valid_ids and item.item_id not in seen:
cleaned.append(item)
seen.add(item.item_id)
if len(cleaned) >= k_final:
break
# If LLM returned fewer than k_final valid items, top up from retrieval order
if len(cleaned) < k_final:
for c in candidates:
if c.item_id not in seen:
cleaned.append(RerankedItem(
item_id=c.item_id,
reasoning="Strong semantic match for the user's profile.",
))
seen.add(c.item_id)
if len(cleaned) >= k_final:
break
return cleaned
def format_response(ranked: list[RerankedItem],
candidates_by_id: dict[str, RetrievedItem]) -> list[Recommendation]:
"""Wrap reranked items in the final Recommendation dataclass."""
out: list[Recommendation] = []
for i, item in enumerate(ranked, 1):
cand = candidates_by_id.get(item.item_id)
if cand is None:
continue
out.append(Recommendation(
rank=i,
item_id=cand.item_id,
title=cand.title,
domain=cand.domain,
categories=cand.categories,
reasoning=item.reasoning,
retrieval_distance=cand.distance,
))
return out
# ──────────────────────────────────────────────────────────────────────────────
# Agent
# ──────────────────────────────────────────────────────────────────────────────
class RecommendationAgent:
"""The Task B agent.
Usage:
agent = RecommendationAgent()
recs = agent.run(persona, k=10, cross_domain=False)
# recs is list[Recommendation]
"""
def __init__(self, llm: LLMClient | None = None,
retriever: ItemRetriever | None = None,
aggregator: ReviewAggregator | None = None,
candidates_k: int = 50,
use_review_enrichment: bool = True,
use_hyde: bool = True,
use_reflection: bool = True,
reflection_max_iterations: int = 2):
self.llm = llm or LLMClient()
self.retriever = retriever or ItemRetriever()
self.use_review_enrichment = use_review_enrichment
if use_review_enrichment:
self.aggregator = aggregator or ReviewAggregator()
else:
self.aggregator = None
# HyDE for cold-start retrieval
self.use_hyde = use_hyde
if use_hyde:
self.hyde = HydeRetriever(self.llm, self.retriever)
else:
self.hyde = None
# Self-reflection
self.use_reflection = use_reflection
self.reflection_max_iterations = reflection_max_iterations
self.candidates_k = candidates_k
# Introspection hooks for the demo UI (purely additive)
self.last_mode: str | None = None
self.last_candidate_count: int = 0
self.last_reflection_trace = None
def run(self, persona: UserPersona, k: int = 10,
cross_domain: bool = False,
conversation_context: str | None = None) -> list[Recommendation]:
mode = detect_mode(persona, requested_cross_domain=cross_domain)
log.info(f"Recommendation mode for user {persona.user_id}: {mode}")
# Introspection hooks for the demo UI (purely additive β€” core logic
# never reads these). Reset at the start of each run.
self.last_mode = mode
self.last_candidate_count = 0
self.last_reflection_trace = None
# Step 1+2: retrieve candidates (HyDE for cold-start if enabled)
candidates = retrieve_candidates(
self.retriever, persona, mode, k_candidates=self.candidates_k,
hyde=self.hyde,
)
if not candidates:
log.warning("No candidates retrieved; returning empty list")
return []
cand_by_id = {c.item_id: c for c in candidates}
self.last_candidate_count = len(candidates)
# Step 3: LLM rerank with optional review enrichment (Stage 2c)
ranked = rerank(self.llm, persona, candidates, k_final=k, mode=mode,
aggregator=self.aggregator,
conversation_context=conversation_context)
recs = format_response(ranked, cand_by_id)
# Step 4: self-reflection β€” critique + refine (Stage 3c)
if self.use_reflection and recs:
log.info("Running self-reflection on recommendations")
def _refine(critique_issues: str) -> list[dict]:
"""Re-run rerank with the critique injected, return rec dicts."""
refined_ranked = rerank(
self.llm, persona, candidates, k_final=k, mode=mode,
aggregator=self.aggregator,
critique_feedback=critique_issues,
conversation_context=conversation_context,
)
refined_recs = format_response(refined_ranked, cand_by_id)
return [r.as_dict() for r in refined_recs]
rec_dicts = [r.as_dict() for r in recs]
final_dicts, trace = reflect_on_recommendations(
self.llm, persona, rec_dicts, mode,
refine_fn=_refine,
max_iterations=self.reflection_max_iterations,
)
self.last_reflection_trace = trace # introspection hook for the UI
# Rebuild Recommendation objects from the (possibly refined) dicts
if trace.refined:
recs = _rebuild_recommendations(final_dicts, cand_by_id)
return recs
def _rebuild_recommendations(rec_dicts: list[dict],
cand_by_id: dict[str, RetrievedItem],
) -> list[Recommendation]:
"""Reconstruct Recommendation objects from dicts after a refinement pass."""
out: list[Recommendation] = []
for i, d in enumerate(rec_dicts, 1):
cand = cand_by_id.get(d.get("item_id", ""))
if cand is None:
continue
out.append(Recommendation(
rank=i,
item_id=cand.item_id,
title=cand.title,
domain=cand.domain,
categories=cand.categories,
reasoning=d.get("reasoning", ""),
retrieval_distance=cand.distance,
))
return out