explainer-env / rewards /exploration.py
kgdrathan's picture
Upload folder using huggingface_hub
5869d56 verified
"""Reward components for the exploration phase."""
from __future__ import annotations
try:
from ..constants import MAX_EXPLORE_REWARD, clamp_action_reward
from ..research.retrieval import tokenize
from ..research.types import ResearchResult
except ImportError: # pragma: no cover - supports direct test execution
from constants import MAX_EXPLORE_REWARD, clamp_action_reward
from research.retrieval import tokenize
from research.types import ResearchResult
# Weights. Keep the visible reward compact: each component maps to a skill.
W_QUERY_QUALITY = 0.20
W_EVIDENCE_QUALITY = 0.25
W_INFORMATION_GAIN = 0.40
W_EFFICIENCY = 0.15
# Flat per-step penalty — the agent must expect enough gain to justify each search
STEP_COST = 0.05
# ---------------------------------------------------------------------------
# Individual scorers
# ---------------------------------------------------------------------------
def query_relevance(query: str, topic: str, keywords_csv: str, intent: str = "") -> float:
"""Score how relevant and specific the search query is to the task (0-1)."""
if not query or not query.strip():
return 0.0
query_lower = f"{query} {intent}".strip().lower()
score = 0.0
if topic.lower() in query_lower:
score += 0.35
keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()]
if keywords:
hits = sum(1 for kw in keywords if kw in query_lower)
score += 0.35 * (hits / len(keywords))
if len(query_lower.split()) >= 3:
score += 0.15
if any(term in query_lower for term in ("equation", "intuition", "visual", "example", "code")):
score += 0.15
return min(1.0, score)
def result_novelty(new_content: str, previous_context: list[str]) -> float:
"""Score how much new information this result adds (0-1)."""
if not new_content or not new_content.strip():
return 0.0
if not previous_context:
return 1.0
new_words = set(tokenize(new_content))
seen_words: set[str] = set()
for ctx in previous_context:
seen_words.update(tokenize(ctx))
if not new_words:
return 0.0
novel = new_words - seen_words
return min(1.0, len(novel) / max(len(new_words), 1))
def research_breadth(accumulated_context: list[str], min_sources: int = 2) -> float:
"""Score whether the agent gathered enough sources (0-1)."""
n = len(accumulated_context)
return 1.0 if n >= min_sources else n / min_sources
def content_sufficiency(
task_content: str, keywords_csv: str, accumulated_context: list[str],
) -> float:
"""Fraction of task keywords covered in task content + research (0-1)."""
keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()]
if not keywords:
return 1.0
combined = _normalized_text(task_content)
for ctx in accumulated_context:
combined += " " + _normalized_text(ctx)
return sum(1 for kw in keywords if _normalized_text(kw) in combined) / len(keywords)
def tool_choice_score(tool: str, difficulty: str, query: str, intent: str = "") -> float:
"""Reward selecting a tool that fits the task/source need."""
text = f"{query} {intent}".lower()
codeish = any(term in text for term in ("api", "code", "library", "plot", "chart", "lint", "marimo", "manim"))
paperish = any(term in text for term in ("paper", "research", "recent", "citation", "state of the art"))
hubish = any(term in text for term in ("model", "dataset", "space", "hugging face", "hf hub"))
if codeish:
return 1.0 if tool == "fetch_docs" else 0.65
if hubish:
return 1.0 if tool == "search_hf_hub" else 0.65
if difficulty == "hard" or paperish:
return 1.0 if tool in {"search_arxiv", "search_hf_papers", "search_scholar"} else 0.55
if difficulty == "medium":
return 1.0 if tool in {"search_wikipedia", "search_arxiv", "search_scholar"} else 0.75
return 1.0 if tool in {"search_wikipedia", "fetch_docs"} else 0.65
def source_quality(result: ResearchResult) -> float:
"""Measure whether retrieved chunks are usable, not whether the source is trusted."""
if result.error or not result.chunks:
return 0.0
scores = []
for chunk in result.chunks:
metadata = 0.0
if chunk.title:
metadata += 0.35
if chunk.url:
metadata += 0.35
if chunk.metadata:
metadata += 0.30
word_count = len(tokenize(chunk.text))
content = min(1.0, word_count / 80)
relevance = 1.0 if chunk.score > 0 else 0.5
scores.append(0.35 * metadata + 0.45 * content + 0.20 * relevance)
return min(1.0, sum(scores) / max(len(scores), 1))
def coverage_delta(
keywords_csv: str,
task_content: str,
previous_context: list[str],
new_content: str,
) -> float:
"""Score newly covered keywords/concepts added by this result."""
keywords = [k.strip().lower() for k in keywords_csv.split(",") if k.strip()]
if not keywords:
return 1.0
before = _normalized_text(task_content + " " + " ".join(previous_context))
after = before + " " + _normalized_text(new_content)
missing_before = [kw for kw in keywords if _normalized_text(kw) not in before]
if not missing_before:
return 0.0
newly_covered = sum(1 for kw in missing_before if _normalized_text(kw) in after)
return newly_covered / len(missing_before)
def diversity_score(tool: str, used_tools: set[str], result: ResearchResult) -> float:
"""Reward useful source diversity without rewarding empty calls."""
if result.error or not result.chunks:
return 0.0
if tool not in used_tools:
return 1.0
unique_urls = {chunk.url for chunk in result.chunks if chunk.url}
return 0.5 if len(unique_urls) > 1 else 0.25
def action_novelty(tool: str, query: str, intent: str, previous_actions: list[str]) -> float:
"""Score whether this explore action asks for genuinely new information."""
if not previous_actions:
return 1.0
current = _action_text(tool, query, intent)
max_similarity = max(_jaccard(current, previous) for previous in previous_actions)
return max(0.0, 1.0 - max_similarity)
# ---------------------------------------------------------------------------
# Gating
# ---------------------------------------------------------------------------
def _exploration_gate(sufficiency: float) -> float:
"""Multiplier based on information need.
High sufficiency → low multiplier (exploration has little value).
Low sufficiency → high multiplier (exploration has high value).
Range: [0.3, 1.0].
"""
info_need = max(0.0, 1.0 - sufficiency)
return 0.3 + 0.7 * info_need
# ---------------------------------------------------------------------------
# Main reward function
# ---------------------------------------------------------------------------
def compute_explore_reward(
query: str,
tool: str,
intent: str,
result: ResearchResult,
topic: str,
keywords_csv: str,
task_content: str,
difficulty: str,
previous_context: list[str],
accumulated_context: list[str],
used_tools: set[str] | None = None,
previous_actions: list[str] | None = None,
) -> tuple[float, dict]:
"""Compute per-step exploration reward. Returns (total, components)."""
used_tools = used_tools or set()
previous_actions = previous_actions or []
result_text = result.text
result_ok = result.ok
t_choice = tool_choice_score(tool, difficulty, query, intent)
q_rel = query_relevance(query, topic, keywords_csv, intent)
query_quality = 0.65 * q_rel + 0.35 * t_choice
src_quality = source_quality(result) if result_ok else 0.0
diversity = diversity_score(tool, used_tools, result) if result_ok else 0.0
evidence_quality = 0.75 * src_quality + 0.25 * diversity
delta = coverage_delta(keywords_csv, task_content, previous_context, result_text)
novelty = result_novelty(result_text, previous_context) if result_ok else 0.0
information_gain = 0.70 * delta + 0.30 * novelty if result_ok else 0.0
act_novelty = action_novelty(tool, query, intent, previous_actions)
sufficiency_before = content_sufficiency(task_content, keywords_csv, previous_context)
sufficiency_after = content_sufficiency(task_content, keywords_csv, accumulated_context)
info_need = max(0.0, 1.0 - sufficiency_before)
efficiency = act_novelty * (0.35 + 0.65 * info_need) if result_ok else 0.0
raw = (
W_QUERY_QUALITY * query_quality
+ W_EVIDENCE_QUALITY * evidence_quality
+ W_INFORMATION_GAIN * information_gain
+ W_EFFICIENCY * efficiency
)
gate = _exploration_gate(sufficiency_after) if result_ok else 0.0
total = raw * gate + 0.08 * info_need - STEP_COST
total = min(MAX_EXPLORE_REWARD, clamp_action_reward(total))
components = {
"query_quality": round(query_quality, 3),
"evidence_quality": round(evidence_quality, 3),
"information_gain": round(information_gain, 3),
"efficiency": round(efficiency, 3),
"explore_total": round(total, 4),
}
return total, components
def _normalized_text(text: str) -> str:
return " ".join(tokenize(text))
def _action_text(tool: str, query: str, intent: str) -> str:
return " ".join(tokenize(f"{tool} {query} {intent}"))
def _jaccard(left: str, right: str) -> float:
left_tokens = set(left.split())
right_tokens = set(right.split())
if not left_tokens or not right_tokens:
return 0.0
return len(left_tokens & right_tokens) / len(left_tokens | right_tokens)