|
|
"""Tests for training infrastructure.""" |
|
|
|
|
|
|
|
|
class TestSwiftConfig: |
|
|
"""Tests for ms-swift training configuration.""" |
|
|
|
|
|
def test_default_config(self): |
|
|
"""Test default training config.""" |
|
|
from zen_translator.training import SwiftTrainingConfig |
|
|
|
|
|
config = SwiftTrainingConfig() |
|
|
|
|
|
assert config.model_type == "qwen3-omni" |
|
|
assert config.train_type == "lora" |
|
|
assert config.lora_rank == 64 |
|
|
assert config.lora_alpha == 128 |
|
|
|
|
|
def test_to_swift_args(self): |
|
|
"""Test conversion to swift CLI arguments.""" |
|
|
from zen_translator.training import SwiftTrainingConfig |
|
|
|
|
|
config = SwiftTrainingConfig() |
|
|
args = config.to_swift_args() |
|
|
|
|
|
assert "--model_type=qwen3-omni" in args |
|
|
assert "--train_type=lora" in args |
|
|
assert "--lora_rank=64" in args |
|
|
|
|
|
def test_to_yaml(self, tmp_path): |
|
|
"""Test YAML export.""" |
|
|
from zen_translator.training import SwiftTrainingConfig |
|
|
|
|
|
config = SwiftTrainingConfig() |
|
|
yaml_path = tmp_path / "config.yaml" |
|
|
|
|
|
config.to_yaml(yaml_path) |
|
|
|
|
|
assert yaml_path.exists() |
|
|
|
|
|
|
|
|
import yaml |
|
|
|
|
|
with open(yaml_path) as f: |
|
|
saved = yaml.safe_load(f) |
|
|
|
|
|
assert saved["model"]["type"] == "qwen3-omni" |
|
|
assert saved["lora"]["rank"] == 64 |
|
|
|
|
|
|
|
|
class TestZenIdentityConfig: |
|
|
"""Tests for Zen identity finetuning config.""" |
|
|
|
|
|
def test_identity_system_prompt(self): |
|
|
"""Test identity system prompt is set.""" |
|
|
from zen_translator.training import ZenIdentityConfig |
|
|
|
|
|
config = ZenIdentityConfig() |
|
|
|
|
|
assert "Zen Translator" in config.system_prompt |
|
|
assert "Hanzo AI" in config.system_prompt |
|
|
|
|
|
|
|
|
class TestNewsAnchorConfig: |
|
|
"""Tests for news anchor training config.""" |
|
|
|
|
|
def test_anchor_names(self): |
|
|
"""Test anchor names are configured.""" |
|
|
from zen_translator.training import NewsAnchorConfig |
|
|
|
|
|
config = NewsAnchorConfig() |
|
|
|
|
|
assert len(config.anchor_names) > 0 |
|
|
assert "cnn" in config.anchor_names |
|
|
assert "bbc" in config.anchor_names |
|
|
|
|
|
def test_news_domains(self): |
|
|
"""Test news domains are configured.""" |
|
|
from zen_translator.training import NewsAnchorConfig |
|
|
|
|
|
config = NewsAnchorConfig() |
|
|
|
|
|
assert "politics" in config.news_domains |
|
|
assert "technology" in config.news_domains |
|
|
|
|
|
|
|
|
class TestNewsChannels: |
|
|
"""Tests for predefined news channels.""" |
|
|
|
|
|
def test_channels_defined(self): |
|
|
"""Test news channels are defined.""" |
|
|
from zen_translator.training import NEWS_CHANNELS |
|
|
|
|
|
assert len(NEWS_CHANNELS) > 0 |
|
|
assert "cnn" in NEWS_CHANNELS |
|
|
assert "bbc" in NEWS_CHANNELS |
|
|
assert "nhk" in NEWS_CHANNELS |
|
|
|
|
|
def test_channel_urls(self): |
|
|
"""Test channel URLs are valid.""" |
|
|
from zen_translator.training import NEWS_CHANNELS |
|
|
|
|
|
for name, url in NEWS_CHANNELS.items(): |
|
|
assert url.startswith("https://") |
|
|
assert "youtube.com" in url |
|
|
|
|
|
|
|
|
class TestCreateTrainingDataset: |
|
|
"""Tests for dataset creation.""" |
|
|
|
|
|
def test_create_jsonl_dataset(self, tmp_path): |
|
|
"""Test JSONL dataset creation.""" |
|
|
from zen_translator.training import create_training_dataset |
|
|
|
|
|
conversations = [ |
|
|
{ |
|
|
"conversations": [ |
|
|
{"role": "user", "content": "Hello"}, |
|
|
{"role": "assistant", "content": "Hi there!"}, |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
output_path = tmp_path / "train.jsonl" |
|
|
create_training_dataset(conversations, output_path, format="jsonl") |
|
|
|
|
|
assert output_path.exists() |
|
|
|
|
|
|
|
|
import json |
|
|
|
|
|
with open(output_path) as f: |
|
|
lines = f.readlines() |
|
|
|
|
|
assert len(lines) == 1 |
|
|
data = json.loads(lines[0]) |
|
|
assert "conversations" in data |
|
|
|