"""Reward components for the exploration phase.""" from __future__ import annotations try: from ..constants import MAX_EXPLORE_REWARD, clamp_action_reward from ..research.retrieval import tokenize from ..research.types import ResearchResult except ImportError: # pragma: no cover - supports direct test execution from constants import MAX_EXPLORE_REWARD, clamp_action_reward from research.retrieval import tokenize from research.types import ResearchResult # Weights. Keep the visible reward compact: each component maps to a skill. W_QUERY_QUALITY = 0.20 W_EVIDENCE_QUALITY = 0.25 W_INFORMATION_GAIN = 0.40 W_EFFICIENCY = 0.15 # Flat per-step penalty — the agent must expect enough gain to justify each search STEP_COST = 0.05 # --------------------------------------------------------------------------- # Individual scorers # --------------------------------------------------------------------------- def query_relevance(query: str, topic: str, keywords_csv: str, intent: str = "") -> float: """Score how relevant and specific the search query is to the task (0-1).""" if not query or not query.strip(): return 0.0 query_lower = f"{query} {intent}".strip().lower() score = 0.0 if topic.lower() in query_lower: score += 0.35 keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()] if keywords: hits = sum(1 for kw in keywords if kw in query_lower) score += 0.35 * (hits / len(keywords)) if len(query_lower.split()) >= 3: score += 0.15 if any(term in query_lower for term in ("equation", "intuition", "visual", "example", "code")): score += 0.15 return min(1.0, score) def result_novelty(new_content: str, previous_context: list[str]) -> float: """Score how much new information this result adds (0-1).""" if not new_content or not new_content.strip(): return 0.0 if not previous_context: return 1.0 new_words = set(tokenize(new_content)) seen_words: set[str] = set() for ctx in previous_context: seen_words.update(tokenize(ctx)) if not new_words: return 0.0 novel = new_words - seen_words return min(1.0, len(novel) / max(len(new_words), 1)) def research_breadth(accumulated_context: list[str], min_sources: int = 2) -> float: """Score whether the agent gathered enough sources (0-1).""" n = len(accumulated_context) return 1.0 if n >= min_sources else n / min_sources def content_sufficiency( task_content: str, keywords_csv: str, accumulated_context: list[str], ) -> float: """Fraction of task keywords covered in task content + research (0-1).""" keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()] if not keywords: return 1.0 combined = _normalized_text(task_content) for ctx in accumulated_context: combined += " " + _normalized_text(ctx) return sum(1 for kw in keywords if _normalized_text(kw) in combined) / len(keywords) def tool_choice_score(tool: str, difficulty: str, query: str, intent: str = "") -> float: """Reward selecting a tool that fits the task/source need.""" text = f"{query} {intent}".lower() codeish = any(term in text for term in ("api", "code", "library", "plot", "chart", "lint", "marimo", "manim")) paperish = any(term in text for term in ("paper", "research", "recent", "citation", "state of the art")) hubish = any(term in text for term in ("model", "dataset", "space", "hugging face", "hf hub")) if codeish: return 1.0 if tool == "fetch_docs" else 0.65 if hubish: return 1.0 if tool == "search_hf_hub" else 0.65 if difficulty == "hard" or paperish: return 1.0 if tool in {"search_arxiv", "search_hf_papers", "search_scholar"} else 0.55 if difficulty == "medium": return 1.0 if tool in {"search_wikipedia", "search_arxiv", "search_scholar"} else 0.75 return 1.0 if tool in {"search_wikipedia", "fetch_docs"} else 0.65 def source_quality(result: ResearchResult) -> float: """Measure whether retrieved chunks are usable, not whether the source is trusted.""" if result.error or not result.chunks: return 0.0 scores = [] for chunk in result.chunks: metadata = 0.0 if chunk.title: metadata += 0.35 if chunk.url: metadata += 0.35 if chunk.metadata: metadata += 0.30 word_count = len(tokenize(chunk.text)) content = min(1.0, word_count / 80) relevance = 1.0 if chunk.score > 0 else 0.5 scores.append(0.35 * metadata + 0.45 * content + 0.20 * relevance) return min(1.0, sum(scores) / max(len(scores), 1)) def coverage_delta( keywords_csv: str, task_content: str, previous_context: list[str], new_content: str, ) -> float: """Score newly covered keywords/concepts added by this result.""" keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()] if not keywords: return 1.0 before = _normalized_text(task_content + " " + " ".join(previous_context)) after = before + " " + _normalized_text(new_content) missing_before = [kw for kw in keywords if _normalized_text(kw) not in before] if not missing_before: return 0.0 newly_covered = sum(1 for kw in missing_before if _normalized_text(kw) in after) return newly_covered / len(missing_before) def diversity_score(tool: str, used_tools: set[str], result: ResearchResult) -> float: """Reward useful source diversity without rewarding empty calls.""" if result.error or not result.chunks: return 0.0 if tool not in used_tools: return 1.0 unique_urls = {chunk.url for chunk in result.chunks if chunk.url} return 0.5 if len(unique_urls) > 1 else 0.25 def action_novelty(tool: str, query: str, intent: str, previous_actions: list[str]) -> float: """Score whether this explore action asks for genuinely new information.""" if not previous_actions: return 1.0 current = _action_text(tool, query, intent) max_similarity = max(_jaccard(current, previous) for previous in previous_actions) return max(0.0, 1.0 - max_similarity) # --------------------------------------------------------------------------- # Gating # --------------------------------------------------------------------------- def _exploration_gate(sufficiency: float) -> float: """Multiplier based on information need. High sufficiency → low multiplier (exploration has little value). Low sufficiency → high multiplier (exploration has high value). Range: [0.3, 1.0]. """ info_need = max(0.0, 1.0 - sufficiency) return 0.3 + 0.7 * info_need # --------------------------------------------------------------------------- # Main reward function # --------------------------------------------------------------------------- def compute_explore_reward( query: str, tool: str, intent: str, result: ResearchResult, topic: str, keywords_csv: str, task_content: str, difficulty: str, previous_context: list[str], accumulated_context: list[str], used_tools: set[str] | None = None, previous_actions: list[str] | None = None, ) -> tuple[float, dict]: """Compute per-step exploration reward. Returns (total, components).""" used_tools = used_tools or set() previous_actions = previous_actions or [] result_text = result.text result_ok = result.ok t_choice = tool_choice_score(tool, difficulty, query, intent) q_rel = query_relevance(query, topic, keywords_csv, intent) query_quality = 0.65 * q_rel + 0.35 * t_choice src_quality = source_quality(result) if result_ok else 0.0 diversity = diversity_score(tool, used_tools, result) if result_ok else 0.0 evidence_quality = 0.75 * src_quality + 0.25 * diversity delta = coverage_delta(keywords_csv, task_content, previous_context, result_text) novelty = result_novelty(result_text, previous_context) if result_ok else 0.0 information_gain = 0.70 * delta + 0.30 * novelty if result_ok else 0.0 act_novelty = action_novelty(tool, query, intent, previous_actions) sufficiency_before = content_sufficiency(task_content, keywords_csv, previous_context) sufficiency_after = content_sufficiency(task_content, keywords_csv, accumulated_context) info_need = max(0.0, 1.0 - sufficiency_before) efficiency = act_novelty * (0.35 + 0.65 * info_need) if result_ok else 0.0 raw = ( W_QUERY_QUALITY * query_quality + W_EVIDENCE_QUALITY * evidence_quality + W_INFORMATION_GAIN * information_gain + W_EFFICIENCY * efficiency ) gate = _exploration_gate(sufficiency_after) if result_ok else 0.0 total = raw * gate + 0.08 * info_need - STEP_COST total = min(MAX_EXPLORE_REWARD, clamp_action_reward(total)) components = { "query_quality": round(query_quality, 3), "evidence_quality": round(evidence_quality, 3), "information_gain": round(information_gain, 3), "efficiency": round(efficiency, 3), "explore_total": round(total, 4), } return total, components def _normalized_text(text: str) -> str: return " ".join(tokenize(text)) def _action_text(tool: str, query: str, intent: str) -> str: return " ".join(tokenize(f"{tool} {query} {intent}")) def _jaccard(left: str, right: str) -> float: left_tokens = set(left.split()) right_tokens = set(right.split()) if not left_tokens or not right_tokens: return 0.0 return len(left_tokens & right_tokens) / len(left_tokens | right_tokens)