| |
| from typing import Protocol, List, Tuple |
| from transformers import AutoTokenizer |
|
|
|
|
| class PromptTemplate(Protocol): |
| """Protocol for prompt templates.""" |
| def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: |
| pass |
|
|
|
|
| class LlamaPromptTemplate: |
| def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str: |
| system_message = f"Please assist based on the following context: {context}" |
| prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" |
| |
| for user_msg, assistant_msg in chat_history[-max_history_turns:]: |
| prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" |
| prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" |
| |
| prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" |
| prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| return prompt |
| |
|
|
| class TransformersPromptTemplate: |
| def __init__(self, model_path: str): |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
| def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: |
| messages = [ |
| { |
| "role": "system", |
| "content": f"Please assist based on the following context: {context}", |
| } |
| ] |
| |
| for user_msg, assistant_msg in chat_history: |
| messages.extend([ |
| {"role": "user", "content": user_msg}, |
| {"role": "assistant", "content": assistant_msg} |
| ]) |
| |
| messages.append({"role": "user", "content": user_input}) |
| |
| tokenized_chat = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| return tokenized_chat |
|
|