Spaces:
Sleeping
Sleeping
| """Reward components for the exploration phase.""" | |
| from __future__ import annotations | |
| try: | |
| from ..constants import MAX_EXPLORE_REWARD, clamp_action_reward | |
| from ..research.retrieval import tokenize | |
| from ..research.types import ResearchResult | |
| except ImportError: # pragma: no cover - supports direct test execution | |
| from constants import MAX_EXPLORE_REWARD, clamp_action_reward | |
| from research.retrieval import tokenize | |
| from research.types import ResearchResult | |
| # Weights. Keep the visible reward compact: each component maps to a skill. | |
| W_QUERY_QUALITY = 0.20 | |
| W_EVIDENCE_QUALITY = 0.25 | |
| W_INFORMATION_GAIN = 0.40 | |
| W_EFFICIENCY = 0.15 | |
| # Flat per-step penalty — the agent must expect enough gain to justify each search | |
| STEP_COST = 0.05 | |
| # --------------------------------------------------------------------------- | |
| # Individual scorers | |
| # --------------------------------------------------------------------------- | |
| def query_relevance(query: str, topic: str, keywords_csv: str, intent: str = "") -> float: | |
| """Score how relevant and specific the search query is to the task (0-1).""" | |
| if not query or not query.strip(): | |
| return 0.0 | |
| query_lower = f"{query} {intent}".strip().lower() | |
| score = 0.0 | |
| if topic.lower() in query_lower: | |
| score += 0.35 | |
| keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()] | |
| if keywords: | |
| hits = sum(1 for kw in keywords if kw in query_lower) | |
| score += 0.35 * (hits / len(keywords)) | |
| if len(query_lower.split()) >= 3: | |
| score += 0.15 | |
| if any(term in query_lower for term in ("equation", "intuition", "visual", "example", "code")): | |
| score += 0.15 | |
| return min(1.0, score) | |
| def result_novelty(new_content: str, previous_context: list[str]) -> float: | |
| """Score how much new information this result adds (0-1).""" | |
| if not new_content or not new_content.strip(): | |
| return 0.0 | |
| if not previous_context: | |
| return 1.0 | |
| new_words = set(tokenize(new_content)) | |
| seen_words: set[str] = set() | |
| for ctx in previous_context: | |
| seen_words.update(tokenize(ctx)) | |
| if not new_words: | |
| return 0.0 | |
| novel = new_words - seen_words | |
| return min(1.0, len(novel) / max(len(new_words), 1)) | |
| def research_breadth(accumulated_context: list[str], min_sources: int = 2) -> float: | |
| """Score whether the agent gathered enough sources (0-1).""" | |
| n = len(accumulated_context) | |
| return 1.0 if n >= min_sources else n / min_sources | |
| def content_sufficiency( | |
| task_content: str, keywords_csv: str, accumulated_context: list[str], | |
| ) -> float: | |
| """Fraction of task keywords covered in task content + research (0-1).""" | |
| keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()] | |
| if not keywords: | |
| return 1.0 | |
| combined = _normalized_text(task_content) | |
| for ctx in accumulated_context: | |
| combined += " " + _normalized_text(ctx) | |
| return sum(1 for kw in keywords if _normalized_text(kw) in combined) / len(keywords) | |
| def tool_choice_score(tool: str, difficulty: str, query: str, intent: str = "") -> float: | |
| """Reward selecting a tool that fits the task/source need.""" | |
| text = f"{query} {intent}".lower() | |
| codeish = any(term in text for term in ("api", "code", "library", "plot", "chart", "lint", "marimo", "manim")) | |
| paperish = any(term in text for term in ("paper", "research", "recent", "citation", "state of the art")) | |
| hubish = any(term in text for term in ("model", "dataset", "space", "hugging face", "hf hub")) | |
| if codeish: | |
| return 1.0 if tool == "fetch_docs" else 0.65 | |
| if hubish: | |
| return 1.0 if tool == "search_hf_hub" else 0.65 | |
| if difficulty == "hard" or paperish: | |
| return 1.0 if tool in {"search_arxiv", "search_hf_papers", "search_scholar"} else 0.55 | |
| if difficulty == "medium": | |
| return 1.0 if tool in {"search_wikipedia", "search_arxiv", "search_scholar"} else 0.75 | |
| return 1.0 if tool in {"search_wikipedia", "fetch_docs"} else 0.65 | |
| def source_quality(result: ResearchResult) -> float: | |
| """Measure whether retrieved chunks are usable, not whether the source is trusted.""" | |
| if result.error or not result.chunks: | |
| return 0.0 | |
| scores = [] | |
| for chunk in result.chunks: | |
| metadata = 0.0 | |
| if chunk.title: | |
| metadata += 0.35 | |
| if chunk.url: | |
| metadata += 0.35 | |
| if chunk.metadata: | |
| metadata += 0.30 | |
| word_count = len(tokenize(chunk.text)) | |
| content = min(1.0, word_count / 80) | |
| relevance = 1.0 if chunk.score > 0 else 0.5 | |
| scores.append(0.35 * metadata + 0.45 * content + 0.20 * relevance) | |
| return min(1.0, sum(scores) / max(len(scores), 1)) | |
| def coverage_delta( | |
| keywords_csv: str, | |
| task_content: str, | |
| previous_context: list[str], | |
| new_content: str, | |
| ) -> float: | |
| """Score newly covered keywords/concepts added by this result.""" | |
| keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()] | |
| if not keywords: | |
| return 1.0 | |
| before = _normalized_text(task_content + " " + " ".join(previous_context)) | |
| after = before + " " + _normalized_text(new_content) | |
| missing_before = [kw for kw in keywords if _normalized_text(kw) not in before] | |
| if not missing_before: | |
| return 0.0 | |
| newly_covered = sum(1 for kw in missing_before if _normalized_text(kw) in after) | |
| return newly_covered / len(missing_before) | |
| def diversity_score(tool: str, used_tools: set[str], result: ResearchResult) -> float: | |
| """Reward useful source diversity without rewarding empty calls.""" | |
| if result.error or not result.chunks: | |
| return 0.0 | |
| if tool not in used_tools: | |
| return 1.0 | |
| unique_urls = {chunk.url for chunk in result.chunks if chunk.url} | |
| return 0.5 if len(unique_urls) > 1 else 0.25 | |
| def action_novelty(tool: str, query: str, intent: str, previous_actions: list[str]) -> float: | |
| """Score whether this explore action asks for genuinely new information.""" | |
| if not previous_actions: | |
| return 1.0 | |
| current = _action_text(tool, query, intent) | |
| max_similarity = max(_jaccard(current, previous) for previous in previous_actions) | |
| return max(0.0, 1.0 - max_similarity) | |
| # --------------------------------------------------------------------------- | |
| # Gating | |
| # --------------------------------------------------------------------------- | |
| def _exploration_gate(sufficiency: float) -> float: | |
| """Multiplier based on information need. | |
| High sufficiency → low multiplier (exploration has little value). | |
| Low sufficiency → high multiplier (exploration has high value). | |
| Range: [0.3, 1.0]. | |
| """ | |
| info_need = max(0.0, 1.0 - sufficiency) | |
| return 0.3 + 0.7 * info_need | |
| # --------------------------------------------------------------------------- | |
| # Main reward function | |
| # --------------------------------------------------------------------------- | |
| def compute_explore_reward( | |
| query: str, | |
| tool: str, | |
| intent: str, | |
| result: ResearchResult, | |
| topic: str, | |
| keywords_csv: str, | |
| task_content: str, | |
| difficulty: str, | |
| previous_context: list[str], | |
| accumulated_context: list[str], | |
| used_tools: set[str] | None = None, | |
| previous_actions: list[str] | None = None, | |
| ) -> tuple[float, dict]: | |
| """Compute per-step exploration reward. Returns (total, components).""" | |
| used_tools = used_tools or set() | |
| previous_actions = previous_actions or [] | |
| result_text = result.text | |
| result_ok = result.ok | |
| t_choice = tool_choice_score(tool, difficulty, query, intent) | |
| q_rel = query_relevance(query, topic, keywords_csv, intent) | |
| query_quality = 0.65 * q_rel + 0.35 * t_choice | |
| src_quality = source_quality(result) if result_ok else 0.0 | |
| diversity = diversity_score(tool, used_tools, result) if result_ok else 0.0 | |
| evidence_quality = 0.75 * src_quality + 0.25 * diversity | |
| delta = coverage_delta(keywords_csv, task_content, previous_context, result_text) | |
| novelty = result_novelty(result_text, previous_context) if result_ok else 0.0 | |
| information_gain = 0.70 * delta + 0.30 * novelty if result_ok else 0.0 | |
| act_novelty = action_novelty(tool, query, intent, previous_actions) | |
| sufficiency_before = content_sufficiency(task_content, keywords_csv, previous_context) | |
| sufficiency_after = content_sufficiency(task_content, keywords_csv, accumulated_context) | |
| info_need = max(0.0, 1.0 - sufficiency_before) | |
| efficiency = act_novelty * (0.35 + 0.65 * info_need) if result_ok else 0.0 | |
| raw = ( | |
| W_QUERY_QUALITY * query_quality | |
| + W_EVIDENCE_QUALITY * evidence_quality | |
| + W_INFORMATION_GAIN * information_gain | |
| + W_EFFICIENCY * efficiency | |
| ) | |
| gate = _exploration_gate(sufficiency_after) if result_ok else 0.0 | |
| total = raw * gate + 0.08 * info_need - STEP_COST | |
| total = min(MAX_EXPLORE_REWARD, clamp_action_reward(total)) | |
| components = { | |
| "query_quality": round(query_quality, 3), | |
| "evidence_quality": round(evidence_quality, 3), | |
| "information_gain": round(information_gain, 3), | |
| "efficiency": round(efficiency, 3), | |
| "explore_total": round(total, 4), | |
| } | |
| return total, components | |
| def _normalized_text(text: str) -> str: | |
| return " ".join(tokenize(text)) | |
| def _action_text(tool: str, query: str, intent: str) -> str: | |
| return " ".join(tokenize(f"{tool} {query} {intent}")) | |
| def _jaccard(left: str, right: str) -> float: | |
| left_tokens = set(left.split()) | |
| right_tokens = set(right.split()) | |
| if not left_tokens or not right_tokens: | |
| return 0.0 | |
| return len(left_tokens & right_tokens) / len(left_tokens | right_tokens) | |