File size: 25,697 Bytes
1b35d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfa6ae5
1b35d41
 
 
503bd66
1b35d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfa6ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b35d41
 
bfa6ae5
 
 
 
 
1b35d41
bfa6ae5
1b35d41
 
 
bfa6ae5
 
1b35d41
 
 
 
 
 
 
 
 
 
bfa6ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b35d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfa6ae5
 
 
 
 
 
 
 
 
 
1b35d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfa6ae5
 
1b35d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfa6ae5
 
1b35d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0200006
 
 
 
 
 
 
 
 
1b35d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
"""
Inference Script β€” Code Migration Environment
==============================================

Runs migration tasks using a locally loaded model with 4-bit quantization.
Logs everything to files: console log, per-task JSON with all steps/actions/outputs.

Environment variables:
    MODEL_NAME       (default: google/gemma-4-E4B-it)
    DATASET_PATH     (default: bundled verified dataset)
    DIFFICULTY       (default: all)
    MAX_STEPS        (default: 30)
    MAX_TEST_EXEC    (default: 5)
    TASK_LIMIT       (default: 3)
    LOG_DIR          (default: ./logs)
"""

from __future__ import annotations

import json
import logging
import os
import re
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from code_migration.models import CodeMigrationAction, _TOOL_REQUIRED_ARGS
from code_migration.server.code_migration_environment import CodeMigrationEnvironment
from code_migration.research_agent import ResearchAgent

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-4-E4B-it")
ADAPTER_PATH = os.getenv("ADAPTER_PATH", None)  # path to trained LoRA adapter
DATASET_PATH = os.getenv("DATASET_PATH", os.path.join(os.path.dirname(__file__), "data", "eval.jsonl"))
DIFFICULTY = os.getenv("DIFFICULTY", "all")
MAX_STEPS = int(os.getenv("MAX_STEPS", "30"))
MAX_TEST_EXEC = int(os.getenv("MAX_TEST_EXEC", "5"))
TASK_LIMIT = int(os.getenv("TASK_LIMIT", "9999"))  # default: run all tasks
LOG_DIR = os.getenv("LOG_DIR", "./logs")
TEMPERATURE = 0.3
MAX_NEW_TOKENS = 400

# ---------------------------------------------------------------------------
# Logging setup β€” console + file
# ---------------------------------------------------------------------------
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = Path(LOG_DIR) / run_id
log_dir.mkdir(parents=True, exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%H:%M:%S",
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler(log_dir / "console.log"),
    ],
)
log = logging.getLogger("inference")


# ---------------------------------------------------------------------------
# Tool block for system prompt
# ---------------------------------------------------------------------------
TOOL_BLOCK = """Available tools:
- list_dir(dir_path?): List files/subdirs (default /work)
- search_dir(regex_pattern, dir_path?): Search .py file contents for regex
- search_file(regex_pattern, file_path): Search one file for regex
- view_file(file_path, line_no): View Β±50 lines around line_no
- edit_file(file_path, start_line, end_line, replacement_text): Replace lines
- replace_all_in_file(file_path, regex_pattern, replacement_string): Regex replace
- revert_last(): Undo last edit
- execute_tests(): Run tests in Docker
- search_last_log(regex_pattern): Search last test log
- view_last_log(line_no): View last test log"""

SYSTEM_PROMPT = (
    "You are an expert Python developer fixing failing tests after dependency upgrades.\n\n"
    + TOOL_BLOCK + "\n\n"
    "RULES:\n"
    "- Output EXACTLY ONE JSON tool call: {\"name\": \"...\", \"arguments\": {...}}\n"
    "- Do NOT repeat the same action. Act on info you already have.\n"
    "- search_dir searches file CONTENTS not filenames.\n"
    "- Be decisive: view error β†’ find code β†’ edit β†’ test. 4-8 steps.\n"
    "- NEVER make the same tool call with the same arguments twice in a row. Do something different first.\n"
    "- execute_tests can be re-run after making edits β€” that's expected.\n"
    "- Don't re-read files you already have in context. Use the info from previous steps.\n"
    "- If the research agent already found the fix pattern, apply it directly β€” don't search again.\n"
    "- CRITICAL: Line numbers in test logs are TEST LOG line numbers, NOT source file line numbers.\n"
    "  Always use search_file or view_file to find the ACTUAL line number in the source file before editing.\n"
    "  Use replace_all_in_file when possible β€” it doesn't need line numbers and is safer.\n"
)


# ---------------------------------------------------------------------------
# Model family detection
# ---------------------------------------------------------------------------
def _detect_model_family(model_name: str) -> str:
    """Detect model family from model name string.

    Returns: 'gemma', 'qwen3', 'qwen2', or 'unknown'
    Qwen3/3.5 uses <think>...</think> blocks and enable_thinking param.
    Qwen2.5 uses <|im_end|> tokens, no thinking.
    Gemma 4 uses <|channel>thought...<channel|> blocks and <|think|> token.
    """
    name_lower = model_name.lower()
    if "gemma" in name_lower:
        return "gemma"
    if "qwen3" in name_lower:
        return "qwen3"
    if "qwen" in name_lower:
        return "qwen2"
    return "unknown"


MODEL_FAMILY = _detect_model_family(MODEL_NAME)


# ---------------------------------------------------------------------------
# Device detection
# ---------------------------------------------------------------------------
def _get_device() -> str:
    """Return best available device: 'cuda', 'mps', or 'cpu'."""
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    return "cpu"

DEVICE = _get_device()


# ---------------------------------------------------------------------------
# Model loading β€” supports CUDA (4-bit), MPS (float16), CPU (float32)
# ---------------------------------------------------------------------------
def load_model(model_name: str, adapter_path: str = None):
    """Load model on the best available device.

    - CUDA: 4-bit NF4 quantization via bitsandbytes
    - MPS (Apple Silicon): float16, no quantization
    - CPU: float32 fallback

    Supports Gemma 4, Qwen 3.5, and Qwen 2.5 model families.
    If adapter_path is provided, loads a trained LoRA adapter on top.
    """
    family = _detect_model_family(model_name)
    device = _get_device()
    log.info("Loading %s (family=%s) on device=%s", model_name, family, device)
    if adapter_path:
        log.info("  + LoRA adapter from: %s", adapter_path)
    else:
        log.info("  (base model, no adapter)")
    t0 = time.time()

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    if device == "cuda":
        # CUDA: use 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map={"": 0},
            trust_remote_code=True,
        )
    elif device == "mps":
        # Apple Silicon: float16, no quantization
        log.info("  Using float16 on MPS (Apple Silicon)")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            trust_remote_code=True,
        ).to("mps")
    else:
        # CPU fallback: float32
        log.info("  Using float32 on CPU (this will be slow)")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            trust_remote_code=True,
        )

    # Gemma 4 specific: unwrap ClippableLinear before LoRA
    if family == "gemma":
        replacements = []
        for name, module in model.named_modules():
            if type(module).__name__ == "Gemma4ClippableLinear":
                if hasattr(module, "linear"):
                    replacements.append((name, module.linear))
        for name, inner in replacements:
            parts = name.split(".")
            parent = model.get_submodule(".".join(parts[:-1])) if len(parts) > 1 else model
            setattr(parent, parts[-1], inner)
        if replacements:
            log.info("Unwrapped %d ClippableLinear modules", len(replacements))

    # Load LoRA adapter if provided
    if adapter_path:
        from peft import PeftModel
        log.info("Loading LoRA adapter...")
        model = PeftModel.from_pretrained(model, adapter_path)
        log.info("LoRA adapter loaded.")

    elapsed = time.time() - t0
    if device == "cuda":
        mem_gb = torch.cuda.memory_allocated(0) / 1e9
    elif device == "mps":
        # MPS doesn't have a direct memory query, estimate from model size
        param_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
        mem_gb = param_bytes / 1e9
    else:
        param_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
        mem_gb = param_bytes / 1e9
    log.info("Loaded in %.1fs | memory: ~%.2f GB | device: %s", elapsed, mem_gb, device)

    return model, tokenizer


# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------
def _strip_model_artifacts(raw_text: str, family: str) -> str:
    """Strip model-specific artifacts from generated text.

    Gemma 4: thinking blocks <|channel>thought...<channel|>, special tokens
    Qwen 3/3.5: thinking blocks <think>...</think>
    Qwen 2.5: <|im_end|>, <|endoftext|>

    CRITICAL: After stripping, truncate at the first <|im_end|> or <|im_start|>
    to prevent the model from hallucinating multi-turn conversations.
    """
    clean = raw_text

    if family == "gemma":
        clean = re.sub(r"<\|channel>thought\n.*?<channel\|>", "", clean, flags=re.DOTALL)
        for tok in ["<turn|>", "<|turn>", "<eos>", "</s>"]:
            clean = clean.replace(tok, "")
    elif family == "qwen3":
        # Strip <think>...</think> blocks first
        clean = re.sub(r"<think>.*?</think>", "", clean, flags=re.DOTALL)
        # Truncate at first <|im_end|> β€” everything after is hallucinated
        im_end = clean.find("<|im_end|>")
        if im_end != -1:
            clean = clean[:im_end]
        for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
            clean = clean.replace(tok, "")
    elif family == "qwen2":
        # Truncate at first <|im_end|>
        im_end = clean.find("<|im_end|>")
        if im_end != -1:
            clean = clean[:im_end]
        for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
            clean = clean.replace(tok, "")
    else:
        # Generic cleanup
        for tok in ["<eos>", "</s>", "<|im_end|>", "<|endoftext|>", "<turn|>", "<|turn>"]:
            clean = clean.replace(tok, "")

    return clean.strip()


def generate_tool_call(model, tokenizer, messages: List[Dict]) -> Dict[str, Any]:
    """Generate one tool call. Handles Gemma 4, Qwen 3.5, and Qwen 2.5 model families."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Qwen3/3.5: disable thinking mode for direct JSON output
    if MODEL_FAMILY == "qwen3":
        text = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True,
            enable_thinking=False,
        )
    else:
        text = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True,
        )

    # Gemma 4: strip thinking trigger to disable thinking mode
    if MODEL_FAMILY == "gemma":
        text = text.replace("<|think|>", "")

    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    input_len = inputs["input_ids"].shape[-1]

    # Build stop token IDs to prevent hallucinated multi-turn generation
    gen_kwargs = dict(
        max_new_tokens=MAX_NEW_TOKENS,
        temperature=TEMPERATURE,
        top_p=0.95,
        top_k=50,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
    )

    # For Qwen models, add <|im_end|> as a stop token
    if MODEL_FAMILY in ("qwen3", "qwen2"):
        stop_ids = []
        for tok_str in ["<|im_end|>", "<|endoftext|>"]:
            tid = tokenizer.convert_tokens_to_ids(tok_str)
            if tid is not None and tid != tokenizer.unk_token_id:
                stop_ids.append(tid)
        if stop_ids:
            eos = gen_kwargs.get("eos_token_id", tokenizer.eos_token_id)
            if isinstance(eos, int):
                stop_ids.append(eos)
            elif isinstance(eos, list):
                stop_ids.extend(eos)
            gen_kwargs["eos_token_id"] = list(set(stop_ids))

    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)

    raw_text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=False)
    del inputs, outputs
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    clean = _strip_model_artifacts(raw_text, MODEL_FAMILY)

    parsed = _parse_tool_call(clean)
    parsed["raw_text"] = raw_text
    parsed["clean_text"] = clean
    parsed["input_tokens"] = input_len
    return parsed


def _parse_tool_call(text: str) -> Dict[str, Any]:
    """Parse JSON tool call from model output.

    Finds the FIRST complete JSON object β€” ignores any hallucinated
    multi-turn content that may follow.
    """
    text = text.strip()
    # Strip markdown fences
    if text.startswith("```"):
        lines = text.split("\n")
        lines = [l for l in lines if not l.strip().startswith("```")]
        text = "\n".join(lines).strip()

    # Find the first { and then find its matching }
    start = text.find("{")
    if start == -1:
        return {"tool_name": "list_dir", "tool_args": {}}

    # Try progressively longer substrings to find valid JSON
    depth = 0
    for i in range(start, len(text)):
        if text[i] == "{":
            depth += 1
        elif text[i] == "}":
            depth -= 1
            if depth == 0:
                candidate = text[start:i + 1]
                try:
                    data = json.loads(candidate)
                    if "tool_name" in data:
                        return {"tool_name": data["tool_name"], "tool_args": data.get("tool_args", {})}
                    if "name" in data:
                        return {"tool_name": data["name"], "tool_args": data.get("arguments", data.get("parameters", {}))}
                    if "action" in data:
                        action = data.pop("action")
                        return {"tool_name": action, "tool_args": data}
                except json.JSONDecodeError:
                    pass
                # First balanced braces didn't parse β€” keep looking
                break

    # Fallback: try rfind approach
    end = text.rfind("}")
    if start != -1 and end != -1 and end > start:
        try:
            data = json.loads(text[start:end + 1])
            if "name" in data:
                return {"tool_name": data["name"], "tool_args": data.get("arguments", data.get("parameters", {}))}
        except json.JSONDecodeError:
            pass

    return {"tool_name": "list_dir", "tool_args": {}}


# ---------------------------------------------------------------------------
# Run a single task with full logging
# ---------------------------------------------------------------------------
def run_task(model, tokenizer, env, task_index: int) -> Dict[str, Any]:
    """Run one episode. Returns result dict. Saves detailed JSON log."""

    task_log = {
        "task_index": task_index,
        "model": MODEL_NAME,
        "adapter": ADAPTER_PATH or None,
        "timestamp": datetime.now().isoformat(),
        "steps": [],
    }

    obs = env.reset(task_index=task_index)
    repo_name = obs.metadata.get("repo_name", "unknown")
    difficulty = obs.metadata.get("difficulty", "unknown")
    task_log["repo_name"] = repo_name
    task_log["difficulty"] = difficulty
    task_log["initial_observation"] = obs.tool_output[:5000]

    log.info("━" * 60)
    log.info("  Task %d: %s (difficulty=%s)", task_index + 1, repo_name, difficulty)
    log.info("━" * 60)

    if obs.done:
        log.info("  ERROR: reset failed: %s", obs.tool_output[:300])
        task_log["success"] = False
        task_log["error"] = obs.tool_output[:500]
        _save_task_log(task_log)
        return {"repo_name": repo_name, "difficulty": difficulty,
                "success": False, "steps": 0, "total_reward": 0.0}

    # --- RESEARCH PHASE: gather migration context ---
    log.info("  [RESEARCH] Running research agent...")
    research = ResearchAgent(model, tokenizer, max_steps=12, model_name=MODEL_NAME)

    # Extract task metadata from the environment
    task_meta = env._current_task if hasattr(env, "_current_task") and env._current_task else None
    old_py = task_meta.reproduction_target_version if task_meta else "3.6"
    new_py = task_meta.migration_target_version if task_meta else "3.12"
    related_mods = task_meta.related_modules if task_meta else "builtin"
    dep_versions = task_meta.dependency_versions if task_meta else ""

    research_context = research.research(
        repo_name=repo_name,
        old_python=old_py,
        new_python=new_py,
        related_modules=related_mods,
        test_output=obs.tool_output,
        dependency_versions=dep_versions,
    )

    task_log["research_context"] = research_context
    task_log["research_steps"] = getattr(research, "last_research_steps", [])
    log.info("  [RESEARCH] Done (%d chars, %d steps)",
             len(research_context), len(task_log["research_steps"]))

    # --- BUILD SYSTEM PROMPT with research context ---
    system_with_research = (
        SYSTEM_PROMPT
        + "\n\n=== MIGRATION RESEARCH (gathered by research agent) ===\n"
        + research_context
        + "\n=== END RESEARCH ===\n\n"
        "A research agent has already analyzed the error and found the relevant "
        "breaking changes above. Use this information to make the fix directly. "
        "Don't waste steps searching for what already has been found.\n"
    )

    messages = [
        {"role": "system", "content": system_with_research},
        {"role": "user", "content": obs.tool_output},
    ]

    total_reward = 0.0
    steps = 0
    success = False
    last_tool_key: str = ""

    for step_num in range(1, MAX_STEPS + 1):
        if obs.done:
            break

        # Generate
        t0 = time.time()
        try:
            result = generate_tool_call(model, tokenizer, messages)
            gen_time = time.time() - t0
            tool_name = result["tool_name"]
            tool_args = result["tool_args"]
        except Exception as e:
            gen_time = time.time() - t0
            log.info("  Step %d [%.1fs]: GENERATION FAILED β€” %s", step_num, gen_time, e)
            tool_name, tool_args = "list_dir", {}
            result = {"raw_text": str(e), "clean_text": "", "input_tokens": 0}

        # Nudge on exact consecutive repetition β€” same tool AND same args as last step
        curr_key = f"{tool_name}:{json.dumps(tool_args, sort_keys=True)}"
        if curr_key == last_tool_key and tool_name not in ("execute_tests", "revert_last"):
            nudge = (
                f"You just called {tool_name} with the exact same arguments. "
                "Do NOT repeat. Try a different action β€” edit a file, search something else, or run tests."
            )
            messages.append({"role": "user", "content": nudge})
            log.info("  [NUDGE] Exact repeat detected")
        last_tool_key = curr_key

        # Validate
        if tool_name not in _TOOL_REQUIRED_ARGS:
            # Invalid tool β€” nudge the model instead of wasting a step on list_dir
            nudge = (
                f"Invalid tool '{tool_name}'. Output EXACTLY ONE JSON tool call.\n"
                f"Available tools: {', '.join(_TOOL_REQUIRED_ARGS.keys())}\n"
                f"Format: {{\"name\": \"tool_name\", \"arguments\": {{...}}}}"
            )
            messages.append({"role": "user", "content": nudge})
            log.info("  [NUDGE] Invalid tool '%s' β€” injecting correction", tool_name)
            continue
        try:
            action = CodeMigrationAction(tool_name=tool_name, tool_args=tool_args)
        except Exception:
            action = CodeMigrationAction(tool_name="list_dir", tool_args={})

        # Execute
        obs = env.step(action)
        steps = step_num
        total_reward += obs.reward

        # Check success
        if action.tool_name == "execute_tests" and obs.metadata.get("last_test_exit_code") == 0:
            success = True

        # Log to console
        args_short = json.dumps(action.tool_args, default=str)[:200]
        result_short = obs.tool_output.replace("\n", " ")[:300]
        reward_s = f" r={obs.reward:.2f}" if abs(obs.reward) > 0.001 else ""
        done_s = " DONE!" if obs.done else ""
        log.info("  Step %d [%.1fs] %s(%s)", step_num, gen_time, action.tool_name, args_short)
        log.info("    β†’ %s%s%s", result_short, reward_s, done_s)

        # Log to task JSON
        step_entry = {
            "step": step_num,
            "gen_time_s": round(gen_time, 2),
            "tool_name": action.tool_name,
            "tool_args": action.tool_args,
            "raw_model_output": result.get("raw_text", ""),
            "clean_model_output": result.get("clean_text", ""),
            "input_tokens": result.get("input_tokens", 0),
            "tool_result": obs.tool_output,
            "reward": obs.reward,
            "done": obs.done,
            "metadata": obs.metadata,
        }
        task_log["steps"].append(step_entry)

        # Update conversation
        messages.append({"role": "assistant", "content": json.dumps({"name": action.tool_name, "arguments": action.tool_args})})
        messages.append({"role": "user", "content": f"Tool result:\n{obs.tool_output}"})
        if len(messages) > 22:
            messages = messages[:2] + messages[-20:]

        if obs.done:
            break

    # Apply terminal reward: failed tasks get a penalty
    if not success:
        total_reward = -3.0

    # Summary
    icon = "PASS" if success else "FAIL"
    log.info("  Result: %s | steps=%d | reward=%.2f", icon, steps, total_reward)

    task_log["success"] = success
    task_log["total_steps"] = steps
    task_log["total_reward"] = total_reward
    _save_task_log(task_log)

    return {"repo_name": repo_name, "difficulty": difficulty,
            "success": success, "steps": steps, "total_reward": total_reward}


def _save_task_log(task_log: Dict) -> None:
    """Save per-task detailed JSON log."""
    repo_safe = task_log.get("repo_name", "unknown").replace("/", "__")
    path = log_dir / f"task_{repo_safe}.json"
    with open(path, "w") as f:
        json.dump(task_log, f, indent=2, default=str)
    log.info("  Task log saved: %s", path)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
    log.info("=" * 60)
    log.info("Code Migration Inference")
    log.info("  model:      %s", MODEL_NAME)
    log.info("  adapter:    %s", ADAPTER_PATH or "(none β€” base model)")
    log.info("  difficulty:  %s", DIFFICULTY)
    log.info("  max_steps:   %d", MAX_STEPS)
    log.info("  log_dir:     %s", log_dir)
    log.info("=" * 60)

    model, tokenizer = load_model(MODEL_NAME, ADAPTER_PATH)

    env = CodeMigrationEnvironment(
        dataset_path=DATASET_PATH,
        max_steps=MAX_STEPS,
        max_test_executions=MAX_TEST_EXEC,
        difficulty_filter=DIFFICULTY if DIFFICULTY != "all" else None,
    )

    num_tasks = min(TASK_LIMIT, len(env._loader))
    log.info("Tasks to run: %d", num_tasks)

    results = []
    for i in range(num_tasks):
        try:
            r = run_task(model, tokenizer, env, i)
        except Exception as e:
            log.error("Task %d crashed: %s", i, e)
            r = {"repo_name": "error", "difficulty": "unknown",
                 "success": False, "steps": 0, "total_reward": 0.0}
        results.append(r)

    # Summary
    log.info("\n" + "=" * 60)
    log.info("SUMMARY")
    log.info("=" * 60)

    successes = sum(1 for r in results if r["success"])
    total = len(results)
    avg_r = sum(r["total_reward"] for r in results) / max(total, 1)
    avg_s = sum(r["steps"] for r in results) / max(total, 1)

    log.info("  pass@1:     %d/%d (%.1f%%)", successes, total, 100 * successes / max(total, 1))
    log.info("  avg reward: %.3f", avg_r)
    log.info("  avg steps:  %.1f", avg_s)

    for r in results:
        icon = "PASS" if r["success"] else "FAIL"
        log.info("  [%s] %s (d=%s, steps=%d, r=%.2f)",
                 icon, r["repo_name"], r["difficulty"], r["steps"], r["total_reward"])

    # Save summary
    summary = {
        "run_id": run_id,
        "model": MODEL_NAME,
        "adapter": ADAPTER_PATH or None,
        "mode": "trained" if ADAPTER_PATH else "base",
        "difficulty": DIFFICULTY,
        "pass_at_1": f"{successes}/{total}",
        "avg_reward": avg_r,
        "avg_steps": avg_s,
        "results": results,
    }
    summary_path = log_dir / "summary.json"
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)
    log.info("Summary saved: %s", summary_path)


if __name__ == "__main__":
    main()