Eric Xu commited on
Commit
9198e06
·
unverified ·
1 Parent(s): 2934f67

Add goal-weighted gradient (VJP) — optimize toward objectives, not universal appeal

Browse files

The semantic gradient is now a vector-Jacobian product: each evaluator
is weighted by their relevance to the user's stated goal.

Uniform: nabla_j = (1/n) sum_i J_ij (what pleases everyone)
VJP: nabla_j = sum_i v_i * J_ij (what moves toward goal)

Without a goal, behavior is unchanged (uniform weights). With a goal,
the LLM scores each evaluator's relevance (0-1) and the gradient
prioritizes changes that matter to the right audience.

- counterfactual.py: add --goal flag, compute_goal_weights(), VJP in analyze_gradient()
- web: add goal input field, pass through to counterfactual endpoint
- README: update math section with VJP formulation

Files changed (4) hide show
  1. README.md +15 -7
  2. scripts/counterfactual.py +115 -16
  3. web/app.py +26 -5
  4. web/static/index.html +14 -1
README.md CHANGED
@@ -261,17 +261,23 @@ The gap between SGO and real expert panels has three components:
261
 
262
  ## The Semantic Gradient
263
 
264
- For evaluators in the "movable middle" (scores 4–7), SGO asks: *"if this changed, what's your new score?"*
265
-
266
- This produces a Jacobian matrix where each cell is a score delta:
267
 
268
  $$J_{ij} = f(\theta + \Delta_j, \; x_i) - f(\theta, \; x_i)$$
269
 
270
- The semantic gradient is the column mean — the average impact of each change across the panel:
 
 
 
 
 
 
 
 
271
 
272
- $$\nabla_j = \frac{1}{n}\sum_{i} J_{ij}$$
273
 
274
- Rank by this value descending: that's your priority list.
275
 
276
  ### What to probe
277
 
@@ -290,10 +296,12 @@ Only probe changes you'd actually make:
290
  |--------|---------|
291
  | θ | Entity you control |
292
  | x | Evaluator persona |
 
293
  | f(θ, x) | LLM evaluation → score + reasoning |
 
294
  | Δⱼ | Hypothetical change |
295
  | Jᵢⱼ | Score delta: evaluator *i*, change *j* |
296
- | ∇ⱼ | Semantic gradient: mean impact of change *j* |
297
 
298
  ## Project Structure
299
 
 
261
 
262
  ## The Semantic Gradient
263
 
264
+ SGO computes a Jacobian matrix of score deltas how each evaluator's score would shift for each hypothetical change:
 
 
265
 
266
  $$J_{ij} = f(\theta + \Delta_j, \; x_i) - f(\theta, \; x_i)$$
267
 
268
+ ### Goal-weighted gradient (VJP)
269
+
270
+ The key insight: not all evaluators matter equally. A luxury brand shouldn't optimize for budget shoppers. A dating profile shouldn't optimize for incompatible matches.
271
+
272
+ SGO uses a **goal vector** `v` that weights each evaluator by their relevance to your objective. The gradient is a vector-Jacobian product:
273
+
274
+ $$\nabla_j = \sum_{i} v_i \cdot J_{ij}$$
275
+
276
+ Where `v_i` is the goal-relevance weight for evaluator `i` (0 = irrelevant, 1 = ideal target).
277
 
278
+ Without a goal, `v = [1/n, ...]` — uniform weights, optimizing for universal appeal. With a goal like *"close enterprise deals"*, enterprise CTOs get `v ≈ 1` and solo hobbyists get `v ≈ 0`.
279
 
280
+ The LLM assigns goal-relevance weights automatically by evaluating each persona against your stated objective. This means the gradient tells you *"what changes move you toward your goal"*, not *"what changes make everyone like you more"*.
281
 
282
  ### What to probe
283
 
 
296
  |--------|---------|
297
  | θ | Entity you control |
298
  | x | Evaluator persona |
299
+ | g | Goal — what you're optimizing for |
300
  | f(θ, x) | LLM evaluation → score + reasoning |
301
+ | v_i | Goal-relevance weight for evaluator *i* |
302
  | Δⱼ | Hypothetical change |
303
  | Jᵢⱼ | Score delta: evaluator *i*, change *j* |
304
+ | ∇ⱼ | Goal-weighted gradient (VJP): impact of change *j* toward goal *g* |
305
 
306
  ## Project Structure
307
 
scripts/counterfactual.py CHANGED
@@ -128,29 +128,94 @@ def probe_one(client, model, eval_result, cohort_map, all_changes):
128
  return {"error": str(e), "_evaluator": ev}
129
 
130
 
131
- def analyze_gradient(results, all_changes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  valid = [r for r in results if "counterfactuals" in r]
133
  if not valid:
134
  return "No valid results."
135
 
 
136
  labels = {c["id"]: c["label"] for c in all_changes}
137
  jacobian = defaultdict(list)
138
 
139
  for r in valid:
 
 
140
  for cf in r.get("counterfactuals", []):
141
  jacobian[cf.get("change_id", "")].append({
142
  "delta": cf.get("delta", 0),
143
- "name": r["_evaluator"].get("name", ""),
 
 
144
  "age": r["_evaluator"].get("age", ""),
145
  "reasoning": cf.get("reasoning", ""),
146
  })
147
 
148
  ranked = []
149
  for cid, deltas in jacobian.items():
150
- avg = sum(d["delta"] for d in deltas) / len(deltas)
 
 
 
 
151
  ranked.append({
152
  "id": cid, "label": labels.get(cid, cid),
153
- "avg_delta": avg,
 
154
  "max_delta": max(d["delta"] for d in deltas),
155
  "min_delta": min(d["delta"] for d in deltas),
156
  "positive": sum(1 for d in deltas if d["delta"] > 0),
@@ -159,29 +224,46 @@ def analyze_gradient(results, all_changes):
159
  })
160
  ranked.sort(key=lambda x: x["avg_delta"], reverse=True)
161
 
162
- lines = [f"# Semantic Gradient\n\nProbed {len(valid)} evaluators across {len(all_changes)} changes.\n"]
163
- lines.append(f"{'Rank':<5} {'Avg Δ':>6} {'Max':>5} {'Min':>5} {'👍':>4} {'👎':>4} Change")
 
 
 
 
 
164
  lines.append("-" * 75)
165
  for i, r in enumerate(ranked, 1):
166
- lines.append(
167
- f"{i:<5} {r['avg_delta']:>+5.1f} {r['max_delta']:>+4} {r['min_delta']:>+4} "
168
- f"{r['positive']:>3} {r['negative']:>3} {r['label']}"
169
- )
 
 
 
 
 
 
170
 
171
  lines.append(f"\n## Top 3 — Detail\n")
172
  for r in ranked[:3]:
173
- lines.append(f"### {r['label']} (avg Δ {r['avg_delta']:+.1f})\n")
 
 
 
174
  positive = sorted([d for d in r["details"] if d["delta"] > 0],
175
- key=lambda x: x["delta"], reverse=True)
 
176
  if positive:
177
  lines.append("**Helps:**")
178
  for d in positive[:5]:
179
- lines.append(f" +{d['delta']} {d['name']} ({d['age']}): {d['reasoning']}")
 
180
  negative = [d for d in r["details"] if d["delta"] < 0]
181
  if negative:
182
  lines.append("**Hurts:**")
183
  for d in sorted(negative, key=lambda x: x["delta"])[:3]:
184
- lines.append(f" {d['delta']} {d['name']} ({d['age']}): {d['reasoning']}")
 
185
  lines.append("")
186
 
187
  return "\n".join(lines)
@@ -191,6 +273,8 @@ def main():
191
  parser = argparse.ArgumentParser()
192
  parser.add_argument("--tag", required=True)
193
  parser.add_argument("--changes", required=True, help="JSON file with changes to probe")
 
 
194
  parser.add_argument("--min-score", type=int, default=4)
195
  parser.add_argument("--max-score", type=int, default=7)
196
  parser.add_argument("--parallel", type=int, default=5)
@@ -223,7 +307,11 @@ def main():
223
  model = os.getenv("LLM_MODEL_NAME")
224
 
225
  print(f"Movable middle (score {args.min_score}-{args.max_score}): {len(movable)}")
226
- print(f"Changes: {len(all_changes)} | Model: {model}\n")
 
 
 
 
227
 
228
  results = [None] * len(movable)
229
  done = [0]
@@ -255,7 +343,18 @@ def main():
255
  with open(out_dir / "raw_probes.json", "w") as f:
256
  json.dump(results, f, ensure_ascii=False, indent=2)
257
 
258
- gradient = analyze_gradient(results, all_changes)
 
 
 
 
 
 
 
 
 
 
 
259
  with open(out_dir / "gradient.md", "w") as f:
260
  f.write(gradient)
261
 
 
128
  return {"error": str(e), "_evaluator": ev}
129
 
130
 
131
+ GOAL_RELEVANCE_PROMPT = """You are scoring how relevant an evaluator is to a specific goal.
132
+
133
+ ## Goal
134
+ {goal}
135
+
136
+ ## Evaluator
137
+ Name: {name}, Age: {age}, Occupation: {occupation}
138
+ Their evaluation: {score}/10 — "{summary}"
139
+
140
+ ## Task
141
+ On a scale of 0.0 to 1.0, how relevant is this evaluator's opinion to the stated goal?
142
+ - 1.0 = this is exactly the kind of person whose opinion matters for this goal
143
+ - 0.5 = somewhat relevant
144
+ - 0.0 = completely irrelevant to this goal
145
+
146
+ Return JSON only: {{"relevance": <0.0-1.0>, "reasoning": "<1 sentence>"}}"""
147
+
148
+
149
+ def compute_goal_weights(client, model, eval_results, cohort_map, goal, parallel=5):
150
+ """Score each evaluator's relevance to the goal. Returns {name: weight}."""
151
+ weights = {}
152
+
153
+ def score_one(r):
154
+ ev = r.get("_evaluator", {})
155
+ name = ev.get("name", "")
156
+ persona = cohort_map.get(name, {})
157
+ prompt = GOAL_RELEVANCE_PROMPT.format(
158
+ goal=goal, name=name, age=ev.get("age", ""),
159
+ occupation=ev.get("occupation", ""),
160
+ score=r.get("score", "?"),
161
+ summary=r.get("summary", r.get("reasoning", "")),
162
+ )
163
+ try:
164
+ resp = client.chat.completions.create(
165
+ model=model,
166
+ messages=[{"role": "user", "content": prompt}],
167
+ response_format={"type": "json_object"},
168
+ max_tokens=256, temperature=0.3,
169
+ )
170
+ content = resp.choices[0].message.content
171
+ content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
172
+ data = json.loads(content)
173
+ return name, float(data.get("relevance", 0.5)), data.get("reasoning", "")
174
+ except Exception:
175
+ return name, 0.5, "default"
176
+
177
+ with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as pool:
178
+ futs = [pool.submit(score_one, r) for r in eval_results]
179
+ for fut in concurrent.futures.as_completed(futs):
180
+ name, weight, reasoning = fut.result()
181
+ weights[name] = {"weight": weight, "reasoning": reasoning}
182
+
183
+ return weights
184
+
185
+
186
+ def analyze_gradient(results, all_changes, goal_weights=None):
187
  valid = [r for r in results if "counterfactuals" in r]
188
  if not valid:
189
  return "No valid results."
190
 
191
+ has_goal = goal_weights is not None
192
  labels = {c["id"]: c["label"] for c in all_changes}
193
  jacobian = defaultdict(list)
194
 
195
  for r in valid:
196
+ name = r["_evaluator"].get("name", "")
197
+ w = goal_weights.get(name, {}).get("weight", 1.0) if has_goal else 1.0
198
  for cf in r.get("counterfactuals", []):
199
  jacobian[cf.get("change_id", "")].append({
200
  "delta": cf.get("delta", 0),
201
+ "weighted_delta": cf.get("delta", 0) * w,
202
+ "weight": w,
203
+ "name": name,
204
  "age": r["_evaluator"].get("age", ""),
205
  "reasoning": cf.get("reasoning", ""),
206
  })
207
 
208
  ranked = []
209
  for cid, deltas in jacobian.items():
210
+ total_weight = sum(d["weight"] for d in deltas)
211
+ if total_weight == 0:
212
+ total_weight = 1
213
+ weighted_avg = sum(d["weighted_delta"] for d in deltas) / total_weight
214
+ raw_avg = sum(d["delta"] for d in deltas) / len(deltas)
215
  ranked.append({
216
  "id": cid, "label": labels.get(cid, cid),
217
+ "avg_delta": weighted_avg,
218
+ "raw_avg_delta": raw_avg,
219
  "max_delta": max(d["delta"] for d in deltas),
220
  "min_delta": min(d["delta"] for d in deltas),
221
  "positive": sum(1 for d in deltas if d["delta"] > 0),
 
224
  })
225
  ranked.sort(key=lambda x: x["avg_delta"], reverse=True)
226
 
227
+ mode = "Goal-Weighted (VJP)" if has_goal else "Uniform"
228
+ lines = [f"# Semantic Gradient ({mode})\n\nProbed {len(valid)} evaluators across {len(all_changes)} changes.\n"]
229
+ if has_goal:
230
+ header = f"{'Rank':<5} {'VJP Δ':>6} {'Raw Δ':>6} {'Max':>5} {'Min':>5} Change"
231
+ else:
232
+ header = f"{'Rank':<5} {'Avg Δ':>6} {'Max':>5} {'Min':>5} {'👍':>4} {'👎':>4} Change"
233
+ lines.append(header)
234
  lines.append("-" * 75)
235
  for i, r in enumerate(ranked, 1):
236
+ if has_goal:
237
+ lines.append(
238
+ f"{i:<5} {r['avg_delta']:>+5.1f} {r['raw_avg_delta']:>+5.1f} "
239
+ f"{r['max_delta']:>+4} {r['min_delta']:>+4} {r['label']}"
240
+ )
241
+ else:
242
+ lines.append(
243
+ f"{i:<5} {r['avg_delta']:>+5.1f} {r['max_delta']:>+4} {r['min_delta']:>+4} "
244
+ f"{r['positive']:>3} {r['negative']:>3} {r['label']}"
245
+ )
246
 
247
  lines.append(f"\n## Top 3 — Detail\n")
248
  for r in ranked[:3]:
249
+ label = f"### {r['label']} (Δ {r['avg_delta']:+.1f})"
250
+ if has_goal and abs(r['avg_delta'] - r['raw_avg_delta']) > 0.2:
251
+ label += f" ← was {r['raw_avg_delta']:+.1f} without goal weighting"
252
+ lines.append(label + "\n")
253
  positive = sorted([d for d in r["details"] if d["delta"] > 0],
254
+ key=lambda x: x["weighted_delta"] if has_goal else x["delta"],
255
+ reverse=True)
256
  if positive:
257
  lines.append("**Helps:**")
258
  for d in positive[:5]:
259
+ w_label = f" [w={d['weight']:.1f}]" if has_goal else ""
260
+ lines.append(f" +{d['delta']} {d['name']} ({d['age']}){w_label}: {d['reasoning']}")
261
  negative = [d for d in r["details"] if d["delta"] < 0]
262
  if negative:
263
  lines.append("**Hurts:**")
264
  for d in sorted(negative, key=lambda x: x["delta"])[:3]:
265
+ w_label = f" [w={d['weight']:.1f}]" if has_goal else ""
266
+ lines.append(f" {d['delta']} {d['name']} ({d['age']}){w_label}: {d['reasoning']}")
267
  lines.append("")
268
 
269
  return "\n".join(lines)
 
273
  parser = argparse.ArgumentParser()
274
  parser.add_argument("--tag", required=True)
275
  parser.add_argument("--changes", required=True, help="JSON file with changes to probe")
276
+ parser.add_argument("--goal", default=None,
277
+ help="Goal to optimize toward (enables VJP weighting)")
278
  parser.add_argument("--min-score", type=int, default=4)
279
  parser.add_argument("--max-score", type=int, default=7)
280
  parser.add_argument("--parallel", type=int, default=5)
 
307
  model = os.getenv("LLM_MODEL_NAME")
308
 
309
  print(f"Movable middle (score {args.min_score}-{args.max_score}): {len(movable)}")
310
+ print(f"Changes: {len(all_changes)} | Model: {model}")
311
+ if args.goal:
312
+ print(f"Goal: {args.goal} (VJP mode)\n")
313
+ else:
314
+ print("No goal — uniform weighting\n")
315
 
316
  results = [None] * len(movable)
317
  done = [0]
 
343
  with open(out_dir / "raw_probes.json", "w") as f:
344
  json.dump(results, f, ensure_ascii=False, indent=2)
345
 
346
+ # Compute goal weights if goal is specified (VJP)
347
+ goal_weights = None
348
+ if args.goal:
349
+ print("Computing goal-relevance weights...")
350
+ goal_weights = compute_goal_weights(
351
+ client, model, eval_results, cohort_map, args.goal,
352
+ parallel=args.parallel,
353
+ )
354
+ relevant = sum(1 for v in goal_weights.values() if v["weight"] >= 0.5)
355
+ print(f" {relevant}/{len(goal_weights)} evaluators relevant to goal\n")
356
+
357
+ gradient = analyze_gradient(results, all_changes, goal_weights=goal_weights)
358
  with open(out_dir / "gradient.md", "w") as f:
359
  f.write(gradient)
360
 
web/app.py CHANGED
@@ -419,10 +419,10 @@ async def evaluate_stream(sid: str, parallel: int = 5, bias_calibration: bool =
419
 
420
  @app.get("/api/counterfactual/stream/{sid}")
421
  async def counterfactual_stream(
422
- sid: str, changes_json: str, min_score: int = 4,
423
- max_score: int = 7, parallel: int = 5
424
  ):
425
- """Run counterfactual probes with SSE progress."""
426
  if sid not in sessions:
427
  raise HTTPException(404, "Session not found")
428
  session = sessions[sid]
@@ -442,8 +442,10 @@ async def counterfactual_stream(
442
  if "score" in r and min_score <= r["score"] <= max_score]
443
 
444
  total = len(movable)
 
445
  yield {"event": "start", "data": json.dumps({
446
- "total": total, "changes": len(all_changes), "model": model
 
447
  })}
448
 
449
  if total == 0:
@@ -454,6 +456,23 @@ async def counterfactual_stream(
454
  })}
455
  return
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  results = [None] * total
458
  done = 0
459
  t0 = time.time()
@@ -484,13 +503,15 @@ async def counterfactual_stream(
484
  yield {"event": "progress", "data": json.dumps(progress)}
485
 
486
  elapsed = time.time() - t0
487
- gradient_text = analyze_gradient(results, all_changes)
 
488
  session["gradient"] = gradient_text
489
 
490
  yield {"event": "complete", "data": json.dumps({
491
  "elapsed": round(elapsed, 1),
492
  "gradient": gradient_text,
493
  "results": results,
 
494
  })}
495
 
496
  return EventSourceResponse(event_generator())
 
419
 
420
  @app.get("/api/counterfactual/stream/{sid}")
421
  async def counterfactual_stream(
422
+ sid: str, changes_json: str, goal: str = "",
423
+ min_score: int = 4, max_score: int = 7, parallel: int = 5
424
  ):
425
+ """Run counterfactual probes with SSE progress. Goal enables VJP weighting."""
426
  if sid not in sessions:
427
  raise HTTPException(404, "Session not found")
428
  session = sessions[sid]
 
442
  if "score" in r and min_score <= r["score"] <= max_score]
443
 
444
  total = len(movable)
445
+ has_goal = bool(goal.strip())
446
  yield {"event": "start", "data": json.dumps({
447
+ "total": total, "changes": len(all_changes), "model": model,
448
+ "goal": goal if has_goal else None,
449
  })}
450
 
451
  if total == 0:
 
456
  })}
457
  return
458
 
459
+ # Compute goal-relevance weights (VJP) if goal is set
460
+ goal_weights = None
461
+ if has_goal:
462
+ yield {"event": "goal_weights", "data": json.dumps({
463
+ "status": "computing", "message": "Scoring evaluator relevance to goal..."
464
+ })}
465
+ goal_weights = compute_goal_weights(
466
+ client, model, eval_results, cohort_map, goal, parallel=parallel,
467
+ )
468
+ relevant = sum(1 for v in goal_weights.values() if v["weight"] >= 0.5)
469
+ yield {"event": "goal_weights", "data": json.dumps({
470
+ "status": "done",
471
+ "relevant": relevant,
472
+ "total": len(goal_weights),
473
+ "message": f"{relevant}/{len(goal_weights)} evaluators relevant to goal",
474
+ })}
475
+
476
  results = [None] * total
477
  done = 0
478
  t0 = time.time()
 
503
  yield {"event": "progress", "data": json.dumps(progress)}
504
 
505
  elapsed = time.time() - t0
506
+ gradient_text = analyze_gradient(results, all_changes,
507
+ goal_weights=goal_weights)
508
  session["gradient"] = gradient_text
509
 
510
  yield {"event": "complete", "data": json.dumps({
511
  "elapsed": round(elapsed, 1),
512
  "gradient": gradient_text,
513
  "results": results,
514
+ "goal": goal if has_goal else None,
515
  })}
516
 
517
  return EventSourceResponse(event_generator())
web/static/index.html CHANGED
@@ -348,6 +348,11 @@
348
  <textarea id="entityText" placeholder="Paste your entity here..."></textarea>
349
  </div>
350
 
 
 
 
 
 
351
  <details class="mb-8">
352
  <summary style="cursor:pointer;color:var(--text2);font-size:0.85rem">Advanced options</summary>
353
  <div style="padding:12px 0">
@@ -861,8 +866,10 @@ function runCounterfactual() {
861
  document.getElementById('cfResults').classList.add('hidden');
862
  document.getElementById('cfLog').innerHTML = '';
863
 
 
864
  const params = new URLSearchParams({
865
  changes_json: JSON.stringify(changes),
 
866
  min_score: minScore,
867
  max_score: maxScore,
868
  parallel: 5,
@@ -872,8 +879,14 @@ function runCounterfactual() {
872
 
873
  es.addEventListener('start', (e) => {
874
  const d = JSON.parse(e.data);
 
875
  document.getElementById('cfProgressText').textContent =
876
- `Probing ${d.total} evaluators across ${d.changes} changes...`;
 
 
 
 
 
877
  });
878
 
879
  es.addEventListener('progress', (e) => {
 
348
  <textarea id="entityText" placeholder="Paste your entity here..."></textarea>
349
  </div>
350
 
351
+ <div class="field">
352
+ <label>What's your goal?</label>
353
+ <input type="text" id="goalText" placeholder="e.g. 'Get hired at a Series B startup' or 'Close enterprise deals'">
354
+ </div>
355
+
356
  <details class="mb-8">
357
  <summary style="cursor:pointer;color:var(--text2);font-size:0.85rem">Advanced options</summary>
358
  <div style="padding:12px 0">
 
866
  document.getElementById('cfResults').classList.add('hidden');
867
  document.getElementById('cfLog').innerHTML = '';
868
 
869
+ const goal = document.getElementById('goalText').value.trim();
870
  const params = new URLSearchParams({
871
  changes_json: JSON.stringify(changes),
872
+ goal: goal,
873
  min_score: minScore,
874
  max_score: maxScore,
875
  parallel: 5,
 
879
 
880
  es.addEventListener('start', (e) => {
881
  const d = JSON.parse(e.data);
882
+ const goalLabel = d.goal ? ` toward "${d.goal}"` : '';
883
  document.getElementById('cfProgressText').textContent =
884
+ `Probing ${d.total} evaluators across ${d.changes} changes${goalLabel}...`;
885
+ });
886
+
887
+ es.addEventListener('goal_weights', (e) => {
888
+ const d = JSON.parse(e.data);
889
+ document.getElementById('cfProgressText').textContent = d.message;
890
  });
891
 
892
  es.addEventListener('progress', (e) => {