api / backend /prompt_formatter.py
gary-boon
Use mistral_common for proper Devstral prompt formatting
3e80769
raw
history blame
6.51 kB
"""
Prompt Formatter Service
Handles formatting prompts appropriately for different model types:
- Completion models: Raw text continuation
- Instruction models: System prompt + user message with chat template
"""
from typing import Dict, Optional, Any, List
import logging
logger = logging.getLogger(__name__)
def _try_mistral_common_format(messages: List[Dict[str, str]], model_name: str) -> Optional[str]:
"""
Try to use mistral_common for proper Mistral/Devstral chat formatting.
Returns None if mistral_common is not available or fails.
"""
try:
from mistral_common.protocol.instruct.messages import (
SystemMessage, UserMessage
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
# Load the tokenizer from HF hub
tokenizer = MistralTokenizer.from_hf_hub(model_name)
# Build messages
mistral_messages = []
for msg in messages:
if msg["role"] == "system":
mistral_messages.append(SystemMessage(content=msg["content"]))
elif msg["role"] == "user":
mistral_messages.append(UserMessage(content=msg["content"]))
# Encode to get token IDs
request = ChatCompletionRequest(messages=mistral_messages)
tokenized = tokenizer.encode_chat_completion(request)
# Decode back to text for use with HF tokenizer
# This gives us the properly formatted prompt string
decoded = tokenizer.decode(tokenized.tokens)
logger.info(f"Used mistral_common format for {model_name}")
return decoded
except ImportError:
logger.warning("mistral_common not available, using fallback format")
return None
except Exception as e:
logger.warning(f"mistral_common formatting failed: {e}, using fallback")
return None
class PromptFormatter:
"""
Unified prompt formatting for different model types.
Completion models (CodeGen, Code Llama base):
- Pass prompt through unchanged
- Model treats it as text to continue
Instruction models (Devstral, instruct variants):
- Use mistral_common for Mistral/Devstral models
- Fallback to tokenizer's chat_template if available
- Final fallback to manual Mistral format
"""
def format(
self,
prompt: str,
model_config: Dict[str, Any],
tokenizer: Any,
system_prompt_override: Optional[str] = None
) -> str:
"""
Format a prompt appropriately for the model type.
Args:
prompt: The user's input (e.g., "def quicksort(arr):")
model_config: Model configuration from model_config.py
tokenizer: HuggingFace tokenizer for the model
system_prompt_override: Optional override for the default system prompt
Returns:
Formatted prompt ready for tokenization
"""
prompt_style = model_config.get("prompt_style", "completion")
if prompt_style == "instruction":
return self._format_instruction(
prompt,
model_config,
tokenizer,
system_prompt_override
)
# Completion style: return raw prompt
return prompt
def _format_instruction(
self,
prompt: str,
model_config: Dict[str, Any],
tokenizer: Any,
system_prompt_override: Optional[str] = None
) -> str:
"""
Format prompt for instruction-tuned models.
Priority:
1. mistral_common for Mistral/Devstral models
2. Tokenizer's native chat_template
3. Manual Mistral format fallback
"""
# Get system prompt (override > model default > generic fallback)
system_prompt = system_prompt_override or model_config.get("system_prompt")
if not system_prompt:
system_prompt = "You are a helpful coding assistant. Continue the code provided."
# Build messages list
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
# For Mistral/Devstral models, try mistral_common first
architecture = model_config.get("architecture", "")
hf_path = model_config.get("hf_path", "")
if architecture == "mistral" or "mistral" in hf_path.lower():
formatted = _try_mistral_common_format(messages, hf_path)
if formatted:
return formatted
# Try tokenizer's native chat template
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None:
try:
formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
logger.info("Used HF tokenizer chat_template")
return formatted
except Exception as e:
logger.warning(f"chat_template failed: {e}, using manual format")
# Fallback: Manual Mistral/Llama format
# Note: Don't include <s> as the tokenizer adds it during tokenization
return self._manual_mistral_format(prompt, system_prompt)
def _manual_mistral_format(self, prompt: str, system_prompt: str) -> str:
"""
Manual Mistral instruction format as fallback.
Format: [INST] {system}\n\n{user} [/INST]
Note: <s> is NOT included as the tokenizer adds BOS automatically.
"""
logger.info("Using manual Mistral instruction format")
return f"[INST] {system_prompt}\n\n{prompt} [/INST]"
# Singleton instance for convenience
_formatter = PromptFormatter()
def format_prompt(
prompt: str,
model_config: Dict[str, Any],
tokenizer: Any,
system_prompt_override: Optional[str] = None
) -> str:
"""
Convenience function to format a prompt.
Args:
prompt: The user's input (e.g., "def quicksort(arr):")
model_config: Model configuration from model_config.py
tokenizer: HuggingFace tokenizer for the model
system_prompt_override: Optional override for the default system prompt
Returns:
Formatted prompt ready for tokenization
"""
return _formatter.format(prompt, model_config, tokenizer, system_prompt_override)