from types import SimpleNamespace import pytest import inference def make_tool_message(name: str, arguments: str): return SimpleNamespace( tool_calls=[ SimpleNamespace( function=SimpleNamespace(name=name, arguments=arguments) ) ] ) def test_request_scheduler_enforces_min_gap(): current_time = {"value": 100.0} sleeps = [] def fake_time(): return current_time["value"] def fake_sleep(seconds: float): sleeps.append(seconds) current_time["value"] += seconds scheduler = inference.RequestScheduler( min_gap_seconds=12.5, rpm_limit=5, time_fn=fake_time, sleep_fn=fake_sleep, ) scheduler.wait_for_turn() current_time["value"] += 2.0 scheduler.wait_for_turn() assert sleeps == [pytest.approx(10.5)] def test_parse_tool_call_accepts_valid_tool_call(): message = make_tool_message("run_epochs", '{"num_epochs": 8}') tool_call = inference.parse_tool_call(message) assert tool_call == {"tool_name": "run_epochs", "arguments": {"num_epochs": 8}} def test_parse_tool_call_rejects_text_only_response(): message = SimpleNamespace(tool_calls=[]) with pytest.raises(inference.InferenceError, match="exactly one tool call"): inference.parse_tool_call(message) def test_parse_tool_call_rejects_malformed_arguments(): message = make_tool_message("run_epochs", "{bad json") with pytest.raises(inference.InferenceError, match="valid JSON"): inference.parse_tool_call(message) def test_request_action_retries_rate_limit_with_retry_after(monkeypatch): sleeps = [] monkeypatch.setattr(inference.time, "sleep", lambda seconds: sleeps.append(seconds)) class DummyRateLimitError(Exception): def __init__(self, retry_after: str): self.response = SimpleNamespace(headers={"retry-after": retry_after}) monkeypatch.setattr(inference, "RateLimitError", DummyRateLimitError) responses = [ DummyRateLimitError("7"), SimpleNamespace( choices=[SimpleNamespace(message=make_tool_message("submit_model", "{}"))] ), ] class FakeCompletions: def create(self, **kwargs): response = responses.pop(0) if isinstance(response, Exception): raise response return response client = SimpleNamespace(chat=SimpleNamespace(completions=FakeCompletions())) scheduler = inference.RequestScheduler( min_gap_seconds=0.0, rpm_limit=5, time_fn=lambda: 0.0, sleep_fn=lambda seconds: None, ) stats = inference.TaskStats() tool_call = inference.request_action( client=client, scheduler=scheduler, messages=[{"role": "user", "content": "test"}], stats=stats, decision_index=1, max_decisions=3, ) assert tool_call["tool_name"] == "submit_model" assert stats.requests == 2 assert stats.retries == 1 assert sleeps == [7.0] def test_extract_reset_metadata_reads_wrapped_result_payload(): payload = { "observation": { "metadata": {}, "result": { "data": { "task_name": "MNIST Digit Classifier", "difficulty": "easy", "dataset": "mnist", "max_epochs": 100, } }, } } metadata = inference.extract_reset_metadata(payload) assert metadata["task_name"] == "MNIST Digit Classifier" assert metadata["difficulty"] == "easy" def test_extract_result_data_reads_json_content_text(): payload = { "observation": { "result": { "content": [ { "type": "text", "text": '{"status":"configured","metrics":{"current_epoch":0,"remaining_budget":100,"train_loss":0.0,"val_loss":0.0,"train_accuracy":0.0,"val_accuracy":0.0,"best_val_accuracy":0.0,"convergence_signal":"not_started","is_diverged":false}}', } ] } } } result = inference.extract_result_data(payload) normalized = inference.normalize_tool_result(result) assert normalized["current_epoch"] == 0 assert normalized["remaining_budget"] == 100 assert normalized["status"] == "configured" def test_run_task_forces_configure_then_submit(monkeypatch): monkeypatch.setitem(inference.LLM_MAX_STEPS, "easy_mnist", 2) tool_choices = [] responses = [ SimpleNamespace( choices=[ SimpleNamespace( message=make_tool_message( "configure_training", '{"optimizer":"adam","learning_rate":0.001,"batch_size":64,"weight_decay":0.0,"dropout":0.0,"lr_schedule":"cosine","warmup_epochs":3,"augmentation":false,"augmentation_strength":0.0}', ) ) ] ), SimpleNamespace( choices=[SimpleNamespace(message=make_tool_message("submit_model", "{}"))] ), ] class FakeCompletions: def create(self, **kwargs): tool_choices.append(kwargs["tool_choice"]) return responses.pop(0) client = SimpleNamespace(chat=SimpleNamespace(completions=FakeCompletions())) scheduler = inference.RequestScheduler( min_gap_seconds=0.0, rpm_limit=5, time_fn=lambda: 0.0, sleep_fn=lambda seconds: None, ) class FakeEnv: def sync(self): return self def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return None def reset(self, **kwargs): return SimpleNamespace( observation=SimpleNamespace( metadata={ "task_id": kwargs["task_id"], "task_name": "MNIST Digit Classifier", "task_description": "desc", "difficulty": "easy", "model_type": "simple_mlp", "dataset": "mnist", "max_epochs": 100, "target_metric": "val_accuracy", "target_value": 0.96, } ) ) def step(self, action): if action.tool_name == "configure_training": return SimpleNamespace( observation=SimpleNamespace( metadata={}, result={ "data": { "status": "configured", "metrics": { "current_epoch": 0, "remaining_budget": 100, "train_loss": 0.0, "val_loss": 0.0, "train_accuracy": 0.0, "val_accuracy": 0.0, "best_val_accuracy": 0.0, "convergence_signal": "not_started", "is_diverged": False, "current_config": { "optimizer": "adam", "learning_rate": 0.001, }, }, } }, ), reward=0.0, done=False, ) return SimpleNamespace( observation=SimpleNamespace( metadata={}, result={"data": {"grade": {"score": 0.8}}}, ), reward=0.8, done=True, ) monkeypatch.setattr(inference, "MLTrainerEnv", lambda base_url: FakeEnv()) result = inference.run_task(client, scheduler, "easy_mnist") assert tool_choices[0] == {"type": "function", "function": {"name": "configure_training"}} assert tool_choices[-1] == {"type": "function", "function": {"name": "submit_model"}} assert result["llm_decisions"] == 2 assert result["final_score"] == 0.8 def test_run_task_reuses_same_env_session(monkeypatch): calls = [] class FakeCompletions: def create(self, **kwargs): return SimpleNamespace( choices=[SimpleNamespace(message=make_tool_message("submit_model", "{}"))] ) client = SimpleNamespace(chat=SimpleNamespace(completions=FakeCompletions())) scheduler = inference.RequestScheduler( min_gap_seconds=0.0, rpm_limit=5, time_fn=lambda: 0.0, sleep_fn=lambda seconds: None, ) monkeypatch.setitem(inference.LLM_MAX_STEPS, "easy_mnist", 1) class FakeEnv: def sync(self): return self def __enter__(self): calls.append("enter") return self def __exit__(self, exc_type, exc, tb): calls.append("exit") return None def reset(self, **kwargs): calls.append(("reset", kwargs["task_id"])) return SimpleNamespace( observation=SimpleNamespace( metadata={ "task_id": kwargs["task_id"], "task_name": "MNIST Digit Classifier", "difficulty": "easy", "dataset": "mnist", "max_epochs": 100, "target_metric": "val_accuracy", "target_value": 0.96, } ) ) def step(self, action): calls.append(("step", action.tool_name)) return SimpleNamespace( observation=SimpleNamespace( metadata={}, result={"data": {"grade": {"score": 0.7}}}, ), reward=0.7, done=True, ) monkeypatch.setattr(inference, "MLTrainerEnv", lambda base_url: FakeEnv()) result = inference.run_task(client, scheduler, "easy_mnist") assert result["final_score"] == 0.7 assert calls == ["enter", ("reset", "easy_mnist"), ("step", "submit_model"), "exit"] def test_request_action_does_not_send_seed(monkeypatch): captured_kwargs = {} class FakeCompletions: def create(self, **kwargs): captured_kwargs.update(kwargs) return SimpleNamespace( choices=[SimpleNamespace(message=make_tool_message("submit_model", "{}"))] ) client = SimpleNamespace(chat=SimpleNamespace(completions=FakeCompletions())) scheduler = inference.RequestScheduler( min_gap_seconds=0.0, rpm_limit=5, time_fn=lambda: 0.0, sleep_fn=lambda seconds: None, ) stats = inference.TaskStats() inference.request_action( client=client, scheduler=scheduler, messages=[{"role": "user", "content": "test"}], stats=stats, decision_index=1, max_decisions=3, ) assert "seed" not in captured_kwargs def test_apply_tool_context_preserves_current_config(): configured = inference.apply_tool_context( "configure_training", {"optimizer": "adam", "learning_rate": 0.001, "batch_size": 64}, {}, {"current_epoch": 0, "remaining_budget": 100}, ) assert configured["current_config"]["optimizer"] == "adam" updated = inference.apply_tool_context( "run_epochs", {"num_epochs": 10}, configured, {"current_epoch": 10, "best_val_accuracy": 0.95}, ) assert updated["current_config"]["batch_size"] == 64 assert updated["best_val_accuracy"] == 0.95 def test_merge_task_metadata_fills_missing_fields(): merged = inference.merge_task_metadata("easy_mnist", {"task_id": "easy_mnist"}) assert merged["difficulty"] == "easy" assert merged["dataset"] == "mnist" assert merged["max_epochs"] == 100