import type { ExplainerObservation, RewardEntry } from "./types"; const NON_REWARD_KEYS = new Set([ "step", "phase", "tool", "source_count", "error", "explore_steps_used", "repair_steps_used", "sandbox_message", "error_codes", ]); const VISIBLE: Record = { explore: ["query_quality", "evidence_quality", "information_gain", "efficiency", "explore_total"], generate: ["validity", "task_alignment", "structure", "research_usage", "generate_total"], repair: ["repair_success", "fixed_prior_errors", "changed_code", "repair_total"], }; const FALLBACK_WEIGHTS: Record> = { explore: { query_quality: 0.2, evidence_quality: 0.25, information_gain: 0.4, efficiency: 0.15, }, generate: { validity: 0.15, task_alignment: 0.3, structure: 0.3, research_usage: 0.25, }, repair: { repair_success: 0.6, fixed_prior_errors: 0.2, changed_code: 0.2, }, }; export function rewardComponents(obs: ExplainerObservation): Record { const meta = obs.metadata || {}; const all: Record = {}; for (const [k, v] of Object.entries(meta)) { if (NON_REWARD_KEYS.has(k)) continue; if (typeof v === "number" && !Number.isNaN(v)) all[k] = v; } const phase = (meta.phase as string) || obs.phase; const allowed = VISIBLE[phase]; if (allowed) { const visible: Record = {}; for (const k of allowed) if (k in all) visible[k] = all[k]; if (Object.keys(visible).length) return visible; } return Object.keys(all).length ? all : parseRewardComponents(obs.feedback); } function parseRewardComponents(feedback: string): Record { const rewardMatch = feedback.match(/Reward:\s*(\{[\s\S]*?\})(?:\n|$)/); if (rewardMatch?.[1]) { const components: Record = {}; for (const match of rewardMatch[1].matchAll( /['"]?([A-Za-z0-9_]+)['"]?\s*:\s*(-?\d+(?:\.\d+)?)/g, )) { components[match[1]] = Number(match[2]); } if (Object.keys(components).length) return components; } const kvMatch = feedback.match(/Reward:\s*([^\n]+)/); if (!kvMatch?.[1]) return {}; const components: Record = {}; for (const part of kvMatch[1].split(",")) { const [rawKey, rawValue] = part.split("="); if (!rawKey || rawValue === undefined) continue; const key = rawKey.trim(); const value = rawValue.trim(); const numeric = Number(value); components[key] = Number.isFinite(numeric) ? numeric : value; } return components; } export function totalForPhase( phase: string, components: Record, ): number | null { const totalKey = phase === "explore" ? "explore_total" : phase === "generate" ? "generate_total" : phase === "repair" ? "repair_total" : null; if (totalKey && typeof components[totalKey] === "number") { return components[totalKey] as number; } return weightedFallbackTotal(phase, components); } function weightedFallbackTotal( phase: string, components: Record, ): number | null { const weights = FALLBACK_WEIGHTS[phase]; if (!weights) return null; let total = 0; let usedWeight = 0; for (const [key, weight] of Object.entries(weights)) { const value = components[key]; if (typeof value !== "number") continue; total += value * weight; usedWeight += weight; } if (usedWeight === 0) return null; return Math.max(0, Math.min(1, total / usedWeight)); } export const SUCCESS_SCORE_THRESHOLD = 0.3; const MAX_EXPLORE_STEPS = 6; const MAX_EXPLORE_REWARD = 1.0; const MAX_GENERATE_REWARD = 1.0; export function normalizedEpisodeScore(rewards: RewardEntry[]): number { const totalReward = rewards .map((r) => r.total) .filter((total): total is number => typeof total === "number") .reduce((acc, total) => acc + total, 0); const maxPossible = MAX_EXPLORE_STEPS * MAX_EXPLORE_REWARD + MAX_GENERATE_REWARD; return Math.max(0, Math.min(1, totalReward / maxPossible)); }