collaborative-decoding / src /utils /prompt_formatter.py
Alon Albalak
major update: all data saved on HF (prompts, results), unified utilities
57be184
"""Unified prompt formatting and response extraction"""
import re
class PromptFormatter:
"""Handles prompt formatting and assistant response extraction"""
@staticmethod
def format_prompt(model_name, prompt, partial_response, continuation):
"""Format the full prompt for generation"""
if "meta-llama" in model_name:
return f"{prompt}\n\nAssistant: {partial_response}{continuation}"
else:
raise NotImplementedError(f"Prompt formatting not implemented for model: {model_name}")
@staticmethod
def extract_assistant_response(model_name, full_response):
"""Extract the assistant's response from the full generated text"""
if "meta-llama" in model_name:
# Check if we have multiple assistant tags and get the last one
assistant_tags = re.findall(r"Assistant:\s*", full_response)
if len(assistant_tags) > 1:
print(f"Found multiple assistant tags ({len(assistant_tags)})\nFull response:\n{full_response}\n**")
# Only split on the first assistant tag
return full_response.split("Assistant:", maxsplit=1)[-1].strip()
else:
raise NotImplementedError(f"Response extraction not implemented for model: {model_name}")