BART-ender commited on
Commit
7ee1b6d
·
verified ·
1 Parent(s): 63d8572

test(generation): add unit tests for render_prompt

Browse files
Files changed (1) hide show
  1. tests/test_prompting.py +51 -0
tests/test_prompting.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any
3
+ import pytest
4
+ from app.generation.prompting import render_prompt, SYSTEM_PROMPT
5
+
6
+ class TokenizerWithTemplate:
7
+ def __init__(self, template: str | None = "some template"):
8
+ self.chat_template = template
9
+
10
+ def apply_chat_template(self, messages: list[dict[str, str]], **kwargs: Any) -> str:
11
+ if self.chat_template is None:
12
+ raise ValueError("No template set")
13
+ return f"TEMPLATED: {messages[-1]['content']}"
14
+
15
+ class TokenizerWithoutTemplate:
16
+ pass
17
+
18
+ def test_render_prompt_with_valid_template():
19
+ tokenizer = TokenizerWithTemplate()
20
+ prompt = render_prompt(tokenizer, "Hello")
21
+ assert prompt == "TEMPLATED: Hello"
22
+
23
+ def test_render_prompt_with_none_template():
24
+ # This simulates the failure case reported by the user
25
+ tokenizer = TokenizerWithTemplate(template=None)
26
+ prompt = render_prompt(tokenizer, "Hello")
27
+ assert "User: Hello" in prompt
28
+ assert "Assistant:" in prompt
29
+ assert SYSTEM_PROMPT in prompt
30
+
31
+ def test_render_prompt_with_exception_in_apply():
32
+ tokenizer = TokenizerWithTemplate()
33
+ # Mock apply_chat_template to raise exception
34
+ def broken_apply(*args, **kwargs):
35
+ raise RuntimeError("broken")
36
+ tokenizer.apply_chat_template = broken_apply
37
+
38
+ prompt = render_prompt(tokenizer, "Hello")
39
+ assert "User: Hello" in prompt
40
+
41
+ def test_render_prompt_without_apply_method():
42
+ tokenizer = TokenizerWithoutTemplate()
43
+ prompt = render_prompt(tokenizer, "Hello")
44
+ assert "User: Hello" in prompt
45
+
46
+ def test_render_prompt_hrm_text_fallback():
47
+ tokenizer = TokenizerWithoutTemplate()
48
+ setattr(tokenizer, "name_or_path", "sapientinc/HRM-Text-1B")
49
+ prompt = render_prompt(tokenizer, "Hello")
50
+ assert "<|im_start|><|quad_end|><|object_ref_end|><|im_end|>" in prompt
51
+ assert "User: Hello" in prompt