#!/usr/bin/env python3 """ scripts/run_transfer_inference.py ================================== Run the TRAINED adapter on the 5 held-out transfer episodes. Unlike transfer_eval.py (which uses a heuristic oracle), this script loads the actual LoRA adapter, generates real model completions, and parses the budget_predictions + tool calls from the output. This produces REAL model inference results, not proxy heuristics. Usage (after training completes): python scripts/run_transfer_inference.py python scripts/run_transfer_inference.py --adapter grpo_output/sft_adapter python scripts/run_transfer_inference.py --adapter grpo_output/checkpoints/checkpoint-150 Outputs: grpo_output/transfer_inference_results.json — per-episode model outputs grpo_output/transfer_inference_plot.png — real model inference plot """ from __future__ import annotations import argparse import json import os import re import sys from pathlib import Path ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT)) DEFAULT_EPISODES = ROOT / "data" / "transfer_episodes.json" DEFAULT_ADAPTER = ROOT / "grpo_output" / "sft_adapter" OUT_DIR = ROOT / "grpo_output" def build_transfer_prompt(episode: dict) -> str: """Build a prompt from a transfer episode in the same format as train_grpo.py.""" file_list = "\n".join( f" • {f['file']} [{f.get('language', '?')}] " f"complexity={f.get('features', [0,0,0,0])[1]} " f"churn={f.get('features', [0,0,0,0])[0]}" for f in episode["files"] ) return ( f"{'='*60}\n" f" CODE REVIEW INVESTIGATION\n" f"{'='*60}\n" f" Task: {episode['task_id']}\n" f" Description: {episode['description']}\n\n" f" Files to investigate ({len(episode['files'])}):\n" f"{file_list}\n\n" f" MISSION: Investigate which files contain bugs or regressions.\n" f" Use your thinking budget wisely — think deeply on suspicious\n" f" files and briefly on safe ones.\n" f"{'='*60}" ) def parse_predictions_from_output(text: str) -> list: """Extract (budget_prediction, think_length, tool_call_file) triples.""" preds = [] bp_pattern = re.compile( r'\s*(short|medium|long)\s*' r'.*?(.*?)' r'(?:.*?\s*(\{.*?\})\s*)?', re.DOTALL ) for m in bp_pattern.finditer(text): pred = m.group(1) think_text = m.group(2) tool_json = m.group(3) file_path = None tool_name = None if tool_json: try: data = json.loads(tool_json) tool_name = data.get("name", "") args = data.get("arguments", {}) file_path = args.get("file_path", "") except json.JSONDecodeError: pass preds.append({ "prediction": pred, "think_length": len(think_text.strip()), "tool_name": tool_name, "file_path": file_path, }) return preds def run_with_model(episodes, adapter_path, model_name): """Load model + adapter and run inference on transfer episodes.""" import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel print(f"šŸ”„ Loading {model_name}...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16, ) if adapter_path and os.path.exists(adapter_path): print(f"šŸ”„ Loading adapter from {adapter_path}...") model = PeftModel.from_pretrained(model, adapter_path) model = model.merge_and_unload() print("āœ… Adapter merged.") model.eval() # System prompt (same as train_grpo.py) system_prompt = ( "You are an expert code reviewer. Before each block, " "emit short|medium|long. " "Use 'long' for files that look buggy (high complexity, high churn, " "suspicious patterns). Use 'short' for safe files (test files, configs, " "docs, low complexity). After thinking, use to flag or skip." ) results = [] for ep in episodes: print(f"\nšŸ“‹ Running {ep['task_id']}: {ep['title']}") user_prompt = build_transfer_prompt(ep) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] try: input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=True, ) except TypeError: input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(input_text, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=1024, temperature=0.6, top_p=0.9, do_sample=True, ) generated = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) # Parse predictions preds = parse_predictions_from_output(generated) # Score against ground truth ground_truth = {f["file"]: f["label"] for f in ep["files"]} bug_lengths = [] safe_lengths = [] tp = fp = fn = 0 for pred in preds: fpath = pred.get("file_path", "") if fpath and fpath in ground_truth: label = ground_truth[fpath] if label == 1: bug_lengths.append(pred["think_length"]) else: safe_lengths.append(pred["think_length"]) is_flag = pred.get("tool_name") == "flag_vulnerable" if label == 1 and is_flag: tp += 1 elif label == 0 and is_flag: fp += 1 elif label == 1 and not is_flag: fn += 1 # Count unprocessed bugs as FN processed_bugs = set(p["file_path"] for p in preds if p.get("file_path") in ground_truth and ground_truth.get(p.get("file_path")) == 1) all_bugs = set(f["file"] for f in ep["files"] if f["label"] == 1) fn += len(all_bugs - processed_bugs) prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 bug_avg = sum(bug_lengths) / len(bug_lengths) if bug_lengths else 0 safe_avg = sum(safe_lengths) / len(safe_lengths) if safe_lengths else 1 ratio = bug_avg / max(1, safe_avg) result = { "task_id": ep["task_id"], "title": ep["title"], "n_predictions": len(preds), "f1": f1, "precision": prec, "recall": rec, "tp": tp, "fp": fp, "fn": fn, "bug_avg_think": bug_avg, "safe_avg_think": safe_avg, "thinking_ratio": ratio, "predictions": preds, "raw_output_length": len(generated), } results.append(result) print(f" F1={f1:.2f} | ratio={ratio:.1f}x | " f"preds={len(preds)} | TP={tp} FP={fp} FN={fn}") return results def run_without_model(episodes): """Fallback: run the heuristic oracle (same as transfer_eval.py) but CLEARLY LABEL the output as 'heuristic-proxy' not 'model inference'.""" import random results = [] rng = random.Random(42) for ep in episodes: ground_truth = {f["file"]: f for f in ep["files"]} bug_lengths = [] safe_lengths = [] tp = fp = fn = 0 preds = [] for f in ep["files"]: feat = f.get("features", [0, 0, 0, 0]) churn, complexity, todos, recency = feat risk = 0.45 * (churn/100) + 0.40 * (complexity/100) risk += 0.10 * (todos/20) + 0.05 * (recency/100) if f.get("is_test"): risk *= 0.3 if risk > 0.5: pred = "long" think_len = 350 + rng.randint(-30, 80) flag = True elif risk > 0.3: pred = "medium" think_len = 160 + rng.randint(-30, 60) flag = False else: pred = "short" think_len = 50 + rng.randint(0, 40) flag = False if f["label"] == 1: bug_lengths.append(think_len) if flag: tp += 1 else: fn += 1 else: safe_lengths.append(think_len) if flag: fp += 1 preds.append({ "prediction": pred, "think_length": think_len, "tool_name": "flag_vulnerable" if flag else "skip_file", "file_path": f["file"], }) prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 bug_avg = sum(bug_lengths) / len(bug_lengths) if bug_lengths else 0 safe_avg = sum(safe_lengths) / len(safe_lengths) if safe_lengths else 1 results.append({ "task_id": ep["task_id"], "title": ep["title"], "n_predictions": len(preds), "f1": f1, "precision": prec, "recall": rec, "tp": tp, "fp": fp, "fn": fn, "bug_avg_think": bug_avg, "safe_avg_think": safe_avg, "thinking_ratio": bug_avg / max(1, safe_avg), "predictions": preds, "mode": "heuristic-proxy (no adapter available)", }) return results def save_results(results, out_dir): """Save results and generate plot.""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np out_dir = Path(out_dir) out_dir.mkdir(exist_ok=True) # Save JSON json_path = out_dir / "transfer_inference_results.json" with open(json_path, "w") as f: json.dump(results, f, indent=2, default=float) print(f"\nšŸ“Š Results saved to {json_path}") # Aggregate agg_f1 = sum(r["f1"] for r in results) / len(results) agg_ratio = sum(r["thinking_ratio"] for r in results) / len(results) total_tp = sum(r["tp"] for r in results) total_fp = sum(r["fp"] for r in results) total_fn = sum(r["fn"] for r in results) print(f"\n{'='*50}") print(f" TRANSFER INFERENCE RESULTS") print(f"{'='*50}") for r in results: print(f" {r['task_id']}: F1={r['f1']:.2f} ratio={r['thinking_ratio']:.1f}x") print(f"{'─'*50}") print(f" Aggregate F1: {agg_f1:.3f}") print(f" Aggregate ratio: {agg_ratio:.1f}x") print(f" Total: TP={total_tp} FP={total_fp} FN={total_fn}") print(f"{'='*50}") # Plot fig, axes = plt.subplots(1, 2, figsize=(13, 5.5)) # Left: per-task F1 tasks = [r["task_id"] for r in results] f1s = [r["f1"] for r in results] colors = ["#4ade80" if f >= 0.5 else "#f87171" for f in f1s] axes[0].barh(tasks, f1s, color=colors, edgecolor="#333", linewidth=0.5) axes[0].set_xlabel("F1 Score", fontsize=12) axes[0].set_title(f"Transfer F1 (aggregate: {agg_f1:.2f})", fontsize=13) axes[0].set_xlim(0, 1.1) axes[0].axvline(1.0, color="#888", ls=":", lw=0.8) for i, v in enumerate(f1s): axes[0].text(v + 0.02, i, f"{v:.2f}", va="center", fontsize=10) # Right: thinking ratios ratios = [r["thinking_ratio"] for r in results] axes[1].barh(tasks, ratios, color="#a78bfa", edgecolor="#333", linewidth=0.5) axes[1].set_xlabel("Thinking Ratio (bug/safe)", fontsize=12) axes[1].set_title(f"Thinking Allocation (avg: {agg_ratio:.1f}x)", fontsize=13) axes[1].axvline(1.0, color="#888", ls=":", lw=0.8, label="uniform (1.0x)") for i, v in enumerate(ratios): axes[1].text(v + 0.1, i, f"{v:.1f}x", va="center", fontsize=10) axes[1].legend(fontsize=9) mode = results[0].get("mode", "trained adapter") fig.suptitle( f"Transfer Domain Evaluation — Real Model Inference\n({mode})", fontsize=14, fontweight="bold", y=1.02, ) fig.tight_layout() plot_path = out_dir / "transfer_inference_plot.png" fig.savefig(plot_path, dpi=140, bbox_inches="tight") plt.close(fig) print(f"āœ… Plot saved to {plot_path}") def main(): ap = argparse.ArgumentParser() ap.add_argument("--adapter", default=str(DEFAULT_ADAPTER), help="Path to trained LoRA adapter") ap.add_argument("--model", default="Qwen/Qwen3-1.7B", help="Base model name") ap.add_argument("--episodes", default=str(DEFAULT_EPISODES)) ap.add_argument("--out", default=str(OUT_DIR)) args = ap.parse_args() with open(args.episodes) as f: episodes = json.load(f) print(f"šŸ“‹ {len(episodes)} transfer episodes loaded") # Try real model inference first, fall back to heuristic has_adapter = os.path.exists(args.adapter) has_torch = False try: import torch has_torch = torch.cuda.is_available() except ImportError: pass if has_torch: print(f"šŸš€ Running with {'adapter' if has_adapter else 'base model'} inference") results = run_with_model( episodes, args.adapter if has_adapter else None, args.model, ) else: print("āš ļø No GPU available — using heuristic proxy") results = run_without_model(episodes) save_results(results, args.out) if __name__ == "__main__": main()