Spaces:
Sleeping
Sleeping
File size: 4,132 Bytes
1b83e76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import type { ExplainerObservation, RewardEntry } from "./types";
const NON_REWARD_KEYS = new Set([
"step",
"phase",
"tool",
"source_count",
"error",
"explore_steps_used",
"repair_steps_used",
"sandbox_message",
"error_codes",
]);
const VISIBLE: Record<string, string[]> = {
explore: ["query_quality", "evidence_quality", "information_gain", "efficiency", "explore_total"],
generate: ["validity", "task_alignment", "structure", "research_usage", "generate_total"],
repair: ["repair_success", "fixed_prior_errors", "changed_code", "repair_total"],
};
const FALLBACK_WEIGHTS: Record<string, Record<string, number>> = {
explore: {
query_quality: 0.2,
evidence_quality: 0.25,
information_gain: 0.4,
efficiency: 0.15,
},
generate: {
validity: 0.15,
task_alignment: 0.3,
structure: 0.3,
research_usage: 0.25,
},
repair: {
repair_success: 0.6,
fixed_prior_errors: 0.2,
changed_code: 0.2,
},
};
export function rewardComponents(obs: ExplainerObservation): Record<string, number | string> {
const meta = obs.metadata || {};
const all: Record<string, number> = {};
for (const [k, v] of Object.entries(meta)) {
if (NON_REWARD_KEYS.has(k)) continue;
if (typeof v === "number" && !Number.isNaN(v)) all[k] = v;
}
const phase = (meta.phase as string) || obs.phase;
const allowed = VISIBLE[phase];
if (allowed) {
const visible: Record<string, number> = {};
for (const k of allowed) if (k in all) visible[k] = all[k];
if (Object.keys(visible).length) return visible;
}
return Object.keys(all).length ? all : parseRewardComponents(obs.feedback);
}
function parseRewardComponents(feedback: string): Record<string, number | string> {
const rewardMatch = feedback.match(/Reward:\s*(\{[\s\S]*?\})(?:\n|$)/);
if (rewardMatch?.[1]) {
const components: Record<string, number> = {};
for (const match of rewardMatch[1].matchAll(
/['"]?([A-Za-z0-9_]+)['"]?\s*:\s*(-?\d+(?:\.\d+)?)/g,
)) {
components[match[1]] = Number(match[2]);
}
if (Object.keys(components).length) return components;
}
const kvMatch = feedback.match(/Reward:\s*([^\n]+)/);
if (!kvMatch?.[1]) return {};
const components: Record<string, number | string> = {};
for (const part of kvMatch[1].split(",")) {
const [rawKey, rawValue] = part.split("=");
if (!rawKey || rawValue === undefined) continue;
const key = rawKey.trim();
const value = rawValue.trim();
const numeric = Number(value);
components[key] = Number.isFinite(numeric) ? numeric : value;
}
return components;
}
export function totalForPhase(
phase: string,
components: Record<string, number | string>,
): number | null {
const totalKey =
phase === "explore"
? "explore_total"
: phase === "generate"
? "generate_total"
: phase === "repair"
? "repair_total"
: null;
if (totalKey && typeof components[totalKey] === "number") {
return components[totalKey] as number;
}
return weightedFallbackTotal(phase, components);
}
function weightedFallbackTotal(
phase: string,
components: Record<string, number | string>,
): number | null {
const weights = FALLBACK_WEIGHTS[phase];
if (!weights) return null;
let total = 0;
let usedWeight = 0;
for (const [key, weight] of Object.entries(weights)) {
const value = components[key];
if (typeof value !== "number") continue;
total += value * weight;
usedWeight += weight;
}
if (usedWeight === 0) return null;
return Math.max(0, Math.min(1, total / usedWeight));
}
export const SUCCESS_SCORE_THRESHOLD = 0.3;
const MAX_EXPLORE_STEPS = 6;
const MAX_EXPLORE_REWARD = 1.0;
const MAX_GENERATE_REWARD = 1.0;
export function normalizedEpisodeScore(rewards: RewardEntry[]): number {
const totalReward = rewards
.map((r) => r.total)
.filter((total): total is number => typeof total === "number")
.reduce((acc, total) => acc + total, 0);
const maxPossible = MAX_EXPLORE_STEPS * MAX_EXPLORE_REWARD + MAX_GENERATE_REWARD;
return Math.max(0, Math.min(1, totalReward / maxPossible));
}
|