File size: 1,306 Bytes
57be184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""Unified prompt formatting and response extraction"""

import re

class PromptFormatter:
    """Handles prompt formatting and assistant response extraction"""
    
    @staticmethod
    def format_prompt(model_name, prompt, partial_response, continuation):
        """Format the full prompt for generation"""
        if "meta-llama" in model_name:
            return f"{prompt}\n\nAssistant: {partial_response}{continuation}"
        else:
            raise NotImplementedError(f"Prompt formatting not implemented for model: {model_name}")
    
    @staticmethod
    def extract_assistant_response(model_name, full_response):
        """Extract the assistant's response from the full generated text"""
        if "meta-llama" in model_name:
            # Check if we have multiple assistant tags and get the last one
            assistant_tags = re.findall(r"Assistant:\s*", full_response)
            if len(assistant_tags) > 1:
                print(f"Found multiple assistant tags ({len(assistant_tags)})\nFull response:\n{full_response}\n**")
            
            # Only split on the first assistant tag
            return full_response.split("Assistant:", maxsplit=1)[-1].strip()
        else:
            raise NotImplementedError(f"Response extraction not implemented for model: {model_name}")