Spaces:
Paused
Paused
| """ | |
| Prompt Formatter Service | |
| Handles formatting prompts appropriately for different model types: | |
| - Completion models: Raw text continuation | |
| - Instruction models: System prompt + user message with chat template | |
| """ | |
| from typing import Dict, Optional, Any, List | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def _try_mistral_common_format(messages: List[Dict[str, str]], model_name: str) -> Optional[str]: | |
| """ | |
| Try to use mistral_common for proper Mistral/Devstral chat formatting. | |
| Returns None if mistral_common is not available or fails. | |
| """ | |
| try: | |
| from mistral_common.protocol.instruct.messages import ( | |
| SystemMessage, UserMessage | |
| ) | |
| from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
| from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
| # Load the tokenizer from HF hub | |
| tokenizer = MistralTokenizer.from_hf_hub(model_name) | |
| # Build messages | |
| mistral_messages = [] | |
| for msg in messages: | |
| if msg["role"] == "system": | |
| mistral_messages.append(SystemMessage(content=msg["content"])) | |
| elif msg["role"] == "user": | |
| mistral_messages.append(UserMessage(content=msg["content"])) | |
| # Encode to get token IDs | |
| request = ChatCompletionRequest(messages=mistral_messages) | |
| tokenized = tokenizer.encode_chat_completion(request) | |
| # Decode back to text for use with HF tokenizer | |
| # This gives us the properly formatted prompt string | |
| decoded = tokenizer.decode(tokenized.tokens) | |
| logger.info(f"Used mistral_common format for {model_name}") | |
| return decoded | |
| except ImportError: | |
| logger.warning("mistral_common not available, using fallback format") | |
| return None | |
| except Exception as e: | |
| logger.warning(f"mistral_common formatting failed: {e}, using fallback") | |
| return None | |
| class PromptFormatter: | |
| """ | |
| Unified prompt formatting for different model types. | |
| Completion models (CodeGen, Code Llama base): | |
| - Pass prompt through unchanged | |
| - Model treats it as text to continue | |
| Instruction models (Devstral, instruct variants): | |
| - Use mistral_common for Mistral/Devstral models | |
| - Fallback to tokenizer's chat_template if available | |
| - Final fallback to manual Mistral format | |
| """ | |
| def format( | |
| self, | |
| prompt: str, | |
| model_config: Dict[str, Any], | |
| tokenizer: Any, | |
| system_prompt_override: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Format a prompt appropriately for the model type. | |
| Args: | |
| prompt: The user's input (e.g., "def quicksort(arr):") | |
| model_config: Model configuration from model_config.py | |
| tokenizer: HuggingFace tokenizer for the model | |
| system_prompt_override: Optional override for the default system prompt | |
| Returns: | |
| Formatted prompt ready for tokenization | |
| """ | |
| prompt_style = model_config.get("prompt_style", "completion") | |
| if prompt_style == "instruction": | |
| return self._format_instruction( | |
| prompt, | |
| model_config, | |
| tokenizer, | |
| system_prompt_override | |
| ) | |
| # Completion style: return raw prompt | |
| return prompt | |
| def _format_instruction( | |
| self, | |
| prompt: str, | |
| model_config: Dict[str, Any], | |
| tokenizer: Any, | |
| system_prompt_override: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Format prompt for instruction-tuned models. | |
| Priority: | |
| 1. mistral_common for Mistral/Devstral models | |
| 2. Tokenizer's native chat_template | |
| 3. Manual Mistral format fallback | |
| """ | |
| # Get system prompt (override > model default > generic fallback) | |
| system_prompt = system_prompt_override or model_config.get("system_prompt") | |
| if not system_prompt: | |
| system_prompt = "You are a helpful coding assistant. Continue the code provided." | |
| # Build messages list | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| # For Mistral/Devstral models, try mistral_common first | |
| architecture = model_config.get("architecture", "") | |
| hf_path = model_config.get("hf_path", "") | |
| if architecture == "mistral" or "mistral" in hf_path.lower(): | |
| formatted = _try_mistral_common_format(messages, hf_path) | |
| if formatted: | |
| return formatted | |
| # Try tokenizer's native chat template | |
| if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None: | |
| try: | |
| formatted = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| logger.info("Used HF tokenizer chat_template") | |
| return formatted | |
| except Exception as e: | |
| logger.warning(f"chat_template failed: {e}, using manual format") | |
| # Fallback: Manual Mistral/Llama format | |
| # Note: Don't include <s> as the tokenizer adds it during tokenization | |
| return self._manual_mistral_format(prompt, system_prompt) | |
| def _manual_mistral_format(self, prompt: str, system_prompt: str) -> str: | |
| """ | |
| Manual Mistral instruction format as fallback. | |
| Format: [INST] {system}\n\n{user} [/INST] | |
| Note: <s> is NOT included as the tokenizer adds BOS automatically. | |
| """ | |
| logger.info("Using manual Mistral instruction format") | |
| return f"[INST] {system_prompt}\n\n{prompt} [/INST]" | |
| # Singleton instance for convenience | |
| _formatter = PromptFormatter() | |
| def format_prompt( | |
| prompt: str, | |
| model_config: Dict[str, Any], | |
| tokenizer: Any, | |
| system_prompt_override: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Convenience function to format a prompt. | |
| Args: | |
| prompt: The user's input (e.g., "def quicksort(arr):") | |
| model_config: Model configuration from model_config.py | |
| tokenizer: HuggingFace tokenizer for the model | |
| system_prompt_override: Optional override for the default system prompt | |
| Returns: | |
| Formatted prompt ready for tokenization | |
| """ | |
| return _formatter.format(prompt, model_config, tokenizer, system_prompt_override) | |