| |
| """ |
| Test script for the translation Gemma chat template. |
| |
| This script tests the Jinja template used for vllm-translategemma-4b-it model, |
| validating various input formats and edge cases. |
| """ |
|
|
| import pytest |
| from jinja2 import Environment, FileSystemLoader |
| from typing import List, Dict, Any |
|
|
|
|
| @pytest.fixture |
| def template_tester(): |
| """Fixture to provide a TranslationTemplateTester instance.""" |
| return TranslationTemplateTester() |
|
|
|
|
| class TranslationTemplateTester: |
| """Test class for the translation chat template.""" |
|
|
| def __init__(self, template_path: str = "chat_template.jinja"): |
| """Initialize the tester with the template.""" |
| self.env = Environment( |
| loader=FileSystemLoader("."), |
| trim_blocks=True, |
| lstrip_blocks=True, |
| keep_trailing_newline=True |
| ) |
| |
| self.env.globals['raise_exception'] = self._raise_exception |
| self.template = self.env.get_template(template_path) |
|
|
| @staticmethod |
| def _raise_exception(message: str): |
| """Custom exception function for Jinja template.""" |
| raise ValueError(message) |
|
|
| def render_template(self, messages: List[Dict[str, str]], add_generation_prompt: bool = False) -> str: |
| """Render the template with given messages.""" |
| return self.template.render( |
| messages=messages, |
| add_generation_prompt=add_generation_prompt, |
| bos_token="<bos>" |
| ) |
|
|
|
|
| def test_single_turn_translation(template_tester): |
| """Test single message translation (user only).""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>en<<<target>>>es<<<text>>>Hello, how are you?" |
| } |
| ] |
|
|
| result = template_tester.render_template(messages, add_generation_prompt=True) |
|
|
| |
| assert "<bos>" in result |
|
|
| |
| assert "<start_of_turn>user" in result |
| assert "English (en) to Spanish (es) translator" in result |
| assert "Hello, how are you?" in result |
|
|
| |
| assert "<start_of_turn>model\n" in result |
| assert "<end_of_turn>\n" in result |
|
|
|
|
| def test_conversation_with_assistant(template_tester): |
| """Test conversation with both user and assistant messages.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>fr<<<target>>>en<<<text>>>Bonjour, comment allez-vous?" |
| }, |
| { |
| "role": "assistant", |
| "content": "Hello, how are you?" |
| } |
| ] |
|
|
| result = template_tester.render_template(messages, add_generation_prompt=False) |
|
|
| |
| assert "<start_of_turn>user" in result |
| assert "French (fr) to English (en) translator" in result |
| assert "Bonjour, comment allez-vous?" in result |
| assert "<start_of_turn>model\n" in result |
| assert "Hello, how are you?" in result |
|
|
|
|
| @pytest.mark.parametrize( |
| "source_code,target_code,text,source_name,target_name", [ |
| ("en", "de", "Good morning", "English", "German"), |
| ("zh", "en", "你好", "Chinese", "English"), |
| ("es", "ja", "Buenos días", "Spanish", "Japanese"), |
| ("ru", "fr", "Здравствуйте", "Russian", "French"), |
| ("ar", "de", "مرحبا", "Arabic", "German"), |
| ] |
| ) |
| def test_different_language_pairs(template_tester, source_code, target_code, text, source_name, target_name): |
| """Test various language code combinations.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": f"<<<source>>>{source_code}<<<target>>>{target_code}<<<text>>>{text}" |
| } |
| ] |
|
|
| result = template_tester.render_template(messages, add_generation_prompt=False) |
|
|
| assert f"{source_name} ({source_code})" in result |
| assert f"{target_name} ({target_code})" in result |
| assert text in result |
|
|
|
|
| def test_custom_prompt(template_tester): |
| """Test custom prompt format.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<custom>>>Write a poem about translation" |
| } |
| ] |
|
|
| result = template_tester.render_template(messages, add_generation_prompt=True) |
|
|
| assert "<start_of_turn>user\nWrite a poem about translation" in result |
| assert "translator" not in result |
|
|
|
|
| def test_language_code_underscore_conversion(template_tester): |
| """Test that underscores in language codes are converted to hyphens.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>zh_Hans<<<target>>>en<<<text>>>你好" |
| } |
| ] |
|
|
| result = template_tester.render_template(messages, add_generation_prompt=False) |
|
|
| |
| assert "Chinese (zh-Hans)" in result |
| assert "English (en)" in result |
|
|
|
|
| def test_add_generation_prompt_false(template_tester): |
| """Test when add_generation_prompt is False.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>en<<<target>>>es<<<text>>>Hello" |
| }, |
| { |
| "role": "assistant", |
| "content": "Hola" |
| } |
| ] |
|
|
| result = template_tester.render_template(messages, add_generation_prompt=False) |
|
|
| |
| assert result.endswith("<end_of_turn>\n") |
| assert not result.endswith("<start_of_turn>model\n") |
|
|
|
|
| def test_multiple_turns(template_tester): |
| """Test conversation with multiple turns.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>en<<<target>>>es<<<text>>>Hello" |
| }, |
| { |
| "role": "assistant", |
| "content": "Hola" |
| }, |
| { |
| "role": "user", |
| "content": "<<<source>>>es<<<target>>>en<<<text>>>Gracias" |
| }, |
| { |
| "role": "assistant", |
| "content": "Thank you" |
| } |
| ] |
|
|
| result = template_tester.render_template(messages, add_generation_prompt=True) |
| |
| assert result.count("<start_of_turn>user") == 2 |
| |
| |
| |
| assert "<start_of_turn>model" in result |
|
|
|
|
| def test_trimming_content(template_tester): |
| """Test that content is properly trimmed.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>en<<<target>>>/es<<<text>>> Hello World " |
| } |
| ] |
|
|
| result = template_tester.render_template(messages, add_generation_prompt=False) |
|
|
| |
| assert "Hello World" in result |
|
|
|
|
| def test_error_start_with_assistant(template_tester): |
| """Test that conversation starting with assistant raises error.""" |
| messages = [ |
| { |
| "role": "assistant", |
| "content": "Hello" |
| } |
| ] |
|
|
| with pytest.raises(ValueError, match="start with a user prompt"): |
| template_tester.render_template(messages, add_generation_prompt=False) |
|
|
|
|
| def test_error_alternating_roles(template_tester): |
| """Test that non-alternating roles raise error.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>en<<<target>>>es<<<text>>>Hello" |
| }, |
| { |
| "role": "user", |
| "content": "<<<source>>>en<<<target>>>es<<<text>>>World" |
| } |
| ] |
|
|
| with pytest.raises(ValueError, match="alternate"): |
| template_tester.render_template(messages, add_generation_prompt=False) |
|
|
|
|
| def test_error_missing_translation_markers(template_tester): |
| """Test that missing markers in user content raise error.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "translate this" |
| } |
| ] |
|
|
| with pytest.raises(ValueError, match="source|target"): |
| template_tester.render_template(messages, add_generation_prompt=False) |
|
|
|
|
| def test_error_assistant_no_content(template_tester): |
| """Test that assistant without content raises error.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>en<<<target>>>es<<<text>>>Hello" |
| }, |
| { |
| "role": "assistant", |
| "content": None |
| } |
| ] |
|
|
| with pytest.raises(ValueError, match="content"): |
| template_tester.render_template(messages, add_generation_prompt=False) |
|
|
|
|
| def test_invalid_role(template_tester): |
| """Test that invalid role raises error.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>en<<<target>>>es<<<text>>>Hello" |
| }, |
| { |
| "role": "system", |
| "content": "You are a translator" |
| } |
| ] |
|
|
| with pytest.raises(ValueError, match="user or assistant"): |
| template_tester.render_template(messages, add_generation_prompt=False) |
|
|
|
|
| def test_edge_case_empty_text(template_tester): |
| """Test with empty text in translation mode.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>en<<<target>>>es<<<text>>>" |
| } |
| ] |
|
|
| |
| result = template_tester.render_template(messages, add_generation_prompt=False) |
| assert "<start_of_turn>user" in result |
|
|
|
|
| def test_complex_multilingual(template_tester): |
| """Test scripts with non-Latin characters.""" |
| messages = [ |
| { |
| "role": "user", |
| "content": "<<<source>>>ja<<<target>>>en<<<text>>>こんにちは世界" |
| } |
| ] |
|
|
| result = template_tester.render_template(messages, add_generation_prompt=False) |
| assert "Japanese (ja)" in result |
| assert "こんにちは世界" in result |