Spaces:
Sleeping
Sleeping
| """ | |
| Inference Script — Code Migration Environment | |
| ============================================== | |
| Runs migration tasks using a locally loaded model with 4-bit quantization. | |
| Logs everything to files: console log, per-task JSON with all steps/actions/outputs. | |
| Environment variables: | |
| MODEL_NAME (default: google/gemma-4-E4B-it) | |
| DATASET_PATH (default: bundled verified dataset) | |
| DIFFICULTY (default: all) | |
| MAX_STEPS (default: 30) | |
| MAX_TEST_EXEC (default: 5) | |
| TASK_LIMIT (default: 3) | |
| LOG_DIR (default: ./logs) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from code_migration.models import CodeMigrationAction, _TOOL_REQUIRED_ARGS | |
| from code_migration.server.code_migration_environment import CodeMigrationEnvironment | |
| from code_migration.research_agent import ResearchAgent | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-4-E4B-it") | |
| ADAPTER_PATH = os.getenv("ADAPTER_PATH", None) # path to trained LoRA adapter | |
| DATASET_PATH = os.getenv("DATASET_PATH", os.path.join(os.path.dirname(__file__), "data", "eval.jsonl")) | |
| DIFFICULTY = os.getenv("DIFFICULTY", "all") | |
| MAX_STEPS = int(os.getenv("MAX_STEPS", "30")) | |
| MAX_TEST_EXEC = int(os.getenv("MAX_TEST_EXEC", "5")) | |
| TASK_LIMIT = int(os.getenv("TASK_LIMIT", "9999")) # default: run all tasks | |
| LOG_DIR = os.getenv("LOG_DIR", "./logs") | |
| TEMPERATURE = 0.3 | |
| MAX_NEW_TOKENS = 400 | |
| # --------------------------------------------------------------------------- | |
| # Logging setup — console + file | |
| # --------------------------------------------------------------------------- | |
| run_id = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| log_dir = Path(LOG_DIR) / run_id | |
| log_dir.mkdir(parents=True, exist_ok=True) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(message)s", | |
| datefmt="%H:%M:%S", | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout), | |
| logging.FileHandler(log_dir / "console.log"), | |
| ], | |
| ) | |
| log = logging.getLogger("inference") | |
| # --------------------------------------------------------------------------- | |
| # Tool block for system prompt | |
| # --------------------------------------------------------------------------- | |
| TOOL_BLOCK = """Available tools: | |
| - list_dir(dir_path?): List files/subdirs (default /work) | |
| - search_dir(regex_pattern, dir_path?): Search .py file contents for regex | |
| - search_file(regex_pattern, file_path): Search one file for regex | |
| - view_file(file_path, line_no): View ±50 lines around line_no | |
| - edit_file(file_path, start_line, end_line, replacement_text): Replace lines | |
| - replace_all_in_file(file_path, regex_pattern, replacement_string): Regex replace | |
| - revert_last(): Undo last edit | |
| - execute_tests(): Run tests in Docker | |
| - search_last_log(regex_pattern): Search last test log | |
| - view_last_log(line_no): View last test log""" | |
| SYSTEM_PROMPT = ( | |
| "You are an expert Python developer fixing failing tests after dependency upgrades.\n\n" | |
| + TOOL_BLOCK + "\n\n" | |
| "RULES:\n" | |
| "- Output EXACTLY ONE JSON tool call: {\"name\": \"...\", \"arguments\": {...}}\n" | |
| "- Do NOT repeat the same action. Act on info you already have.\n" | |
| "- search_dir searches file CONTENTS not filenames.\n" | |
| "- Be decisive: view error → find code → edit → test. 4-8 steps.\n" | |
| "- NEVER make the same tool call with the same arguments twice in a row. Do something different first.\n" | |
| "- execute_tests can be re-run after making edits — that's expected.\n" | |
| "- Don't re-read files you already have in context. Use the info from previous steps.\n" | |
| "- If the research agent already found the fix pattern, apply it directly — don't search again.\n" | |
| "- CRITICAL: Line numbers in test logs are TEST LOG line numbers, NOT source file line numbers.\n" | |
| " Always use search_file or view_file to find the ACTUAL line number in the source file before editing.\n" | |
| " Use replace_all_in_file when possible — it doesn't need line numbers and is safer.\n" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Model family detection | |
| # --------------------------------------------------------------------------- | |
| def _detect_model_family(model_name: str) -> str: | |
| """Detect model family from model name string. | |
| Returns: 'gemma', 'qwen3', 'qwen2', or 'unknown' | |
| Qwen3/3.5 uses <think>...</think> blocks and enable_thinking param. | |
| Qwen2.5 uses <|im_end|> tokens, no thinking. | |
| Gemma 4 uses <|channel>thought...<channel|> blocks and <|think|> token. | |
| """ | |
| name_lower = model_name.lower() | |
| if "gemma" in name_lower: | |
| return "gemma" | |
| if "qwen3" in name_lower: | |
| return "qwen3" | |
| if "qwen" in name_lower: | |
| return "qwen2" | |
| return "unknown" | |
| MODEL_FAMILY = _detect_model_family(MODEL_NAME) | |
| # --------------------------------------------------------------------------- | |
| # Device detection | |
| # --------------------------------------------------------------------------- | |
| def _get_device() -> str: | |
| """Return best available device: 'cuda', 'mps', or 'cpu'.""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| DEVICE = _get_device() | |
| # --------------------------------------------------------------------------- | |
| # Model loading — supports CUDA (4-bit), MPS (float16), CPU (float32) | |
| # --------------------------------------------------------------------------- | |
| def load_model(model_name: str, adapter_path: str = None): | |
| """Load model on the best available device. | |
| - CUDA: 4-bit NF4 quantization via bitsandbytes | |
| - MPS (Apple Silicon): float16, no quantization | |
| - CPU: float32 fallback | |
| Supports Gemma 4, Qwen 3.5, and Qwen 2.5 model families. | |
| If adapter_path is provided, loads a trained LoRA adapter on top. | |
| """ | |
| family = _detect_model_family(model_name) | |
| device = _get_device() | |
| log.info("Loading %s (family=%s) on device=%s", model_name, family, device) | |
| if adapter_path: | |
| log.info(" + LoRA adapter from: %s", adapter_path) | |
| else: | |
| log.info(" (base model, no adapter)") | |
| t0 = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if device == "cuda": | |
| # CUDA: use 4-bit quantization | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=bnb_config, | |
| device_map={"": 0}, | |
| trust_remote_code=True, | |
| ) | |
| elif device == "mps": | |
| # Apple Silicon: float16, no quantization | |
| log.info(" Using float16 on MPS (Apple Silicon)") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| ).to("mps") | |
| else: | |
| # CPU fallback: float32 | |
| log.info(" Using float32 on CPU (this will be slow)") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True, | |
| ) | |
| # Gemma 4 specific: unwrap ClippableLinear before LoRA | |
| if family == "gemma": | |
| replacements = [] | |
| for name, module in model.named_modules(): | |
| if type(module).__name__ == "Gemma4ClippableLinear": | |
| if hasattr(module, "linear"): | |
| replacements.append((name, module.linear)) | |
| for name, inner in replacements: | |
| parts = name.split(".") | |
| parent = model.get_submodule(".".join(parts[:-1])) if len(parts) > 1 else model | |
| setattr(parent, parts[-1], inner) | |
| if replacements: | |
| log.info("Unwrapped %d ClippableLinear modules", len(replacements)) | |
| # Load LoRA adapter if provided | |
| if adapter_path: | |
| from peft import PeftModel | |
| log.info("Loading LoRA adapter...") | |
| model = PeftModel.from_pretrained(model, adapter_path) | |
| log.info("LoRA adapter loaded.") | |
| elapsed = time.time() - t0 | |
| if device == "cuda": | |
| mem_gb = torch.cuda.memory_allocated(0) / 1e9 | |
| elif device == "mps": | |
| # MPS doesn't have a direct memory query, estimate from model size | |
| param_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) | |
| mem_gb = param_bytes / 1e9 | |
| else: | |
| param_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) | |
| mem_gb = param_bytes / 1e9 | |
| log.info("Loaded in %.1fs | memory: ~%.2f GB | device: %s", elapsed, mem_gb, device) | |
| return model, tokenizer | |
| # --------------------------------------------------------------------------- | |
| # Generation | |
| # --------------------------------------------------------------------------- | |
| def _strip_model_artifacts(raw_text: str, family: str) -> str: | |
| """Strip model-specific artifacts from generated text. | |
| Gemma 4: thinking blocks <|channel>thought...<channel|>, special tokens | |
| Qwen 3/3.5: thinking blocks <think>...</think> | |
| Qwen 2.5: <|im_end|>, <|endoftext|> | |
| CRITICAL: After stripping, truncate at the first <|im_end|> or <|im_start|> | |
| to prevent the model from hallucinating multi-turn conversations. | |
| """ | |
| clean = raw_text | |
| if family == "gemma": | |
| clean = re.sub(r"<\|channel>thought\n.*?<channel\|>", "", clean, flags=re.DOTALL) | |
| for tok in ["<turn|>", "<|turn>", "<eos>", "</s>"]: | |
| clean = clean.replace(tok, "") | |
| elif family == "qwen3": | |
| # Strip <think>...</think> blocks first | |
| clean = re.sub(r"<think>.*?</think>", "", clean, flags=re.DOTALL) | |
| # Truncate at first <|im_end|> — everything after is hallucinated | |
| im_end = clean.find("<|im_end|>") | |
| if im_end != -1: | |
| clean = clean[:im_end] | |
| for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]: | |
| clean = clean.replace(tok, "") | |
| elif family == "qwen2": | |
| # Truncate at first <|im_end|> | |
| im_end = clean.find("<|im_end|>") | |
| if im_end != -1: | |
| clean = clean[:im_end] | |
| for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]: | |
| clean = clean.replace(tok, "") | |
| else: | |
| # Generic cleanup | |
| for tok in ["<eos>", "</s>", "<|im_end|>", "<|endoftext|>", "<turn|>", "<|turn>"]: | |
| clean = clean.replace(tok, "") | |
| return clean.strip() | |
| def generate_tool_call(model, tokenizer, messages: List[Dict]) -> Dict[str, Any]: | |
| """Generate one tool call. Handles Gemma 4, Qwen 3.5, and Qwen 2.5 model families.""" | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Qwen3/3.5: disable thinking mode for direct JSON output | |
| if MODEL_FAMILY == "qwen3": | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| else: | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| # Gemma 4: strip thinking trigger to disable thinking mode | |
| if MODEL_FAMILY == "gemma": | |
| text = text.replace("<|think|>", "") | |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| # Build stop token IDs to prevent hallucinated multi-turn generation | |
| gen_kwargs = dict( | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=TEMPERATURE, | |
| top_p=0.95, | |
| top_k=50, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # For Qwen models, add <|im_end|> as a stop token | |
| if MODEL_FAMILY in ("qwen3", "qwen2"): | |
| stop_ids = [] | |
| for tok_str in ["<|im_end|>", "<|endoftext|>"]: | |
| tid = tokenizer.convert_tokens_to_ids(tok_str) | |
| if tid is not None and tid != tokenizer.unk_token_id: | |
| stop_ids.append(tid) | |
| if stop_ids: | |
| eos = gen_kwargs.get("eos_token_id", tokenizer.eos_token_id) | |
| if isinstance(eos, int): | |
| stop_ids.append(eos) | |
| elif isinstance(eos, list): | |
| stop_ids.extend(eos) | |
| gen_kwargs["eos_token_id"] = list(set(stop_ids)) | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, **gen_kwargs) | |
| raw_text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=False) | |
| del inputs, outputs | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| clean = _strip_model_artifacts(raw_text, MODEL_FAMILY) | |
| parsed = _parse_tool_call(clean) | |
| parsed["raw_text"] = raw_text | |
| parsed["clean_text"] = clean | |
| parsed["input_tokens"] = input_len | |
| return parsed | |
| def _parse_tool_call(text: str) -> Dict[str, Any]: | |
| """Parse JSON tool call from model output. | |
| Finds the FIRST complete JSON object — ignores any hallucinated | |
| multi-turn content that may follow. | |
| """ | |
| text = text.strip() | |
| # Strip markdown fences | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| lines = [l for l in lines if not l.strip().startswith("```")] | |
| text = "\n".join(lines).strip() | |
| # Find the first { and then find its matching } | |
| start = text.find("{") | |
| if start == -1: | |
| return {"tool_name": "list_dir", "tool_args": {}} | |
| # Try progressively longer substrings to find valid JSON | |
| depth = 0 | |
| for i in range(start, len(text)): | |
| if text[i] == "{": | |
| depth += 1 | |
| elif text[i] == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| candidate = text[start:i + 1] | |
| try: | |
| data = json.loads(candidate) | |
| if "tool_name" in data: | |
| return {"tool_name": data["tool_name"], "tool_args": data.get("tool_args", {})} | |
| if "name" in data: | |
| return {"tool_name": data["name"], "tool_args": data.get("arguments", data.get("parameters", {}))} | |
| if "action" in data: | |
| action = data.pop("action") | |
| return {"tool_name": action, "tool_args": data} | |
| except json.JSONDecodeError: | |
| pass | |
| # First balanced braces didn't parse — keep looking | |
| break | |
| # Fallback: try rfind approach | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| try: | |
| data = json.loads(text[start:end + 1]) | |
| if "name" in data: | |
| return {"tool_name": data["name"], "tool_args": data.get("arguments", data.get("parameters", {}))} | |
| except json.JSONDecodeError: | |
| pass | |
| return {"tool_name": "list_dir", "tool_args": {}} | |
| # --------------------------------------------------------------------------- | |
| # Run a single task with full logging | |
| # --------------------------------------------------------------------------- | |
| def run_task(model, tokenizer, env, task_index: int) -> Dict[str, Any]: | |
| """Run one episode. Returns result dict. Saves detailed JSON log.""" | |
| task_log = { | |
| "task_index": task_index, | |
| "model": MODEL_NAME, | |
| "adapter": ADAPTER_PATH or None, | |
| "timestamp": datetime.now().isoformat(), | |
| "steps": [], | |
| } | |
| obs = env.reset(task_index=task_index) | |
| repo_name = obs.metadata.get("repo_name", "unknown") | |
| difficulty = obs.metadata.get("difficulty", "unknown") | |
| task_log["repo_name"] = repo_name | |
| task_log["difficulty"] = difficulty | |
| task_log["initial_observation"] = obs.tool_output[:5000] | |
| log.info("━" * 60) | |
| log.info(" Task %d: %s (difficulty=%s)", task_index + 1, repo_name, difficulty) | |
| log.info("━" * 60) | |
| if obs.done: | |
| log.info(" ERROR: reset failed: %s", obs.tool_output[:300]) | |
| task_log["success"] = False | |
| task_log["error"] = obs.tool_output[:500] | |
| _save_task_log(task_log) | |
| return {"repo_name": repo_name, "difficulty": difficulty, | |
| "success": False, "steps": 0, "total_reward": 0.0} | |
| # --- RESEARCH PHASE: gather migration context --- | |
| log.info(" [RESEARCH] Running research agent...") | |
| research = ResearchAgent(model, tokenizer, max_steps=12, model_name=MODEL_NAME) | |
| # Extract task metadata from the environment | |
| task_meta = env._current_task if hasattr(env, "_current_task") and env._current_task else None | |
| old_py = task_meta.reproduction_target_version if task_meta else "3.6" | |
| new_py = task_meta.migration_target_version if task_meta else "3.12" | |
| related_mods = task_meta.related_modules if task_meta else "builtin" | |
| dep_versions = task_meta.dependency_versions if task_meta else "" | |
| research_context = research.research( | |
| repo_name=repo_name, | |
| old_python=old_py, | |
| new_python=new_py, | |
| related_modules=related_mods, | |
| test_output=obs.tool_output, | |
| dependency_versions=dep_versions, | |
| ) | |
| task_log["research_context"] = research_context | |
| task_log["research_steps"] = getattr(research, "last_research_steps", []) | |
| log.info(" [RESEARCH] Done (%d chars, %d steps)", | |
| len(research_context), len(task_log["research_steps"])) | |
| # --- BUILD SYSTEM PROMPT with research context --- | |
| system_with_research = ( | |
| SYSTEM_PROMPT | |
| + "\n\n=== MIGRATION RESEARCH (gathered by research agent) ===\n" | |
| + research_context | |
| + "\n=== END RESEARCH ===\n\n" | |
| "A research agent has already analyzed the error and found the relevant " | |
| "breaking changes above. Use this information to make the fix directly. " | |
| "Don't waste steps searching for what already has been found.\n" | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system_with_research}, | |
| {"role": "user", "content": obs.tool_output}, | |
| ] | |
| total_reward = 0.0 | |
| steps = 0 | |
| success = False | |
| last_tool_key: str = "" | |
| for step_num in range(1, MAX_STEPS + 1): | |
| if obs.done: | |
| break | |
| # Generate | |
| t0 = time.time() | |
| try: | |
| result = generate_tool_call(model, tokenizer, messages) | |
| gen_time = time.time() - t0 | |
| tool_name = result["tool_name"] | |
| tool_args = result["tool_args"] | |
| except Exception as e: | |
| gen_time = time.time() - t0 | |
| log.info(" Step %d [%.1fs]: GENERATION FAILED — %s", step_num, gen_time, e) | |
| tool_name, tool_args = "list_dir", {} | |
| result = {"raw_text": str(e), "clean_text": "", "input_tokens": 0} | |
| # Nudge on exact consecutive repetition — same tool AND same args as last step | |
| curr_key = f"{tool_name}:{json.dumps(tool_args, sort_keys=True)}" | |
| if curr_key == last_tool_key and tool_name not in ("execute_tests", "revert_last"): | |
| nudge = ( | |
| f"You just called {tool_name} with the exact same arguments. " | |
| "Do NOT repeat. Try a different action — edit a file, search something else, or run tests." | |
| ) | |
| messages.append({"role": "user", "content": nudge}) | |
| log.info(" [NUDGE] Exact repeat detected") | |
| last_tool_key = curr_key | |
| # Validate | |
| if tool_name not in _TOOL_REQUIRED_ARGS: | |
| # Invalid tool — nudge the model instead of wasting a step on list_dir | |
| nudge = ( | |
| f"Invalid tool '{tool_name}'. Output EXACTLY ONE JSON tool call.\n" | |
| f"Available tools: {', '.join(_TOOL_REQUIRED_ARGS.keys())}\n" | |
| f"Format: {{\"name\": \"tool_name\", \"arguments\": {{...}}}}" | |
| ) | |
| messages.append({"role": "user", "content": nudge}) | |
| log.info(" [NUDGE] Invalid tool '%s' — injecting correction", tool_name) | |
| continue | |
| try: | |
| action = CodeMigrationAction(tool_name=tool_name, tool_args=tool_args) | |
| except Exception: | |
| action = CodeMigrationAction(tool_name="list_dir", tool_args={}) | |
| # Execute | |
| obs = env.step(action) | |
| steps = step_num | |
| total_reward += obs.reward | |
| # Check success | |
| if action.tool_name == "execute_tests" and obs.metadata.get("last_test_exit_code") == 0: | |
| success = True | |
| # Log to console | |
| args_short = json.dumps(action.tool_args, default=str)[:200] | |
| result_short = obs.tool_output.replace("\n", " ")[:300] | |
| reward_s = f" r={obs.reward:.2f}" if abs(obs.reward) > 0.001 else "" | |
| done_s = " DONE!" if obs.done else "" | |
| log.info(" Step %d [%.1fs] %s(%s)", step_num, gen_time, action.tool_name, args_short) | |
| log.info(" → %s%s%s", result_short, reward_s, done_s) | |
| # Log to task JSON | |
| step_entry = { | |
| "step": step_num, | |
| "gen_time_s": round(gen_time, 2), | |
| "tool_name": action.tool_name, | |
| "tool_args": action.tool_args, | |
| "raw_model_output": result.get("raw_text", ""), | |
| "clean_model_output": result.get("clean_text", ""), | |
| "input_tokens": result.get("input_tokens", 0), | |
| "tool_result": obs.tool_output, | |
| "reward": obs.reward, | |
| "done": obs.done, | |
| "metadata": obs.metadata, | |
| } | |
| task_log["steps"].append(step_entry) | |
| # Update conversation | |
| messages.append({"role": "assistant", "content": json.dumps({"name": action.tool_name, "arguments": action.tool_args})}) | |
| messages.append({"role": "user", "content": f"Tool result:\n{obs.tool_output}"}) | |
| if len(messages) > 22: | |
| messages = messages[:2] + messages[-20:] | |
| if obs.done: | |
| break | |
| # Apply terminal reward: failed tasks get a penalty | |
| if not success: | |
| total_reward = -3.0 | |
| # Summary | |
| icon = "PASS" if success else "FAIL" | |
| log.info(" Result: %s | steps=%d | reward=%.2f", icon, steps, total_reward) | |
| task_log["success"] = success | |
| task_log["total_steps"] = steps | |
| task_log["total_reward"] = total_reward | |
| _save_task_log(task_log) | |
| return {"repo_name": repo_name, "difficulty": difficulty, | |
| "success": success, "steps": steps, "total_reward": total_reward} | |
| def _save_task_log(task_log: Dict) -> None: | |
| """Save per-task detailed JSON log.""" | |
| repo_safe = task_log.get("repo_name", "unknown").replace("/", "__") | |
| path = log_dir / f"task_{repo_safe}.json" | |
| with open(path, "w") as f: | |
| json.dump(task_log, f, indent=2, default=str) | |
| log.info(" Task log saved: %s", path) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| log.info("=" * 60) | |
| log.info("Code Migration Inference") | |
| log.info(" model: %s", MODEL_NAME) | |
| log.info(" adapter: %s", ADAPTER_PATH or "(none — base model)") | |
| log.info(" difficulty: %s", DIFFICULTY) | |
| log.info(" max_steps: %d", MAX_STEPS) | |
| log.info(" log_dir: %s", log_dir) | |
| log.info("=" * 60) | |
| model, tokenizer = load_model(MODEL_NAME, ADAPTER_PATH) | |
| env = CodeMigrationEnvironment( | |
| dataset_path=DATASET_PATH, | |
| max_steps=MAX_STEPS, | |
| max_test_executions=MAX_TEST_EXEC, | |
| difficulty_filter=DIFFICULTY if DIFFICULTY != "all" else None, | |
| ) | |
| num_tasks = min(TASK_LIMIT, len(env._loader)) | |
| log.info("Tasks to run: %d", num_tasks) | |
| results = [] | |
| for i in range(num_tasks): | |
| try: | |
| r = run_task(model, tokenizer, env, i) | |
| except Exception as e: | |
| log.error("Task %d crashed: %s", i, e) | |
| r = {"repo_name": "error", "difficulty": "unknown", | |
| "success": False, "steps": 0, "total_reward": 0.0} | |
| results.append(r) | |
| # Summary | |
| log.info("\n" + "=" * 60) | |
| log.info("SUMMARY") | |
| log.info("=" * 60) | |
| successes = sum(1 for r in results if r["success"]) | |
| total = len(results) | |
| avg_r = sum(r["total_reward"] for r in results) / max(total, 1) | |
| avg_s = sum(r["steps"] for r in results) / max(total, 1) | |
| log.info(" pass@1: %d/%d (%.1f%%)", successes, total, 100 * successes / max(total, 1)) | |
| log.info(" avg reward: %.3f", avg_r) | |
| log.info(" avg steps: %.1f", avg_s) | |
| for r in results: | |
| icon = "PASS" if r["success"] else "FAIL" | |
| log.info(" [%s] %s (d=%s, steps=%d, r=%.2f)", | |
| icon, r["repo_name"], r["difficulty"], r["steps"], r["total_reward"]) | |
| # Save summary | |
| summary = { | |
| "run_id": run_id, | |
| "model": MODEL_NAME, | |
| "adapter": ADAPTER_PATH or None, | |
| "mode": "trained" if ADAPTER_PATH else "base", | |
| "difficulty": DIFFICULTY, | |
| "pass_at_1": f"{successes}/{total}", | |
| "avg_reward": avg_r, | |
| "avg_steps": avg_s, | |
| "results": results, | |
| } | |
| summary_path = log_dir / "summary.json" | |
| with open(summary_path, "w") as f: | |
| json.dump(summary, f, indent=2) | |
| log.info("Summary saved: %s", summary_path) | |
| if __name__ == "__main__": | |
| main() | |