from __future__ import annotations import importlib.util import json import math import os import re import sys from collections import Counter, defaultdict from pathlib import Path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from datasets import Dataset try: from unsloth import FastLanguageModel, is_bfloat16_supported HAS_UNSLOTH = True except ImportError: HAS_UNSLOTH = False def is_bfloat16_supported() -> bool: return False try: from unsloth import PatchFastRL PatchFastRL("GRPO", FastLanguageModel) except ImportError: pass try: from bug_bank import BugBank, BugSample, build_bug_bank from seed_bank import SEED_BANK, SeedSpec, get_seed_by_id from server.bug_injector import infer_bug_operator from server.executor import execute_code from server.graders import ( compute_ast_distance, compute_proposer_reward, compute_solver_reward, is_effectively_unchanged, reset_reward_history, ) from training.dual_role_sampler import sample_proposer_prompt, sample_solver_prompt except ImportError: from ..bug_bank import BugBank, BugSample, build_bug_bank from ..seed_bank import SEED_BANK, SeedSpec, get_seed_by_id from ..server.bug_injector import infer_bug_operator from ..server.executor import execute_code from ..server.graders import ( compute_ast_distance, compute_proposer_reward, compute_solver_reward, is_effectively_unchanged, reset_reward_history, ) from .dual_role_sampler import sample_proposer_prompt, sample_solver_prompt DEFAULT_MODEL_ID = "unsloth/Qwen2.5-Coder-3B-Instruct" DEFAULT_FALLBACK_MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct" DEFAULT_OUTPUT_DIR = Path("debugzero_model") DEFAULT_SOLVER_WEIGHT = 2 DEFAULT_NUM_GENERATIONS = 4 DEFAULT_MAX_STEPS = 80 DEFAULT_MAX_PROMPT_LENGTH = 768 DEFAULT_MAX_COMPLETION_LENGTH = 256 DRY_RUN_MAX_STEPS = 2 DEFAULT_PROPOSER_METRICS_PATH = DEFAULT_OUTPUT_DIR / "proposer_metrics.json" TARGETED_PROMPT_RATIO = 0.75 def extract_python_code(text: str) -> str: match = re.search(r"```(?:python)?\s(.*?)```", text, flags=re.DOTALL) if match: return match.group(1).strip() return text.strip() def completion_to_text(completion) -> str: if isinstance(completion, list) and completion: item = completion[0] if isinstance(item, dict): return item.get("content", "") return str(item) return str(completion) def prompt_to_text(prompt) -> str: if isinstance(prompt, list): parts = [] for item in prompt: if isinstance(item, dict): parts.append(str(item.get("content", ""))) else: parts.append(str(item)) return "\n".join(part for part in parts if part) if isinstance(prompt, dict): return str(prompt.get("content", "")) return str(prompt) def execute_candidate(seed: SeedSpec, candidate_code: str) -> dict[str, object]: result = execute_code(candidate_code, seed.test) execution_result = result.output[:500] if result.output else "" unsafe_code = execution_result.startswith("Unsafe import detected.") return { "tests_passed": result.passed, "syntax_error": result.syntax_error, "unsafe_code": unsafe_code, "execution_result": execution_result, } def build_mixed_role_dataset( bug_bank: BugBank, solver_weight: int = DEFAULT_SOLVER_WEIGHT, ) -> Dataset: rows: list[dict[str, object]] = [] for bug_sample in bug_bank.train_samples: prompt_text = sample_solver_prompt( bug_sample.buggy_code, bug_sample.execution_result, mode="concise", ) rows.append( { "prompt": [{"role": "user", "content": prompt_text}], "role": "solver", "seed_id": bug_sample.seed_id, "original_code": bug_sample.original_code, "buggy_code": bug_sample.buggy_code, "bug_operator": bug_sample.bug_operator, "execution_result": bug_sample.execution_result, } ) target_proposer_rows = max(1, math.ceil(len(rows) / solver_weight)) if rows else len(SEED_BANK) proposer_rows = build_weighted_proposer_rows(bug_bank, target_proposer_rows) rows.extend(proposer_rows) return Dataset.from_list(rows) def create_dataset() -> tuple[Dataset, BugBank]: bug_bank = build_bug_bank() return build_mixed_role_dataset(bug_bank), bug_bank def prop_rew(prompts, completions, **kwargs): rewards: list[float] = [] roles = kwargs.get("role", []) seed_ids = kwargs.get("seed_id", []) original_codes = kwargs.get("original_code", []) for i, completion in enumerate(completions): role = roles[i] if i < len(roles) else roles[0] if role != "proposer": rewards.append(0.0) continue seed_id = seed_ids[i] if i < len(seed_ids) else seed_ids[0] original_code = original_codes[i] if i < len(original_codes) else original_codes[0] seed = get_seed_by_id(seed_id) candidate_code = extract_python_code(completion_to_text(completion)) execution_meta = execute_candidate(seed, candidate_code) unchanged_code = is_effectively_unchanged(original_code, candidate_code) changed_but_passing = ( (not unchanged_code) and execution_meta["tests_passed"] and (not execution_meta["syntax_error"]) ) proposer_meta = { "seed_id": seed.seed_id, "tests_passed": execution_meta["tests_passed"], "syntax_error": execution_meta["syntax_error"], "unsafe_code": execution_meta["unsafe_code"], "unchanged_code": unchanged_code, "changed_but_passing": changed_but_passing, "plausibility_score": 0.0, } if not execution_meta["syntax_error"]: proposer_meta["plausibility_score"] = compute_ast_distance(original_code, candidate_code) rewards.append(compute_proposer_reward(proposer_meta)) return rewards def solv_rew(prompts, completions, **kwargs): rewards: list[float] = [] roles = kwargs.get("role", []) seed_ids = kwargs.get("seed_id", []) for i, completion in enumerate(completions): role = roles[i] if i < len(roles) else roles[0] if role != "solver": rewards.append(0.0) continue seed_id = seed_ids[i] if i < len(seed_ids) else seed_ids[0] seed = get_seed_by_id(seed_id) candidate_code = extract_python_code(completion_to_text(completion)) execution_meta = execute_candidate(seed, candidate_code) solver_meta = { "seed_id": seed.seed_id, "tests_passed": execution_meta["tests_passed"], "syntax_error": execution_meta["syntax_error"], "unsafe_code": execution_meta["unsafe_code"], } rewards.append(compute_solver_reward(solver_meta)) return rewards def evaluate_bug_sample(candidate_code: str, bug_sample: BugSample) -> dict[str, object]: seed = get_seed_by_id(bug_sample.seed_id) evaluation = execute_candidate(seed, candidate_code) reward = compute_solver_reward( { "seed_id": bug_sample.seed_id, "tests_passed": evaluation["tests_passed"], "syntax_error": evaluation["syntax_error"], "unsafe_code": evaluation["unsafe_code"], } ) return {**evaluation, "reward": reward} def evaluate_solver_fixed_set(model, tokenizer, bug_bank: BugBank) -> dict[str, float]: results = [] for bug_sample in bug_bank.eval_samples: prompt = sample_solver_prompt( bug_sample.buggy_code, bug_sample.execution_result, mode="concise", ) candidate_code = generate_code(model, tokenizer, prompt, do_sample=False) results.append(evaluate_bug_sample(candidate_code, bug_sample)) return summarize_solver_results(results) def evaluate_proposer_fixed_set(model, tokenizer) -> dict[str, float]: results = [] for seed in SEED_BANK: prompt = sample_proposer_prompt(seed.original_code) candidate_code = generate_code(model, tokenizer, prompt, do_sample=False) evaluation = execute_candidate(seed, candidate_code) unchanged_code = is_effectively_unchanged(seed.original_code, candidate_code) valid_bug = (not evaluation["tests_passed"]) and (not evaluation["syntax_error"]) changed_but_passing = ( (not unchanged_code) and evaluation["tests_passed"] and (not evaluation["syntax_error"]) ) reward = compute_proposer_reward( { "seed_id": seed.seed_id, "tests_passed": evaluation["tests_passed"], "syntax_error": evaluation["syntax_error"], "unsafe_code": evaluation["unsafe_code"], "unchanged_code": unchanged_code, "changed_but_passing": changed_but_passing, "plausibility_score": 0.0 if evaluation["syntax_error"] else compute_ast_distance(seed.original_code, candidate_code), } ) results.append( { "seed_id": seed.seed_id, **evaluation, "reward": reward, "unchanged_code": unchanged_code, "valid_bug": valid_bug, "changed_but_passing": changed_but_passing, "likely_bug_family": infer_bug_operator(seed.original_code, candidate_code) or "unknown", } ) summary = summarize_proposer_results(results) summary["by_seed"] = summarize_proposer_by_seed(results) summary["by_bug_family"] = summarize_proposer_by_bug_family(results) return summary def summarize_solver_results(results: list[dict[str, object]]) -> dict[str, float]: total = len(results) or 1 passed = sum(1 for result in results if result["tests_passed"]) syntax_errors = sum(1 for result in results if result["syntax_error"]) mean_reward = sum(float(result["reward"]) for result in results) / total return { "pass_rate": passed / total, "syntax_error_rate": syntax_errors / total, "mean_reward": mean_reward, } def summarize_proposer_results(results: list[dict[str, object]]) -> dict[str, float]: total = len(results) or 1 bug_rate = sum( 1 for result in results if (not result["tests_passed"]) and (not result["syntax_error"]) ) unchanged = sum(1 for result in results if result.get("unchanged_code")) changed_but_passing = sum(1 for result in results if result.get("changed_but_passing")) syntax_errors = sum(1 for result in results if result["syntax_error"]) mean_reward = sum(float(result["reward"]) for result in results) / total return { "break_rate": bug_rate / total, "valid_bug_rate": bug_rate / total, "unchanged_rate": unchanged / total, "changed_but_passing_rate": changed_but_passing / total, "syntax_error_rate": syntax_errors / total, "mean_reward": mean_reward, } def summarize_proposer_by_seed(results: list[dict[str, object]]) -> dict[str, dict[str, float]]: grouped: dict[str, list[dict[str, object]]] = defaultdict(list) for result in results: grouped[str(result["seed_id"])].append(result) summary: dict[str, dict[str, float]] = {} for seed_id, seed_results in grouped.items(): total = len(seed_results) or 1 summary[seed_id] = { "valid_bug_rate": sum(1 for item in seed_results if item.get("valid_bug")) / total, "unchanged_rate": sum(1 for item in seed_results if item.get("unchanged_code")) / total, "changed_but_passing_rate": sum( 1 for item in seed_results if item.get("changed_but_passing") ) / total, "mean_reward": sum(float(item["reward"]) for item in seed_results) / total, } return summary def summarize_proposer_by_bug_family(results: list[dict[str, object]]) -> dict[str, dict[str, float]]: grouped: dict[str, list[dict[str, object]]] = defaultdict(list) for result in results: grouped[str(result.get("likely_bug_family", "unknown"))].append(result) summary: dict[str, dict[str, float]] = {} for family, family_results in grouped.items(): total = len(family_results) or 1 summary[family] = { "count": float(total), "valid_bug_rate": sum(1 for item in family_results if item.get("valid_bug")) / total, "mean_reward": sum(float(item["reward"]) for item in family_results) / total, } return summary def build_weighted_proposer_rows(bug_bank: BugBank, target_proposer_rows: int) -> list[dict[str, object]]: if target_proposer_rows <= 0: return [] prior_seed_rates = load_prior_seed_break_rates() operator_counts = Counter(sample.bug_operator for sample in bug_bank.train_samples) seed_to_operators: dict[str, list[str]] = defaultdict(list) for sample in bug_bank.train_samples: seed_to_operators[sample.seed_id].append(sample.bug_operator) seed_weights = {} for seed in SEED_BANK: prior_break_rate = prior_seed_rates.get(seed.seed_id, 0.5) seed_weights[seed.seed_id] = max(1, 1 + round((1.0 - prior_break_rate) * 2)) rows: list[dict[str, object]] = [] focus_counters = Counter() ordered_seeds = sorted(SEED_BANK, key=lambda seed: (-seed_weights[seed.seed_id], seed.seed_id)) # Keep every seed represented before adding extra weight to weak seeds. for seed in SEED_BANK[:target_proposer_rows]: bug_focus = choose_proposer_bug_focus( seed.seed_id, seed_to_operators[seed.seed_id], operator_counts, focus_counters, len(rows), target_proposer_rows, ) prompt_text = sample_proposer_prompt(seed.original_code, bug_focus=bug_focus) rows.append( { "prompt": [{"role": "user", "content": prompt_text}], "role": "proposer", "seed_id": seed.seed_id, "original_code": seed.original_code, "bug_focus": bug_focus if bug_focus else "", } ) while len(rows) < target_proposer_rows: for seed in ordered_seeds: extra_weight = max(0, seed_weights[seed.seed_id] - 1) for _ in range(extra_weight): if len(rows) >= target_proposer_rows: break bug_focus = choose_proposer_bug_focus( seed.seed_id, seed_to_operators[seed.seed_id], operator_counts, focus_counters, len(rows), target_proposer_rows, ) prompt_text = sample_proposer_prompt(seed.original_code, bug_focus=bug_focus) rows.append( { "prompt": [{"role": "user", "content": prompt_text}], "role": "proposer", "seed_id": seed.seed_id, "original_code": seed.original_code, "bug_focus": bug_focus if bug_focus else "", } ) if len(rows) >= target_proposer_rows: break return rows def choose_proposer_bug_focus( seed_id: str, operators: list[str], operator_counts: Counter, focus_counters: Counter, row_index: int, total_rows: int, ) -> str | None: unique_operators = sorted(set(operators), key=lambda op: (operator_counts[op], op)) if not unique_operators: return None if row_index >= math.ceil(total_rows * TARGETED_PROMPT_RATIO): return None del seed_id chosen = min(unique_operators, key=lambda op: (focus_counters[op], operator_counts[op], op)) focus_counters[chosen] += 1 return chosen def load_prior_seed_break_rates() -> dict[str, float]: if not DEFAULT_PROPOSER_METRICS_PATH.exists(): return {} try: data = json.loads(DEFAULT_PROPOSER_METRICS_PATH.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return {} seed_metrics = data.get("post_proposer_metrics", {}).get("by_seed", {}) return { str(seed_id): float(metrics.get("valid_bug_rate", 0.5)) for seed_id, metrics in seed_metrics.items() if isinstance(metrics, dict) } def save_metrics_artifact( pre_proposer_metrics: dict[str, object], post_proposer_metrics: dict[str, object], ) -> Path: DEFAULT_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) artifact = { "pre_proposer_metrics": pre_proposer_metrics, "post_proposer_metrics": post_proposer_metrics, } DEFAULT_PROPOSER_METRICS_PATH.write_text( json.dumps(artifact, indent=2, sort_keys=True), encoding="utf-8", ) return DEFAULT_PROPOSER_METRICS_PATH def generate_code( model, tokenizer, prompt: str | list[dict[str, str]], *, do_sample: bool, max_new_tokens: int = DEFAULT_MAX_COMPLETION_LENGTH, ) -> str: import torch if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.eval() if isinstance(prompt, list): prompt_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) else: prompt_text = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True) encoded = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=DEFAULT_MAX_PROMPT_LENGTH) model_device = next(model.parameters()).device encoded = {key: value.to(model_device) for key, value in encoded.items()} generation_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": do_sample, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id, } if do_sample: generation_kwargs["temperature"] = 0.7 generation_kwargs["top_p"] = 0.95 with torch.no_grad(): output = model.generate(**encoded, **generation_kwargs) decoded = tokenizer.decode(output[0], skip_special_tokens=True) completion = decoded[len(prompt) :] if decoded.startswith(prompt) else decoded return extract_python_code(completion) def get_training_profile(dry_run: bool) -> dict[str, int | float | bool | str]: has_bitsandbytes = importlib.util.find_spec("bitsandbytes") is not None return { "per_device_train_batch_size": 1, "gradient_accumulation_steps": 4, "learning_rate": 2e-5, "max_steps": DRY_RUN_MAX_STEPS if dry_run else DEFAULT_MAX_STEPS, "num_generations": 2 if dry_run else DEFAULT_NUM_GENERATIONS, "max_completion_length": DEFAULT_MAX_COMPLETION_LENGTH, "report_to": "none", "optim": "adamw_torch" if dry_run or not has_bitsandbytes else "adamw_8bit", } def load_training_model_and_tokenizer( dry_run: bool, dataset: Dataset, bug_bank: BugBank, ): if dry_run: return build_tiny_local_model_and_tokenizer(dataset, bug_bank) if HAS_UNSLOTH: print("Initializing Unsloth FastLanguageModel...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=DEFAULT_MODEL_ID, max_seq_length=DEFAULT_MAX_PROMPT_LENGTH + DEFAULT_MAX_COMPLETION_LENGTH, load_in_4bit=True, fast_inference=True, ) model = FastLanguageModel.get_peft_model( model, r=16, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=16, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=3407, use_rslora=False, loftq_config=None, ) return model, tokenizer # Unsloth is failing to load (e.g., due to Kaggle/Colab CUDA mismatch). # Falling back to standard HuggingFace PEFT (LoRA). print("Unsloth not available. Falling back to standard Transformers loading.") from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model tokenizer = AutoTokenizer.from_pretrained(DEFAULT_FALLBACK_MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(DEFAULT_FALLBACK_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto") peft_config = LoraConfig( r=16, lora_alpha=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, peft_config) return model, tokenizer def build_tiny_local_model_and_tokenizer(dataset: Dataset, bug_bank: BugBank): from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace from tokenizers.trainers import WordLevelTrainer from transformers import GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast corpus = [prompt_to_text(row["prompt"]) for row in dataset] corpus.extend(sample.original_code for sample in bug_bank.train_samples) corpus.extend(sample.buggy_code for sample in bug_bank.train_samples) corpus.extend(sample.original_code for sample in bug_bank.eval_samples) corpus.extend(sample.buggy_code for sample in bug_bank.eval_samples) corpus.extend(seed.test for seed in SEED_BANK) tokenizer_object = Tokenizer(WordLevel(unk_token="")) tokenizer_object.pre_tokenizer = Whitespace() trainer = WordLevelTrainer( special_tokens=["", "", "", ""], min_frequency=1, ) tokenizer_object.train_from_iterator(corpus, trainer=trainer) tokenizer = PreTrainedTokenizerFast( tokenizer_object=tokenizer_object, bos_token="", eos_token="", unk_token="", pad_token="", ) tokenizer.chat_template = ( "{% for message in messages %}" "{{ message['role'] }}: {{ message['content'] }}\n" "{% endfor %}" "{% if add_generation_prompt %}assistant: {% endif %}" ) model_config = GPT2Config( vocab_size=tokenizer.vocab_size, n_positions=DEFAULT_MAX_PROMPT_LENGTH + DEFAULT_MAX_COMPLETION_LENGTH, n_ctx=DEFAULT_MAX_PROMPT_LENGTH + DEFAULT_MAX_COMPLETION_LENGTH, n_embd=128, n_layer=2, n_head=2, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) model = GPT2LMHeadModel(model_config) return model, tokenizer def get_trl_classes(): if os.name == "nt" and not sys.flags.utf8_mode: print("Windows detected. Use `python -X utf8` when running this file locally.") from trl import GRPOConfig, GRPOTrainer return GRPOConfig, GRPOTrainer def create_trainer(model, tokenizer, dataset: Dataset, dry_run: bool): GRPOConfig, GRPOTrainer = get_trl_classes() profile = get_training_profile(dry_run) supported_kwargs = importlib.import_module("inspect").signature(GRPOConfig.__init__).parameters config_kwargs = { "output_dir": str(DEFAULT_OUTPUT_DIR), "per_device_train_batch_size": profile["per_device_train_batch_size"], "gradient_accumulation_steps": profile["gradient_accumulation_steps"], "learning_rate": profile["learning_rate"], "max_steps": profile["max_steps"], "num_generations": profile["num_generations"], "max_prompt_length": DEFAULT_MAX_PROMPT_LENGTH, "max_completion_length": profile["max_completion_length"], "bf16": (not dry_run) and HAS_UNSLOTH and is_bfloat16_supported(), "fp16": (not dry_run) and not is_bfloat16_supported(), "use_cpu": dry_run, "logging_steps": 1 if dry_run else 5, "optim": profile["optim"], "report_to": profile["report_to"], "disable_tqdm": True, } training_args = GRPOConfig(**{k: v for k, v in config_kwargs.items() if k in supported_kwargs}) print(f"Starting GRPO training for {training_args.max_steps} episodes (steps)...") print("To change the number of episodes, modify 'max_steps' in the training profile.") return GRPOTrainer( model=model, reward_funcs=[prop_rew, solv_rew], args=training_args, train_dataset=dataset, processing_class=tokenizer, ) def save_results_plot( pre_solver_metrics: dict[str, float], post_solver_metrics: dict[str, float], pre_proposer_metrics: dict[str, float], post_proposer_metrics: dict[str, float], log_history: list[dict[str, float]], ) -> Path | None: try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt except ImportError: print("matplotlib is not installed, skipping plot generation.") return None DEFAULT_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) plot_path = DEFAULT_OUTPUT_DIR / "debugzero_results.png" figure, axes = plt.subplots(1, 2, figsize=(10, 4)) axes[0].bar( ["solver pre", "solver post", "proposer pre", "proposer post"], [ pre_solver_metrics["pass_rate"], post_solver_metrics["pass_rate"], pre_proposer_metrics["break_rate"], post_proposer_metrics["break_rate"], ], color=["#4f81bd", "#4f81bd", "#c0504d", "#c0504d"], ) axes[0].set_ylim(0.0, 1.0) axes[0].set_title("Fixed Eval Rates") axes[0].set_ylabel("Rate") steps = [entry["step"] for entry in log_history if "step" in entry] losses = [entry["loss"] for entry in log_history if "loss" in entry] if steps and losses: axes[1].plot(steps[: len(losses)], losses, marker="o") axes[1].set_title("Training Loss") axes[1].set_xlabel("Step") axes[1].set_ylabel("Loss") else: axes[1].bar( ["solver reward pre", "solver reward post"], [ pre_solver_metrics["mean_reward"], post_solver_metrics["mean_reward"], ], color=["#9bbb59", "#9bbb59"], ) axes[1].set_title("Solver Mean Reward") figure.tight_layout() figure.savefig(plot_path) plt.close(figure) return plot_path def run_workflow(dry_run: bool = False) -> dict[str, object]: dataset, bug_bank = create_dataset() print( f"Built dataset with {len(dataset)} rows from " f"{len(bug_bank.train_samples)} training bugs and {len(bug_bank.eval_samples)} eval bugs." ) model, tokenizer = load_training_model_and_tokenizer(dry_run, dataset, bug_bank) trainer = create_trainer(model, tokenizer, dataset, dry_run) reset_reward_history() pre_solver_metrics = evaluate_solver_fixed_set(model, tokenizer, bug_bank) pre_proposer_metrics = evaluate_proposer_fixed_set(model, tokenizer) print("Pre-training solver metrics:", pre_solver_metrics) print("Pre-training proposer metrics:", pre_proposer_metrics) reset_reward_history() train_result = trainer.train() post_solver_metrics = evaluate_solver_fixed_set(trainer.model, tokenizer, bug_bank) post_proposer_metrics = evaluate_proposer_fixed_set(trainer.model, tokenizer) plot_path = save_results_plot( pre_solver_metrics, post_solver_metrics, pre_proposer_metrics, post_proposer_metrics, trainer.state.log_history, ) metrics_artifact_path = save_metrics_artifact(pre_proposer_metrics, post_proposer_metrics) results = { "train_result": train_result, "pre_solver_metrics": pre_solver_metrics, "post_solver_metrics": post_solver_metrics, "pre_proposer_metrics": pre_proposer_metrics, "post_proposer_metrics": post_proposer_metrics, "plot_path": str(plot_path) if plot_path else None, "metrics_artifact_path": str(metrics_artifact_path), "dataset_size": len(dataset), "train_bug_count": len(bug_bank.train_samples), "eval_bug_count": len(bug_bank.eval_samples), } print("Post-training solver metrics:", post_solver_metrics) print("Post-training proposer metrics:", post_proposer_metrics) if plot_path: print(f"Saved plot to {plot_path}") print(f"Saved proposer metrics to {metrics_artifact_path}") return results def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--dry_run", action="store_true", help="Run a tiny local GRPO smoke test.") args = parser.parse_args() run_workflow(dry_run=args.dry_run) if __name__ == "__main__": main()