File size: 4,132 Bytes
1b83e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import type { ExplainerObservation, RewardEntry } from "./types";

const NON_REWARD_KEYS = new Set([
  "step",
  "phase",
  "tool",
  "source_count",
  "error",
  "explore_steps_used",
  "repair_steps_used",
  "sandbox_message",
  "error_codes",
]);

const VISIBLE: Record<string, string[]> = {
  explore: ["query_quality", "evidence_quality", "information_gain", "efficiency", "explore_total"],
  generate: ["validity", "task_alignment", "structure", "research_usage", "generate_total"],
  repair: ["repair_success", "fixed_prior_errors", "changed_code", "repair_total"],
};

const FALLBACK_WEIGHTS: Record<string, Record<string, number>> = {
  explore: {
    query_quality: 0.2,
    evidence_quality: 0.25,
    information_gain: 0.4,
    efficiency: 0.15,
  },
  generate: {
    validity: 0.15,
    task_alignment: 0.3,
    structure: 0.3,
    research_usage: 0.25,
  },
  repair: {
    repair_success: 0.6,
    fixed_prior_errors: 0.2,
    changed_code: 0.2,
  },
};

export function rewardComponents(obs: ExplainerObservation): Record<string, number | string> {
  const meta = obs.metadata || {};
  const all: Record<string, number> = {};
  for (const [k, v] of Object.entries(meta)) {
    if (NON_REWARD_KEYS.has(k)) continue;
    if (typeof v === "number" && !Number.isNaN(v)) all[k] = v;
  }
  const phase = (meta.phase as string) || obs.phase;
  const allowed = VISIBLE[phase];
  if (allowed) {
    const visible: Record<string, number> = {};
    for (const k of allowed) if (k in all) visible[k] = all[k];
    if (Object.keys(visible).length) return visible;
  }
  return Object.keys(all).length ? all : parseRewardComponents(obs.feedback);
}

function parseRewardComponents(feedback: string): Record<string, number | string> {
  const rewardMatch = feedback.match(/Reward:\s*(\{[\s\S]*?\})(?:\n|$)/);
  if (rewardMatch?.[1]) {
    const components: Record<string, number> = {};
    for (const match of rewardMatch[1].matchAll(
      /['"]?([A-Za-z0-9_]+)['"]?\s*:\s*(-?\d+(?:\.\d+)?)/g,
    )) {
      components[match[1]] = Number(match[2]);
    }
    if (Object.keys(components).length) return components;
  }

  const kvMatch = feedback.match(/Reward:\s*([^\n]+)/);
  if (!kvMatch?.[1]) return {};

  const components: Record<string, number | string> = {};
  for (const part of kvMatch[1].split(",")) {
    const [rawKey, rawValue] = part.split("=");
    if (!rawKey || rawValue === undefined) continue;
    const key = rawKey.trim();
    const value = rawValue.trim();
    const numeric = Number(value);
    components[key] = Number.isFinite(numeric) ? numeric : value;
  }
  return components;
}

export function totalForPhase(
  phase: string,
  components: Record<string, number | string>,
): number | null {
  const totalKey =
    phase === "explore"
      ? "explore_total"
      : phase === "generate"
        ? "generate_total"
        : phase === "repair"
          ? "repair_total"
          : null;
  if (totalKey && typeof components[totalKey] === "number") {
    return components[totalKey] as number;
  }
  return weightedFallbackTotal(phase, components);
}

function weightedFallbackTotal(
  phase: string,
  components: Record<string, number | string>,
): number | null {
  const weights = FALLBACK_WEIGHTS[phase];
  if (!weights) return null;

  let total = 0;
  let usedWeight = 0;
  for (const [key, weight] of Object.entries(weights)) {
    const value = components[key];
    if (typeof value !== "number") continue;
    total += value * weight;
    usedWeight += weight;
  }

  if (usedWeight === 0) return null;
  return Math.max(0, Math.min(1, total / usedWeight));
}

export const SUCCESS_SCORE_THRESHOLD = 0.3;
const MAX_EXPLORE_STEPS = 6;
const MAX_EXPLORE_REWARD = 1.0;
const MAX_GENERATE_REWARD = 1.0;

export function normalizedEpisodeScore(rewards: RewardEntry[]): number {
  const totalReward = rewards
    .map((r) => r.total)
    .filter((total): total is number => typeof total === "number")
    .reduce((acc, total) => acc + total, 0);
  const maxPossible = MAX_EXPLORE_STEPS * MAX_EXPLORE_REWARD + MAX_GENERATE_REWARD;
  return Math.max(0, Math.min(1, totalReward / maxPossible));
}