import json from pathlib import Path from legacy_cobol_env.eval.providers import create_provider from legacy_cobol_env.eval.providers import LocalTransformersProvider from legacy_cobol_env.training.train_sft import SFTArgs, build_sft_plan, load_jsonl_rows, write_dry_run_artifacts def test_build_sft_plan_reads_dataset_without_gpu_dependencies(tmp_path: Path): dataset = tmp_path / "tiny.jsonl" dataset.write_text( json.dumps( { "task_id": "invoice_occurs_001", "family_id": "invoice_occurs_totals", "messages": [ {"role": "user", "content": "prompt"}, {"role": "assistant", "content": "{\"code\": \"def migrate(input_record: str) -> str:\\n return input_record\\n\"}"}, ], } ) + "\n", encoding="utf-8", ) plan = build_sft_plan(SFTArgs(dataset=str(dataset), output_dir=str(tmp_path / "out"))) assert plan["dataset_examples"] == 1 assert plan["model_name"] assert plan["uses_lora"] is True assert plan["output_dir"] == str(tmp_path / "out") def test_load_jsonl_rows_rejects_missing_messages(tmp_path: Path): dataset = tmp_path / "bad.jsonl" dataset.write_text(json.dumps({"task_id": "x"}) + "\n", encoding="utf-8") try: load_jsonl_rows(dataset) except ValueError as exc: assert "messages" in str(exc) else: raise AssertionError("expected invalid dataset row to fail") def test_local_transformers_provider_requires_model_path(): try: create_provider("local-transformers", {}) except ValueError as exc: assert "LOCAL_MODEL_PATH" in str(exc) else: raise AssertionError("expected missing local model path to fail") def test_local_transformers_provider_detects_peft_adapter_config(tmp_path: Path): adapter_dir = tmp_path / "adapter" adapter_dir.mkdir() (adapter_dir / "adapter_config.json").write_text(json.dumps({"base_model_name_or_path": "base-model"}), encoding="utf-8") provider = LocalTransformersProvider(model_path=str(adapter_dir)) assert provider._adapter_base_model_path() == "base-model" def test_local_transformers_provider_env_accepts_base_model_path(tmp_path: Path): adapter_dir = tmp_path / "adapter" adapter_dir.mkdir() (adapter_dir / "adapter_config.json").write_text("{}", encoding="utf-8") provider = create_provider( "local-transformers", {"LOCAL_MODEL_PATH": str(adapter_dir), "LOCAL_BASE_MODEL_PATH": "Qwen/base"}, ) assert provider.base_model_path == "Qwen/base" assert provider.load_in_4bit is True def test_local_transformers_provider_can_disable_4bit(tmp_path: Path): model_dir = tmp_path / "model" model_dir.mkdir() provider = create_provider( "local-transformers", {"LOCAL_MODEL_PATH": str(model_dir), "LOCAL_LOAD_IN_4BIT": "false"}, ) assert provider.load_in_4bit is False def test_local_transformers_provider_formats_chat_prompt_with_template(): class FakeTokenizer: def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False): assert tokenize is False assert add_generation_prompt is True return "|".join(message["role"] for message in messages) provider = LocalTransformersProvider(model_path="unused") assert provider._format_prompt(FakeTokenizer(), "hello") == "system|user" def test_write_dry_run_artifacts_creates_metadata_loss_and_plot(tmp_path: Path): plan = { "dataset_examples": 6, "model_name": "Qwen/Qwen2.5-Coder-7B-Instruct", "output_dir": str(tmp_path / "model"), } artifacts = write_dry_run_artifacts(plan, tmp_path) assert artifacts["metadata"].exists() assert artifacts["loss_csv"].read_text(encoding="utf-8").splitlines()[0] == "step,loss" assert artifacts["loss_plot"].read_text(encoding="utf-8").startswith("