mainframe-modernization-openenv / tests /test_training_workflow.py
Ishangtxl's picture
Sync from GitHub c4e4dad
a537615 verified
import json
from pathlib import Path
from legacy_cobol_env.eval.providers import create_provider
from legacy_cobol_env.eval.providers import LocalTransformersProvider
from legacy_cobol_env.training.train_sft import SFTArgs, build_sft_plan, load_jsonl_rows, write_dry_run_artifacts
def test_build_sft_plan_reads_dataset_without_gpu_dependencies(tmp_path: Path):
dataset = tmp_path / "tiny.jsonl"
dataset.write_text(
json.dumps(
{
"task_id": "invoice_occurs_001",
"family_id": "invoice_occurs_totals",
"messages": [
{"role": "user", "content": "prompt"},
{"role": "assistant", "content": "{\"code\": \"def migrate(input_record: str) -> str:\\n return input_record\\n\"}"},
],
}
)
+ "\n",
encoding="utf-8",
)
plan = build_sft_plan(SFTArgs(dataset=str(dataset), output_dir=str(tmp_path / "out")))
assert plan["dataset_examples"] == 1
assert plan["model_name"]
assert plan["uses_lora"] is True
assert plan["output_dir"] == str(tmp_path / "out")
def test_load_jsonl_rows_rejects_missing_messages(tmp_path: Path):
dataset = tmp_path / "bad.jsonl"
dataset.write_text(json.dumps({"task_id": "x"}) + "\n", encoding="utf-8")
try:
load_jsonl_rows(dataset)
except ValueError as exc:
assert "messages" in str(exc)
else:
raise AssertionError("expected invalid dataset row to fail")
def test_local_transformers_provider_requires_model_path():
try:
create_provider("local-transformers", {})
except ValueError as exc:
assert "LOCAL_MODEL_PATH" in str(exc)
else:
raise AssertionError("expected missing local model path to fail")
def test_local_transformers_provider_detects_peft_adapter_config(tmp_path: Path):
adapter_dir = tmp_path / "adapter"
adapter_dir.mkdir()
(adapter_dir / "adapter_config.json").write_text(json.dumps({"base_model_name_or_path": "base-model"}), encoding="utf-8")
provider = LocalTransformersProvider(model_path=str(adapter_dir))
assert provider._adapter_base_model_path() == "base-model"
def test_local_transformers_provider_env_accepts_base_model_path(tmp_path: Path):
adapter_dir = tmp_path / "adapter"
adapter_dir.mkdir()
(adapter_dir / "adapter_config.json").write_text("{}", encoding="utf-8")
provider = create_provider(
"local-transformers",
{"LOCAL_MODEL_PATH": str(adapter_dir), "LOCAL_BASE_MODEL_PATH": "Qwen/base"},
)
assert provider.base_model_path == "Qwen/base"
assert provider.load_in_4bit is True
def test_local_transformers_provider_can_disable_4bit(tmp_path: Path):
model_dir = tmp_path / "model"
model_dir.mkdir()
provider = create_provider(
"local-transformers",
{"LOCAL_MODEL_PATH": str(model_dir), "LOCAL_LOAD_IN_4BIT": "false"},
)
assert provider.load_in_4bit is False
def test_local_transformers_provider_formats_chat_prompt_with_template():
class FakeTokenizer:
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):
assert tokenize is False
assert add_generation_prompt is True
return "|".join(message["role"] for message in messages)
provider = LocalTransformersProvider(model_path="unused")
assert provider._format_prompt(FakeTokenizer(), "hello") == "system|user"
def test_write_dry_run_artifacts_creates_metadata_loss_and_plot(tmp_path: Path):
plan = {
"dataset_examples": 6,
"model_name": "Qwen/Qwen2.5-Coder-7B-Instruct",
"output_dir": str(tmp_path / "model"),
}
artifacts = write_dry_run_artifacts(plan, tmp_path)
assert artifacts["metadata"].exists()
assert artifacts["loss_csv"].read_text(encoding="utf-8").splitlines()[0] == "step,loss"
assert artifacts["loss_plot"].read_text(encoding="utf-8").startswith("<svg")