Spaces:
Runtime error
Runtime error
| import pytest | |
| from llama_index.core.llms import ChatMessage, MessageRole | |
| from private_gpt.components.llm.prompt_helper import ( | |
| ChatMLPromptStyle, | |
| DefaultPromptStyle, | |
| Llama2PromptStyle, | |
| MistralPromptStyle, | |
| TagPromptStyle, | |
| get_prompt_style, | |
| ) | |
| def test_get_prompt_style_success(prompt_style, expected_prompt_style): | |
| assert isinstance(get_prompt_style(prompt_style), expected_prompt_style) | |
| def test_get_prompt_style_failure(): | |
| prompt_style = "unknown" | |
| with pytest.raises(ValueError) as exc_info: | |
| get_prompt_style(prompt_style) | |
| assert str(exc_info.value) == f"Unknown prompt_style='{prompt_style}'" | |
| def test_tag_prompt_style_format(): | |
| prompt_style = TagPromptStyle() | |
| messages = [ | |
| ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), | |
| ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
| ] | |
| expected_prompt = ( | |
| "<|system|>: You are an AI assistant.\n" | |
| "<|user|>: Hello, how are you doing?\n" | |
| "<|assistant|>: " | |
| ) | |
| assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
| def test_tag_prompt_style_format_with_system_prompt(): | |
| prompt_style = TagPromptStyle() | |
| messages = [ | |
| ChatMessage( | |
| content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM | |
| ), | |
| ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
| ] | |
| expected_prompt = ( | |
| "<|system|>: FOO BAR Custom sys prompt from messages.\n" | |
| "<|user|>: Hello, how are you doing?\n" | |
| "<|assistant|>: " | |
| ) | |
| assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
| def test_mistral_prompt_style_format(): | |
| prompt_style = MistralPromptStyle() | |
| messages = [ | |
| ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), | |
| ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
| ] | |
| expected_prompt = ( | |
| "<s>[INST] You are an AI assistant. [/INST]</s>" | |
| "[INST] Hello, how are you doing? [/INST]" | |
| ) | |
| assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
| def test_chatml_prompt_style_format(): | |
| prompt_style = ChatMLPromptStyle() | |
| messages = [ | |
| ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), | |
| ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
| ] | |
| expected_prompt = ( | |
| "<|im_start|>system\n" | |
| "You are an AI assistant.<|im_end|>\n" | |
| "<|im_start|>user\n" | |
| "Hello, how are you doing?<|im_end|>\n" | |
| "<|im_start|>assistant\n" | |
| ) | |
| assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
| def test_llama2_prompt_style_format(): | |
| prompt_style = Llama2PromptStyle() | |
| messages = [ | |
| ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM), | |
| ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
| ] | |
| expected_prompt = ( | |
| "<s> [INST] <<SYS>>\n" | |
| " You are an AI assistant. \n" | |
| "<</SYS>>\n" | |
| "\n" | |
| " Hello, how are you doing? [/INST]" | |
| ) | |
| assert prompt_style.messages_to_prompt(messages) == expected_prompt | |
| def test_llama2_prompt_style_with_system_prompt(): | |
| prompt_style = Llama2PromptStyle() | |
| messages = [ | |
| ChatMessage( | |
| content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM | |
| ), | |
| ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), | |
| ] | |
| expected_prompt = ( | |
| "<s> [INST] <<SYS>>\n" | |
| " FOO BAR Custom sys prompt from messages. \n" | |
| "<</SYS>>\n" | |
| "\n" | |
| " Hello, how are you doing? [/INST]" | |
| ) | |
| assert prompt_style.messages_to_prompt(messages) == expected_prompt | |