codemaverick2 commited on
Commit
116a4b1
Β·
1 Parent(s): 78f3eb2

Add GRPO batch endpoint, task replay in curriculum, update to v2.1.0

Browse files
Files changed (2) hide show
  1. openenv.yaml +11 -3
  2. server/app.py +89 -2
openenv.yaml CHANGED
@@ -1,12 +1,13 @@
1
  spec_version: 1
2
  name: code_review_env
3
- version: "2.0.0"
4
  description: >
5
  A code review and security audit RL environment for training AI agents.
6
  The agent identifies bugs, security vulnerabilities, and performance issues
7
  across 7 tasks of increasing difficulty (easy β†’ medium β†’ medium-hard β†’ hard).
8
  Features: PBRS reward shaping, graduated near-miss rewards, flood protection,
9
- CAMRL curriculum selector, VL return normalization, and cross-language tasks.
 
10
  type: space
11
  runtime: fastapi
12
  app: server.app:app
@@ -53,6 +54,13 @@ tasks:
53
  reward_design:
54
  terminal: "0.70 * F1 + 0.30 * severity_accuracy"
55
  shaping: "PBRS (Ng et al. 1999): phi(s) = (tp/total_gt) * 0.5"
56
- near_miss: "exponential decay: 0.10 * exp(-0.6 * (line_diff - 2))"
 
 
57
  flood_protection: "escalating FP penalty after 3rd false positive"
58
  normalization: "VL Norm (2025): normalized_return = cumulative / steps_used"
 
 
 
 
 
 
1
  spec_version: 1
2
  name: code_review_env
3
+ version: "2.1.0"
4
  description: >
5
  A code review and security audit RL environment for training AI agents.
6
  The agent identifies bugs, security vulnerabilities, and performance issues
7
  across 7 tasks of increasing difficulty (easy β†’ medium β†’ medium-hard β†’ hard).
8
  Features: PBRS reward shaping, graduated near-miss rewards, flood protection,
9
+ CAMRL curriculum with task replay, VL return normalization, GRPO batch endpoint,
10
+ diversity/exploration bonuses, and cross-language tasks (Python + JavaScript).
11
  type: space
12
  runtime: fastapi
13
  app: server.app:app
 
54
  reward_design:
55
  terminal: "0.70 * F1 + 0.30 * severity_accuracy"
56
  shaping: "PBRS (Ng et al. 1999): phi(s) = (tp/total_gt) * 0.5"
57
+ near_miss: "exponential decay: 0.10 * exp(-0.6 * (line_diff - 2)), requires compatible type"
58
+ diversity_bonus: "+0.02 for first TP in a new issue category"
59
+ exploration_bonus: "+0.01 for first TP in a new file (multi-file tasks)"
60
  flood_protection: "escalating FP penalty after 3rd false positive"
61
  normalization: "VL Norm (2025): normalized_return = cumulative / steps_used"
62
+
63
+ training:
64
+ grpo_endpoint: "/grpo_batch β€” group-relative advantages A_i = (r_i - mean) / std"
65
+ curriculum: "CAMRL with 20% task replay to prevent forgetting"
66
+ rollout: "/trl_rollout β€” TRL GRPOTrainer compatible batch rollout"
server/app.py CHANGED
@@ -303,6 +303,7 @@ class CurriculumRequest(BaseModel):
303
  agent_performance: Optional[Dict[str, Any]] = None
304
  easy_threshold: float = 0.30
305
  hard_threshold: float = 0.70
 
306
 
307
 
308
  @app.post("/curriculum")
@@ -341,7 +342,14 @@ async def curriculum_task_selector(request: CurriculumRequest):
341
  else:
342
  avg_success = 0.0
343
 
344
- if avg_success < easy_thresh:
 
 
 
 
 
 
 
345
  phase = "easy"
346
  # Focus on task with lowest ground truth issue count (most approachable)
347
  recommended = min(ALL_TASKS.keys(), key=lambda t: len(ALL_TASKS[t]["ground_truth_issues"]))
@@ -352,7 +360,6 @@ async def curriculum_task_selector(request: CurriculumRequest):
352
  else:
353
  phase = "medium"
354
  # Mix: pick a task proportional to difficulty (harder = more likely)
355
- import random
356
  weights = list(task_difficulty.values())
357
  total_w = sum(weights) or 1.0
358
  probs = [w / total_w for w in weights]
@@ -473,6 +480,86 @@ async def trl_rollout(request: TRLRolloutRequest):
473
  }
474
 
475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  def main():
477
  import uvicorn
478
  port = int(os.environ.get("PORT", 7860))
 
303
  agent_performance: Optional[Dict[str, Any]] = None
304
  easy_threshold: float = 0.30
305
  hard_threshold: float = 0.70
306
+ replay_fraction: float = 0.20 # fraction of time to replay earlier tasks (prevents forgetting)
307
 
308
 
309
  @app.post("/curriculum")
 
342
  else:
343
  avg_success = 0.0
344
 
345
+ # Task replay (prevents catastrophic forgetting, arxiv 2506.06632):
346
+ # With replay_fraction probability, pick an easy/mastered task instead
347
+ replay_frac = request.replay_fraction
348
+ if perf and random.random() < replay_frac:
349
+ # Replay: pick easiest task (lowest GT count) to maintain baseline skills
350
+ phase = "replay"
351
+ recommended = min(ALL_TASKS.keys(), key=lambda t: len(ALL_TASKS[t]["ground_truth_issues"]))
352
+ elif avg_success < easy_thresh:
353
  phase = "easy"
354
  # Focus on task with lowest ground truth issue count (most approachable)
355
  recommended = min(ALL_TASKS.keys(), key=lambda t: len(ALL_TASKS[t]["ground_truth_issues"]))
 
360
  else:
361
  phase = "medium"
362
  # Mix: pick a task proportional to difficulty (harder = more likely)
 
363
  weights = list(task_difficulty.values())
364
  total_w = sum(weights) or 1.0
365
  probs = [w / total_w for w in weights]
 
480
  }
481
 
482
 
483
+ class GRPOBatchRequest(BaseModel):
484
+ task_id: Optional[str] = None
485
+ seed: Optional[int] = None
486
+ group: List[List[Dict[str, Any]]] # G action sequences for group-relative comparison
487
+
488
+
489
+ @app.post("/grpo_batch")
490
+ async def grpo_batch(request: GRPOBatchRequest):
491
+ """
492
+ GRPO group-relative rollout batch (DeepSeek-R1 / DeepSeekMath style).
493
+
494
+ Runs G action sequences on the SAME task, computes group-relative advantages:
495
+ A_i = (r_i - mean(r_1..r_G)) / std(r_1..r_G)
496
+
497
+ This replaces the PPO critic entirely β€” no value network needed.
498
+ Recommended group size G=64 (DeepSeekMath), G=8-16 for faster iteration.
499
+
500
+ Body:
501
+ task_id: str (optional)
502
+ seed: int (optional, ensures same task state for all rollouts)
503
+ group: [[actions_1], [actions_2], ..., [actions_G]]
504
+
505
+ Returns:
506
+ rollouts: [{episode_return, final_score, advantage, ...}]
507
+ group_stats: {mean, std, G}
508
+ """
509
+ G = len(request.group)
510
+ if G < 2:
511
+ raise HTTPException(400, "GRPO requires at least 2 rollouts in the group")
512
+
513
+ returns = []
514
+ rollout_results = []
515
+
516
+ for action_seq in request.group:
517
+ rollout_env = CodeReviewEnvironment()
518
+ rollout_env.reset(task_id=request.task_id, seed=request.seed)
519
+
520
+ episode_return = 0.0
521
+ final_score = 0.0
522
+ n_steps = 0
523
+
524
+ for action_dict in action_seq:
525
+ action = ReviewAction.from_dict(action_dict)
526
+ obs_step = rollout_env.step(action)
527
+ step_data = _serialize(obs_step)
528
+ reward = step_data.get("reward") or 0.0
529
+ episode_return += reward
530
+ n_steps += 1
531
+
532
+ if step_data.get("done", False):
533
+ final_score = step_data.get("reward", step_data.get("current_score", 0.0)) or 0.0
534
+ break
535
+
536
+ returns.append(final_score)
537
+ rollout_results.append({
538
+ "episode_return": round(episode_return, 4),
539
+ "final_score": round(final_score, 4),
540
+ "num_steps": n_steps,
541
+ })
542
+
543
+ # Compute group-relative advantages: A_i = (r_i - mean) / std
544
+ mean_r = sum(returns) / G
545
+ variance = sum((r - mean_r) ** 2 for r in returns) / G
546
+ std_r = max(variance ** 0.5, 1e-8)
547
+
548
+ for i, result in enumerate(rollout_results):
549
+ result["advantage"] = round((returns[i] - mean_r) / std_r, 4)
550
+
551
+ return {
552
+ "task_id": request.task_id,
553
+ "rollouts": rollout_results,
554
+ "group_stats": {
555
+ "mean": round(mean_r, 4),
556
+ "std": round(std_r, 4),
557
+ "G": G,
558
+ },
559
+ "method": "GRPO",
560
+ }
561
+
562
+
563
  def main():
564
  import uvicorn
565
  port = int(os.environ.get("PORT", 7860))