File size: 5,751 Bytes
9aa5185 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """Tests for batch_runner checkpoint behavior — incremental writes, resume, atomicity."""
import json
import os
from pathlib import Path
from threading import Lock
from unittest.mock import patch, MagicMock
import pytest
# batch_runner uses relative imports, ensure project root is on path
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from batch_runner import BatchRunner
@pytest.fixture
def runner(tmp_path):
"""Create a BatchRunner with all paths pointing at tmp_path."""
prompts_file = tmp_path / "prompts.jsonl"
prompts_file.write_text("")
output_file = tmp_path / "output.jsonl"
checkpoint_file = tmp_path / "checkpoint.json"
r = BatchRunner.__new__(BatchRunner)
r.run_name = "test_run"
r.checkpoint_file = checkpoint_file
r.output_file = output_file
r.prompts_file = prompts_file
return r
class TestSaveCheckpoint:
"""Verify _save_checkpoint writes valid, atomic JSON."""
def test_writes_valid_json(self, runner):
data = {"run_name": "test", "completed_prompts": [1, 2, 3], "batch_stats": {}}
runner._save_checkpoint(data)
result = json.loads(runner.checkpoint_file.read_text())
assert result["run_name"] == "test"
assert result["completed_prompts"] == [1, 2, 3]
def test_adds_last_updated(self, runner):
data = {"run_name": "test", "completed_prompts": []}
runner._save_checkpoint(data)
result = json.loads(runner.checkpoint_file.read_text())
assert "last_updated" in result
assert result["last_updated"] is not None
def test_overwrites_previous_checkpoint(self, runner):
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1]})
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1, 2, 3]})
result = json.loads(runner.checkpoint_file.read_text())
assert result["completed_prompts"] == [1, 2, 3]
def test_with_lock(self, runner):
lock = Lock()
data = {"run_name": "test", "completed_prompts": [42]}
runner._save_checkpoint(data, lock=lock)
result = json.loads(runner.checkpoint_file.read_text())
assert result["completed_prompts"] == [42]
def test_without_lock(self, runner):
data = {"run_name": "test", "completed_prompts": [99]}
runner._save_checkpoint(data, lock=None)
result = json.loads(runner.checkpoint_file.read_text())
assert result["completed_prompts"] == [99]
def test_creates_parent_dirs(self, tmp_path):
runner_deep = BatchRunner.__new__(BatchRunner)
runner_deep.checkpoint_file = tmp_path / "deep" / "nested" / "checkpoint.json"
data = {"run_name": "test", "completed_prompts": []}
runner_deep._save_checkpoint(data)
assert runner_deep.checkpoint_file.exists()
def test_no_temp_files_left(self, runner):
runner._save_checkpoint({"run_name": "test", "completed_prompts": []})
tmp_files = [f for f in runner.checkpoint_file.parent.iterdir()
if ".tmp" in f.name]
assert len(tmp_files) == 0
class TestLoadCheckpoint:
"""Verify _load_checkpoint reads existing data or returns defaults."""
def test_returns_empty_when_no_file(self, runner):
result = runner._load_checkpoint()
assert result.get("completed_prompts", []) == []
def test_loads_existing_checkpoint(self, runner):
data = {"run_name": "test_run", "completed_prompts": [5, 10, 15],
"batch_stats": {"0": {"processed": 3}}}
runner.checkpoint_file.write_text(json.dumps(data))
result = runner._load_checkpoint()
assert result["completed_prompts"] == [5, 10, 15]
assert result["batch_stats"]["0"]["processed"] == 3
def test_handles_corrupt_json(self, runner):
runner.checkpoint_file.write_text("{broken json!!")
result = runner._load_checkpoint()
# Should return empty/default, not crash
assert isinstance(result, dict)
class TestResumePreservesProgress:
"""Verify that initializing a run with resume=True loads prior checkpoint."""
def test_completed_prompts_loaded_from_checkpoint(self, runner):
# Simulate a prior run that completed prompts 0-4
prior = {
"run_name": "test_run",
"completed_prompts": [0, 1, 2, 3, 4],
"batch_stats": {"0": {"processed": 5}},
"last_updated": "2026-01-01T00:00:00",
}
runner.checkpoint_file.write_text(json.dumps(prior))
# Load checkpoint like run() does
checkpoint_data = runner._load_checkpoint()
if checkpoint_data.get("run_name") != runner.run_name:
checkpoint_data = {
"run_name": runner.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None,
}
completed_set = set(checkpoint_data.get("completed_prompts", []))
assert completed_set == {0, 1, 2, 3, 4}
def test_different_run_name_starts_fresh(self, runner):
prior = {
"run_name": "different_run",
"completed_prompts": [0, 1, 2],
"batch_stats": {},
}
runner.checkpoint_file.write_text(json.dumps(prior))
checkpoint_data = runner._load_checkpoint()
if checkpoint_data.get("run_name") != runner.run_name:
checkpoint_data = {
"run_name": runner.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None,
}
assert checkpoint_data["completed_prompts"] == []
assert checkpoint_data["run_name"] == "test_run"
|