Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| scripts/run_transfer_inference.py | |
| ================================== | |
| Run the TRAINED adapter on the 5 held-out transfer episodes. | |
| Unlike transfer_eval.py (which uses a heuristic oracle), this script | |
| loads the actual LoRA adapter, generates real model completions, and | |
| parses the budget_predictions + tool calls from the output. | |
| This produces REAL model inference results, not proxy heuristics. | |
| Usage (after training completes): | |
| python scripts/run_transfer_inference.py | |
| python scripts/run_transfer_inference.py --adapter grpo_output/sft_adapter | |
| python scripts/run_transfer_inference.py --adapter grpo_output/checkpoints/checkpoint-150 | |
| Outputs: | |
| grpo_output/transfer_inference_results.json โ per-episode model outputs | |
| grpo_output/transfer_inference_plot.png โ real model inference plot | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import re | |
| import sys | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(ROOT)) | |
| DEFAULT_EPISODES = ROOT / "data" / "transfer_episodes.json" | |
| DEFAULT_ADAPTER = ROOT / "grpo_output" / "sft_adapter" | |
| OUT_DIR = ROOT / "grpo_output" | |
| def build_transfer_prompt(episode: dict) -> str: | |
| """Build a prompt from a transfer episode in the same format as train_grpo.py.""" | |
| file_list = "\n".join( | |
| f" โข {f['file']} [{f.get('language', '?')}] " | |
| f"complexity={f.get('features', [0,0,0,0])[1]} " | |
| f"churn={f.get('features', [0,0,0,0])[0]}" | |
| for f in episode["files"] | |
| ) | |
| return ( | |
| f"{'='*60}\n" | |
| f" CODE REVIEW INVESTIGATION\n" | |
| f"{'='*60}\n" | |
| f" Task: {episode['task_id']}\n" | |
| f" Description: {episode['description']}\n\n" | |
| f" Files to investigate ({len(episode['files'])}):\n" | |
| f"{file_list}\n\n" | |
| f" MISSION: Investigate which files contain bugs or regressions.\n" | |
| f" Use your thinking budget wisely โ think deeply on suspicious\n" | |
| f" files and briefly on safe ones.\n" | |
| f"{'='*60}" | |
| ) | |
| def parse_predictions_from_output(text: str) -> list: | |
| """Extract (budget_prediction, think_length, tool_call_file) triples.""" | |
| preds = [] | |
| bp_pattern = re.compile( | |
| r'<budget_prediction>\s*(short|medium|long)\s*</budget_prediction>' | |
| r'.*?<think>(.*?)</think>' | |
| r'(?:.*?<tool_call>\s*(\{.*?\})\s*</tool_call>)?', | |
| re.DOTALL | |
| ) | |
| for m in bp_pattern.finditer(text): | |
| pred = m.group(1) | |
| think_text = m.group(2) | |
| tool_json = m.group(3) | |
| file_path = None | |
| tool_name = None | |
| if tool_json: | |
| try: | |
| data = json.loads(tool_json) | |
| tool_name = data.get("name", "") | |
| args = data.get("arguments", {}) | |
| file_path = args.get("file_path", "") | |
| except json.JSONDecodeError: | |
| pass | |
| preds.append({ | |
| "prediction": pred, | |
| "think_length": len(think_text.strip()), | |
| "tool_name": tool_name, | |
| "file_path": file_path, | |
| }) | |
| return preds | |
| def run_with_model(episodes, adapter_path, model_name): | |
| """Load model + adapter and run inference on transfer episodes.""" | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| print(f"๐ Loading {model_name}...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| if adapter_path and os.path.exists(adapter_path): | |
| print(f"๐ Loading adapter from {adapter_path}...") | |
| model = PeftModel.from_pretrained(model, adapter_path) | |
| model = model.merge_and_unload() | |
| print("โ Adapter merged.") | |
| model.eval() | |
| # System prompt (same as train_grpo.py) | |
| system_prompt = ( | |
| "You are an expert code reviewer. Before each <think> block, " | |
| "emit <budget_prediction>short|medium|long</budget_prediction>. " | |
| "Use 'long' for files that look buggy (high complexity, high churn, " | |
| "suspicious patterns). Use 'short' for safe files (test files, configs, " | |
| "docs, low complexity). After thinking, use <tool_call> to flag or skip." | |
| ) | |
| results = [] | |
| for ep in episodes: | |
| print(f"\n๐ Running {ep['task_id']}: {ep['title']}") | |
| user_prompt = build_transfer_prompt(ep) | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| try: | |
| input_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| enable_thinking=True, | |
| ) | |
| except TypeError: | |
| input_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| temperature=0.6, | |
| top_p=0.9, | |
| do_sample=True, | |
| ) | |
| generated = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=False) | |
| # Parse predictions | |
| preds = parse_predictions_from_output(generated) | |
| # Score against ground truth | |
| ground_truth = {f["file"]: f["label"] for f in ep["files"]} | |
| bug_lengths = [] | |
| safe_lengths = [] | |
| tp = fp = fn = 0 | |
| for pred in preds: | |
| fpath = pred.get("file_path", "") | |
| if fpath and fpath in ground_truth: | |
| label = ground_truth[fpath] | |
| if label == 1: | |
| bug_lengths.append(pred["think_length"]) | |
| else: | |
| safe_lengths.append(pred["think_length"]) | |
| is_flag = pred.get("tool_name") == "flag_vulnerable" | |
| if label == 1 and is_flag: | |
| tp += 1 | |
| elif label == 0 and is_flag: | |
| fp += 1 | |
| elif label == 1 and not is_flag: | |
| fn += 1 | |
| # Count unprocessed bugs as FN | |
| processed_bugs = set(p["file_path"] for p in preds | |
| if p.get("file_path") in ground_truth | |
| and ground_truth.get(p.get("file_path")) == 1) | |
| all_bugs = set(f["file"] for f in ep["files"] if f["label"] == 1) | |
| fn += len(all_bugs - processed_bugs) | |
| prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 | |
| bug_avg = sum(bug_lengths) / len(bug_lengths) if bug_lengths else 0 | |
| safe_avg = sum(safe_lengths) / len(safe_lengths) if safe_lengths else 1 | |
| ratio = bug_avg / max(1, safe_avg) | |
| result = { | |
| "task_id": ep["task_id"], | |
| "title": ep["title"], | |
| "n_predictions": len(preds), | |
| "f1": f1, | |
| "precision": prec, | |
| "recall": rec, | |
| "tp": tp, "fp": fp, "fn": fn, | |
| "bug_avg_think": bug_avg, | |
| "safe_avg_think": safe_avg, | |
| "thinking_ratio": ratio, | |
| "predictions": preds, | |
| "raw_output_length": len(generated), | |
| } | |
| results.append(result) | |
| print(f" F1={f1:.2f} | ratio={ratio:.1f}x | " | |
| f"preds={len(preds)} | TP={tp} FP={fp} FN={fn}") | |
| return results | |
| def run_without_model(episodes): | |
| """Fallback: run the heuristic oracle (same as transfer_eval.py) | |
| but CLEARLY LABEL the output as 'heuristic-proxy' not 'model inference'.""" | |
| import random | |
| results = [] | |
| rng = random.Random(42) | |
| for ep in episodes: | |
| ground_truth = {f["file"]: f for f in ep["files"]} | |
| bug_lengths = [] | |
| safe_lengths = [] | |
| tp = fp = fn = 0 | |
| preds = [] | |
| for f in ep["files"]: | |
| feat = f.get("features", [0, 0, 0, 0]) | |
| churn, complexity, todos, recency = feat | |
| risk = 0.45 * (churn/100) + 0.40 * (complexity/100) | |
| risk += 0.10 * (todos/20) + 0.05 * (recency/100) | |
| if f.get("is_test"): | |
| risk *= 0.3 | |
| if risk > 0.5: | |
| pred = "long" | |
| think_len = 350 + rng.randint(-30, 80) | |
| flag = True | |
| elif risk > 0.3: | |
| pred = "medium" | |
| think_len = 160 + rng.randint(-30, 60) | |
| flag = False | |
| else: | |
| pred = "short" | |
| think_len = 50 + rng.randint(0, 40) | |
| flag = False | |
| if f["label"] == 1: | |
| bug_lengths.append(think_len) | |
| if flag: tp += 1 | |
| else: fn += 1 | |
| else: | |
| safe_lengths.append(think_len) | |
| if flag: fp += 1 | |
| preds.append({ | |
| "prediction": pred, | |
| "think_length": think_len, | |
| "tool_name": "flag_vulnerable" if flag else "skip_file", | |
| "file_path": f["file"], | |
| }) | |
| prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 | |
| bug_avg = sum(bug_lengths) / len(bug_lengths) if bug_lengths else 0 | |
| safe_avg = sum(safe_lengths) / len(safe_lengths) if safe_lengths else 1 | |
| results.append({ | |
| "task_id": ep["task_id"], | |
| "title": ep["title"], | |
| "n_predictions": len(preds), | |
| "f1": f1, | |
| "precision": prec, | |
| "recall": rec, | |
| "tp": tp, "fp": fp, "fn": fn, | |
| "bug_avg_think": bug_avg, | |
| "safe_avg_think": safe_avg, | |
| "thinking_ratio": bug_avg / max(1, safe_avg), | |
| "predictions": preds, | |
| "mode": "heuristic-proxy (no adapter available)", | |
| }) | |
| return results | |
| def save_results(results, out_dir): | |
| """Save results and generate plot.""" | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| out_dir = Path(out_dir) | |
| out_dir.mkdir(exist_ok=True) | |
| # Save JSON | |
| json_path = out_dir / "transfer_inference_results.json" | |
| with open(json_path, "w") as f: | |
| json.dump(results, f, indent=2, default=float) | |
| print(f"\n๐ Results saved to {json_path}") | |
| # Aggregate | |
| agg_f1 = sum(r["f1"] for r in results) / len(results) | |
| agg_ratio = sum(r["thinking_ratio"] for r in results) / len(results) | |
| total_tp = sum(r["tp"] for r in results) | |
| total_fp = sum(r["fp"] for r in results) | |
| total_fn = sum(r["fn"] for r in results) | |
| print(f"\n{'='*50}") | |
| print(f" TRANSFER INFERENCE RESULTS") | |
| print(f"{'='*50}") | |
| for r in results: | |
| print(f" {r['task_id']}: F1={r['f1']:.2f} ratio={r['thinking_ratio']:.1f}x") | |
| print(f"{'โ'*50}") | |
| print(f" Aggregate F1: {agg_f1:.3f}") | |
| print(f" Aggregate ratio: {agg_ratio:.1f}x") | |
| print(f" Total: TP={total_tp} FP={total_fp} FN={total_fn}") | |
| print(f"{'='*50}") | |
| # Plot | |
| fig, axes = plt.subplots(1, 2, figsize=(13, 5.5)) | |
| # Left: per-task F1 | |
| tasks = [r["task_id"] for r in results] | |
| f1s = [r["f1"] for r in results] | |
| colors = ["#4ade80" if f >= 0.5 else "#f87171" for f in f1s] | |
| axes[0].barh(tasks, f1s, color=colors, edgecolor="#333", linewidth=0.5) | |
| axes[0].set_xlabel("F1 Score", fontsize=12) | |
| axes[0].set_title(f"Transfer F1 (aggregate: {agg_f1:.2f})", fontsize=13) | |
| axes[0].set_xlim(0, 1.1) | |
| axes[0].axvline(1.0, color="#888", ls=":", lw=0.8) | |
| for i, v in enumerate(f1s): | |
| axes[0].text(v + 0.02, i, f"{v:.2f}", va="center", fontsize=10) | |
| # Right: thinking ratios | |
| ratios = [r["thinking_ratio"] for r in results] | |
| axes[1].barh(tasks, ratios, color="#a78bfa", edgecolor="#333", linewidth=0.5) | |
| axes[1].set_xlabel("Thinking Ratio (bug/safe)", fontsize=12) | |
| axes[1].set_title(f"Thinking Allocation (avg: {agg_ratio:.1f}x)", fontsize=13) | |
| axes[1].axvline(1.0, color="#888", ls=":", lw=0.8, label="uniform (1.0x)") | |
| for i, v in enumerate(ratios): | |
| axes[1].text(v + 0.1, i, f"{v:.1f}x", va="center", fontsize=10) | |
| axes[1].legend(fontsize=9) | |
| mode = results[0].get("mode", "trained adapter") | |
| fig.suptitle( | |
| f"Transfer Domain Evaluation โ Real Model Inference\n({mode})", | |
| fontsize=14, fontweight="bold", y=1.02, | |
| ) | |
| fig.tight_layout() | |
| plot_path = out_dir / "transfer_inference_plot.png" | |
| fig.savefig(plot_path, dpi=140, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"โ Plot saved to {plot_path}") | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--adapter", default=str(DEFAULT_ADAPTER), | |
| help="Path to trained LoRA adapter") | |
| ap.add_argument("--model", default="Qwen/Qwen3-1.7B", | |
| help="Base model name") | |
| ap.add_argument("--episodes", default=str(DEFAULT_EPISODES)) | |
| ap.add_argument("--out", default=str(OUT_DIR)) | |
| args = ap.parse_args() | |
| with open(args.episodes) as f: | |
| episodes = json.load(f) | |
| print(f"๐ {len(episodes)} transfer episodes loaded") | |
| # Try real model inference first, fall back to heuristic | |
| has_adapter = os.path.exists(args.adapter) | |
| has_torch = False | |
| try: | |
| import torch | |
| has_torch = torch.cuda.is_available() | |
| except ImportError: | |
| pass | |
| if has_torch: | |
| print(f"๐ Running with {'adapter' if has_adapter else 'base model'} inference") | |
| results = run_with_model( | |
| episodes, | |
| args.adapter if has_adapter else None, | |
| args.model, | |
| ) | |
| else: | |
| print("โ ๏ธ No GPU available โ using heuristic proxy") | |
| results = run_without_model(episodes) | |
| save_results(results, args.out) | |
| if __name__ == "__main__": | |
| main() | |