#!/usr/bin/env python3
"""
scripts/run_transfer_inference.py
==================================
Run the TRAINED adapter on the 5 held-out transfer episodes.
Unlike transfer_eval.py (which uses a heuristic oracle), this script
loads the actual LoRA adapter, generates real model completions, and
parses the budget_predictions + tool calls from the output.
This produces REAL model inference results, not proxy heuristics.
Usage (after training completes):
python scripts/run_transfer_inference.py
python scripts/run_transfer_inference.py --adapter grpo_output/sft_adapter
python scripts/run_transfer_inference.py --adapter grpo_output/checkpoints/checkpoint-150
Outputs:
grpo_output/transfer_inference_results.json ā per-episode model outputs
grpo_output/transfer_inference_plot.png ā real model inference plot
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
DEFAULT_EPISODES = ROOT / "data" / "transfer_episodes.json"
DEFAULT_ADAPTER = ROOT / "grpo_output" / "sft_adapter"
OUT_DIR = ROOT / "grpo_output"
def build_transfer_prompt(episode: dict) -> str:
"""Build a prompt from a transfer episode in the same format as train_grpo.py."""
file_list = "\n".join(
f" ⢠{f['file']} [{f.get('language', '?')}] "
f"complexity={f.get('features', [0,0,0,0])[1]} "
f"churn={f.get('features', [0,0,0,0])[0]}"
for f in episode["files"]
)
return (
f"{'='*60}\n"
f" CODE REVIEW INVESTIGATION\n"
f"{'='*60}\n"
f" Task: {episode['task_id']}\n"
f" Description: {episode['description']}\n\n"
f" Files to investigate ({len(episode['files'])}):\n"
f"{file_list}\n\n"
f" MISSION: Investigate which files contain bugs or regressions.\n"
f" Use your thinking budget wisely ā think deeply on suspicious\n"
f" files and briefly on safe ones.\n"
f"{'='*60}"
)
def parse_predictions_from_output(text: str) -> list:
"""Extract (budget_prediction, think_length, tool_call_file) triples."""
preds = []
bp_pattern = re.compile(
r'\s*(short|medium|long)\s*'
r'.*?(.*?)'
r'(?:.*?\s*(\{.*?\})\s*)?',
re.DOTALL
)
for m in bp_pattern.finditer(text):
pred = m.group(1)
think_text = m.group(2)
tool_json = m.group(3)
file_path = None
tool_name = None
if tool_json:
try:
data = json.loads(tool_json)
tool_name = data.get("name", "")
args = data.get("arguments", {})
file_path = args.get("file_path", "")
except json.JSONDecodeError:
pass
preds.append({
"prediction": pred,
"think_length": len(think_text.strip()),
"tool_name": tool_name,
"file_path": file_path,
})
return preds
def run_with_model(episodes, adapter_path, model_name):
"""Load model + adapter and run inference on transfer episodes."""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
print(f"š Loading {model_name}...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
if adapter_path and os.path.exists(adapter_path):
print(f"š Loading adapter from {adapter_path}...")
model = PeftModel.from_pretrained(model, adapter_path)
model = model.merge_and_unload()
print("ā
Adapter merged.")
model.eval()
# System prompt (same as train_grpo.py)
system_prompt = (
"You are an expert code reviewer. Before each block, "
"emit short|medium|long. "
"Use 'long' for files that look buggy (high complexity, high churn, "
"suspicious patterns). Use 'short' for safe files (test files, configs, "
"docs, low complexity). After thinking, use to flag or skip."
)
results = []
for ep in episodes:
print(f"\nš Running {ep['task_id']}: {ep['title']}")
user_prompt = build_transfer_prompt(ep)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
try:
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=True,
)
except TypeError:
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.6,
top_p=0.9,
do_sample=True,
)
generated = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=False)
# Parse predictions
preds = parse_predictions_from_output(generated)
# Score against ground truth
ground_truth = {f["file"]: f["label"] for f in ep["files"]}
bug_lengths = []
safe_lengths = []
tp = fp = fn = 0
for pred in preds:
fpath = pred.get("file_path", "")
if fpath and fpath in ground_truth:
label = ground_truth[fpath]
if label == 1:
bug_lengths.append(pred["think_length"])
else:
safe_lengths.append(pred["think_length"])
is_flag = pred.get("tool_name") == "flag_vulnerable"
if label == 1 and is_flag:
tp += 1
elif label == 0 and is_flag:
fp += 1
elif label == 1 and not is_flag:
fn += 1
# Count unprocessed bugs as FN
processed_bugs = set(p["file_path"] for p in preds
if p.get("file_path") in ground_truth
and ground_truth.get(p.get("file_path")) == 1)
all_bugs = set(f["file"] for f in ep["files"] if f["label"] == 1)
fn += len(all_bugs - processed_bugs)
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
bug_avg = sum(bug_lengths) / len(bug_lengths) if bug_lengths else 0
safe_avg = sum(safe_lengths) / len(safe_lengths) if safe_lengths else 1
ratio = bug_avg / max(1, safe_avg)
result = {
"task_id": ep["task_id"],
"title": ep["title"],
"n_predictions": len(preds),
"f1": f1,
"precision": prec,
"recall": rec,
"tp": tp, "fp": fp, "fn": fn,
"bug_avg_think": bug_avg,
"safe_avg_think": safe_avg,
"thinking_ratio": ratio,
"predictions": preds,
"raw_output_length": len(generated),
}
results.append(result)
print(f" F1={f1:.2f} | ratio={ratio:.1f}x | "
f"preds={len(preds)} | TP={tp} FP={fp} FN={fn}")
return results
def run_without_model(episodes):
"""Fallback: run the heuristic oracle (same as transfer_eval.py)
but CLEARLY LABEL the output as 'heuristic-proxy' not 'model inference'."""
import random
results = []
rng = random.Random(42)
for ep in episodes:
ground_truth = {f["file"]: f for f in ep["files"]}
bug_lengths = []
safe_lengths = []
tp = fp = fn = 0
preds = []
for f in ep["files"]:
feat = f.get("features", [0, 0, 0, 0])
churn, complexity, todos, recency = feat
risk = 0.45 * (churn/100) + 0.40 * (complexity/100)
risk += 0.10 * (todos/20) + 0.05 * (recency/100)
if f.get("is_test"):
risk *= 0.3
if risk > 0.5:
pred = "long"
think_len = 350 + rng.randint(-30, 80)
flag = True
elif risk > 0.3:
pred = "medium"
think_len = 160 + rng.randint(-30, 60)
flag = False
else:
pred = "short"
think_len = 50 + rng.randint(0, 40)
flag = False
if f["label"] == 1:
bug_lengths.append(think_len)
if flag: tp += 1
else: fn += 1
else:
safe_lengths.append(think_len)
if flag: fp += 1
preds.append({
"prediction": pred,
"think_length": think_len,
"tool_name": "flag_vulnerable" if flag else "skip_file",
"file_path": f["file"],
})
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
bug_avg = sum(bug_lengths) / len(bug_lengths) if bug_lengths else 0
safe_avg = sum(safe_lengths) / len(safe_lengths) if safe_lengths else 1
results.append({
"task_id": ep["task_id"],
"title": ep["title"],
"n_predictions": len(preds),
"f1": f1,
"precision": prec,
"recall": rec,
"tp": tp, "fp": fp, "fn": fn,
"bug_avg_think": bug_avg,
"safe_avg_think": safe_avg,
"thinking_ratio": bug_avg / max(1, safe_avg),
"predictions": preds,
"mode": "heuristic-proxy (no adapter available)",
})
return results
def save_results(results, out_dir):
"""Save results and generate plot."""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
out_dir = Path(out_dir)
out_dir.mkdir(exist_ok=True)
# Save JSON
json_path = out_dir / "transfer_inference_results.json"
with open(json_path, "w") as f:
json.dump(results, f, indent=2, default=float)
print(f"\nš Results saved to {json_path}")
# Aggregate
agg_f1 = sum(r["f1"] for r in results) / len(results)
agg_ratio = sum(r["thinking_ratio"] for r in results) / len(results)
total_tp = sum(r["tp"] for r in results)
total_fp = sum(r["fp"] for r in results)
total_fn = sum(r["fn"] for r in results)
print(f"\n{'='*50}")
print(f" TRANSFER INFERENCE RESULTS")
print(f"{'='*50}")
for r in results:
print(f" {r['task_id']}: F1={r['f1']:.2f} ratio={r['thinking_ratio']:.1f}x")
print(f"{'ā'*50}")
print(f" Aggregate F1: {agg_f1:.3f}")
print(f" Aggregate ratio: {agg_ratio:.1f}x")
print(f" Total: TP={total_tp} FP={total_fp} FN={total_fn}")
print(f"{'='*50}")
# Plot
fig, axes = plt.subplots(1, 2, figsize=(13, 5.5))
# Left: per-task F1
tasks = [r["task_id"] for r in results]
f1s = [r["f1"] for r in results]
colors = ["#4ade80" if f >= 0.5 else "#f87171" for f in f1s]
axes[0].barh(tasks, f1s, color=colors, edgecolor="#333", linewidth=0.5)
axes[0].set_xlabel("F1 Score", fontsize=12)
axes[0].set_title(f"Transfer F1 (aggregate: {agg_f1:.2f})", fontsize=13)
axes[0].set_xlim(0, 1.1)
axes[0].axvline(1.0, color="#888", ls=":", lw=0.8)
for i, v in enumerate(f1s):
axes[0].text(v + 0.02, i, f"{v:.2f}", va="center", fontsize=10)
# Right: thinking ratios
ratios = [r["thinking_ratio"] for r in results]
axes[1].barh(tasks, ratios, color="#a78bfa", edgecolor="#333", linewidth=0.5)
axes[1].set_xlabel("Thinking Ratio (bug/safe)", fontsize=12)
axes[1].set_title(f"Thinking Allocation (avg: {agg_ratio:.1f}x)", fontsize=13)
axes[1].axvline(1.0, color="#888", ls=":", lw=0.8, label="uniform (1.0x)")
for i, v in enumerate(ratios):
axes[1].text(v + 0.1, i, f"{v:.1f}x", va="center", fontsize=10)
axes[1].legend(fontsize=9)
mode = results[0].get("mode", "trained adapter")
fig.suptitle(
f"Transfer Domain Evaluation ā Real Model Inference\n({mode})",
fontsize=14, fontweight="bold", y=1.02,
)
fig.tight_layout()
plot_path = out_dir / "transfer_inference_plot.png"
fig.savefig(plot_path, dpi=140, bbox_inches="tight")
plt.close(fig)
print(f"ā
Plot saved to {plot_path}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--adapter", default=str(DEFAULT_ADAPTER),
help="Path to trained LoRA adapter")
ap.add_argument("--model", default="Qwen/Qwen3-1.7B",
help="Base model name")
ap.add_argument("--episodes", default=str(DEFAULT_EPISODES))
ap.add_argument("--out", default=str(OUT_DIR))
args = ap.parse_args()
with open(args.episodes) as f:
episodes = json.load(f)
print(f"š {len(episodes)} transfer episodes loaded")
# Try real model inference first, fall back to heuristic
has_adapter = os.path.exists(args.adapter)
has_torch = False
try:
import torch
has_torch = torch.cuda.is_available()
except ImportError:
pass
if has_torch:
print(f"š Running with {'adapter' if has_adapter else 'base model'} inference")
results = run_with_model(
episodes,
args.adapter if has_adapter else None,
args.model,
)
else:
print("ā ļø No GPU available ā using heuristic proxy")
results = run_without_model(episodes)
save_results(results, args.out)
if __name__ == "__main__":
main()