vllm-translategemma-4b-it / test_chat_template.py
swan.blanc
Add test template
b7567ca
#!/usr/bin/env python3
"""
Test script for the translation Gemma chat template.
This script tests the Jinja template used for vllm-translategemma-4b-it model,
validating various input formats and edge cases.
"""
import pytest
from jinja2 import Environment, FileSystemLoader
from typing import List, Dict, Any
@pytest.fixture
def template_tester():
"""Fixture to provide a TranslationTemplateTester instance."""
return TranslationTemplateTester()
class TranslationTemplateTester:
"""Test class for the translation chat template."""
def __init__(self, template_path: str = "chat_template.jinja"):
"""Initialize the tester with the template."""
self.env = Environment(
loader=FileSystemLoader("."),
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True
)
# Add custom raise_exception function
self.env.globals['raise_exception'] = self._raise_exception
self.template = self.env.get_template(template_path)
@staticmethod
def _raise_exception(message: str):
"""Custom exception function for Jinja template."""
raise ValueError(message)
def render_template(self, messages: List[Dict[str, str]], add_generation_prompt: bool = False) -> str:
"""Render the template with given messages."""
return self.template.render(
messages=messages,
add_generation_prompt=add_generation_prompt,
bos_token="<bos>"
)
def test_single_turn_translation(template_tester):
"""Test single message translation (user only)."""
messages = [
{
"role": "user",
"content": "<<<source>>>en<<<target>>>es<<<text>>>Hello, how are you?"
}
]
result = template_tester.render_template(messages, add_generation_prompt=True)
# Should contain bos token
assert "<bos>" in result
# Should contain user prompt with translation instruction
assert "<start_of_turn>user" in result
assert "English (en) to Spanish (es) translator" in result
assert "Hello, how are you?" in result
# Should end with model prompt
assert "<start_of_turn>model\n" in result
assert "<end_of_turn>\n" in result
def test_conversation_with_assistant(template_tester):
"""Test conversation with both user and assistant messages."""
messages = [
{
"role": "user",
"content": "<<<source>>>fr<<<target>>>en<<<text>>>Bonjour, comment allez-vous?"
},
{
"role": "assistant",
"content": "Hello, how are you?"
}
]
result = template_tester.render_template(messages, add_generation_prompt=False)
# Should contain both user and assistant turns
assert "<start_of_turn>user" in result
assert "French (fr) to English (en) translator" in result
assert "Bonjour, comment allez-vous?" in result
assert "<start_of_turn>model\n" in result
assert "Hello, how are you?" in result
@pytest.mark.parametrize(
"source_code,target_code,text,source_name,target_name", [
("en", "de", "Good morning", "English", "German"),
("zh", "en", "你好", "Chinese", "English"),
("es", "ja", "Buenos días", "Spanish", "Japanese"),
("ru", "fr", "Здравствуйте", "Russian", "French"),
("ar", "de", "مرحبا", "Arabic", "German"),
]
)
def test_different_language_pairs(template_tester, source_code, target_code, text, source_name, target_name):
"""Test various language code combinations."""
messages = [
{
"role": "user",
"content": f"<<<source>>>{source_code}<<<target>>>{target_code}<<<text>>>{text}"
}
]
result = template_tester.render_template(messages, add_generation_prompt=False)
assert f"{source_name} ({source_code})" in result
assert f"{target_name} ({target_code})" in result
assert text in result
def test_custom_prompt(template_tester):
"""Test custom prompt format."""
messages = [
{
"role": "user",
"content": "<<<custom>>>Write a poem about translation"
}
]
result = template_tester.render_template(messages, add_generation_prompt=True)
assert "<start_of_turn>user\nWrite a poem about translation" in result
assert "translator" not in result # Should not contain translation instruction
def test_language_code_underscore_conversion(template_tester):
"""Test that underscores in language codes are converted to hyphens."""
messages = [
{
"role": "user",
"content": "<<<source>>>zh_Hans<<<target>>>en<<<text>>>你好"
}
]
result = template_tester.render_template(messages, add_generation_prompt=False)
# Should work with underscores converted to hyphens
assert "Chinese (zh-Hans)" in result
assert "English (en)" in result
def test_add_generation_prompt_false(template_tester):
"""Test when add_generation_prompt is False."""
messages = [
{
"role": "user",
"content": "<<<source>>>en<<<target>>>es<<<text>>>Hello"
},
{
"role": "assistant",
"content": "Hola"
}
]
result = template_tester.render_template(messages, add_generation_prompt=False)
# Should NOT end with <start_of_turn>model\n
assert result.endswith("<end_of_turn>\n")
assert not result.endswith("<start_of_turn>model\n")
def test_multiple_turns(template_tester):
"""Test conversation with multiple turns."""
messages = [
{
"role": "user",
"content": "<<<source>>>en<<<target>>>es<<<text>>>Hello"
},
{
"role": "assistant",
"content": "Hola"
},
{
"role": "user",
"content": "<<<source>>>es<<<target>>>en<<<text>>>Gracias"
},
{
"role": "assistant",
"content": "Thank you"
}
]
result = template_tester.render_template(messages, add_generation_prompt=True)
# Should contain all turns
assert result.count("<start_of_turn>user") == 2
# The template adds <start_of_turn>model\n after each assistant message content
# PLUS at the end if add_generation_prompt is True
# Let's count what we actually get
assert "<start_of_turn>model" in result
def test_trimming_content(template_tester):
"""Test that content is properly trimmed."""
messages = [
{
"role": "user",
"content": "<<<source>>>en<<<target>>>/es<<<text>>> Hello World "
}
]
result = template_tester.render_template(messages, add_generation_prompt=False)
# The text content should be trimmed
assert "Hello World" in result
def test_error_start_with_assistant(template_tester):
"""Test that conversation starting with assistant raises error."""
messages = [
{
"role": "assistant",
"content": "Hello"
}
]
with pytest.raises(ValueError, match="start with a user prompt"):
template_tester.render_template(messages, add_generation_prompt=False)
def test_error_alternating_roles(template_tester):
"""Test that non-alternating roles raise error."""
messages = [
{
"role": "user",
"content": "<<<source>>>en<<<target>>>es<<<text>>>Hello"
},
{
"role": "user",
"content": "<<<source>>>en<<<target>>>es<<<text>>>World"
}
]
with pytest.raises(ValueError, match="alternate"):
template_tester.render_template(messages, add_generation_prompt=False)
def test_error_missing_translation_markers(template_tester):
"""Test that missing markers in user content raise error."""
messages = [
{
"role": "user",
"content": "translate this"
}
]
with pytest.raises(ValueError, match="source|target"):
template_tester.render_template(messages, add_generation_prompt=False)
def test_error_assistant_no_content(template_tester):
"""Test that assistant without content raises error."""
messages = [
{
"role": "user",
"content": "<<<source>>>en<<<target>>>es<<<text>>>Hello"
},
{
"role": "assistant",
"content": None
}
]
with pytest.raises(ValueError, match="content"):
template_tester.render_template(messages, add_generation_prompt=False)
def test_invalid_role(template_tester):
"""Test that invalid role raises error."""
messages = [
{
"role": "user",
"content": "<<<source>>>en<<<target>>>es<<<text>>>Hello"
},
{
"role": "system",
"content": "You are a translator"
}
]
with pytest.raises(ValueError, match="user or assistant"):
template_tester.render_template(messages, add_generation_prompt=False)
def test_edge_case_empty_text(template_tester):
"""Test with empty text in translation mode."""
messages = [
{
"role": "user",
"content": "<<<source>>>en<<<target>>>es<<<text>>>"
}
]
# Should work with empty text
result = template_tester.render_template(messages, add_generation_prompt=False)
assert "<start_of_turn>user" in result
def test_complex_multilingual(template_tester):
"""Test scripts with non-Latin characters."""
messages = [
{
"role": "user",
"content": "<<<source>>>ja<<<target>>>en<<<text>>>こんにちは世界"
}
]
result = template_tester.render_template(messages, add_generation_prompt=False)
assert "Japanese (ja)" in result
assert "こんにちは世界" in result