Spaces:
Paused
Paused
| """ | |
| Mistral Tokenizer Wrapper | |
| Provides correct tokenization for Devstral using mistral-common library. | |
| The Tekken tokenizer used by Devstral is incompatible with HuggingFace's | |
| standard tokenization approach. This wrapper uses mistral-common to | |
| produce correct token sequences for the model. | |
| """ | |
| import logging | |
| from typing import List, Optional, Set | |
| logger = logging.getLogger(__name__) | |
| class MistralTokenizerWrapper: | |
| """ | |
| Wrapper around mistral-common's MistralTokenizer for Devstral. | |
| Uses encode_chat_completion() to produce correct token IDs | |
| that the model actually expects, rather than HF's text-based approach | |
| which produces corrupted tokens for Tekken-based models. | |
| """ | |
| def __init__(self, model_name: str): | |
| """ | |
| Initialize the Mistral tokenizer from HuggingFace hub. | |
| Args: | |
| model_name: HuggingFace model path (e.g., "mistralai/Devstral-Small-2507") | |
| """ | |
| try: | |
| from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
| self.tokenizer = MistralTokenizer.from_hf_hub(model_name) | |
| self._available = True | |
| logger.info(f"Loaded MistralTokenizer for {model_name}") | |
| except ImportError as e: | |
| logger.warning(f"mistral-common not available: {e}") | |
| self._available = False | |
| self.tokenizer = None | |
| except Exception as e: | |
| logger.error(f"Failed to load MistralTokenizer: {e}") | |
| self._available = False | |
| self.tokenizer = None | |
| def is_available(self) -> bool: | |
| """Check if the tokenizer was loaded successfully.""" | |
| return self._available | |
| def encode_chat( | |
| self, | |
| system_prompt: str, | |
| user_prompt: str | |
| ) -> List[int]: | |
| """ | |
| Encode chat messages to token IDs using mistral-common. | |
| This produces the correct token sequence for Devstral, including | |
| proper handling of control tokens like [INST] and [/INST]. | |
| Args: | |
| system_prompt: System message content | |
| user_prompt: User message content (e.g., "def quicksort(arr):") | |
| Returns: | |
| List of token IDs ready for model input | |
| """ | |
| if not self._available: | |
| raise RuntimeError("MistralTokenizer not available") | |
| from mistral_common.protocol.instruct.messages import ( | |
| SystemMessage, UserMessage | |
| ) | |
| from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
| # Build messages list | |
| messages = [] | |
| if system_prompt: | |
| messages.append(SystemMessage(content=system_prompt)) | |
| messages.append(UserMessage(content=user_prompt)) | |
| # Encode using mistral-common's chat completion encoding | |
| request = ChatCompletionRequest(messages=messages) | |
| tokenized = self.tokenizer.encode_chat_completion(request) | |
| logger.info(f"Encoded chat: {len(tokenized.tokens)} tokens") | |
| return tokenized.tokens | |
| def decode(self, token_ids: List[int]) -> str: | |
| """ | |
| Decode token IDs back to text. | |
| Args: | |
| token_ids: List of token IDs to decode | |
| Returns: | |
| Decoded text string | |
| """ | |
| if not self._available: | |
| raise RuntimeError("MistralTokenizer not available") | |
| return self.tokenizer.decode(token_ids) | |
| def decode_token(self, token_id: int) -> str: | |
| """ | |
| Decode a single token ID to text. | |
| Args: | |
| token_id: Single token ID to decode | |
| Returns: | |
| Decoded text for this token | |
| """ | |
| if not self._available: | |
| raise RuntimeError("MistralTokenizer not available") | |
| result = self.tokenizer.decode([token_id]) | |
| return result | |
| def get_control_token_ids(self) -> Set[int]: | |
| """ | |
| Return the full set of control/special token IDs known to the | |
| underlying Tekkenizer (e.g. ``<s>``, ``</s>``, ``[INST]``, ``[/INST]``, | |
| ``[SYSTEM_PROMPT]``, tool-call markers, etc.). | |
| These IDs are needed to label tokens with an accurate ``is_special`` | |
| flag in the trace response. The HF tokenizer's ``all_special_ids`` | |
| misses Mistral-specific chat-template delimiters, so we source them | |
| directly from mistral-common. | |
| Tries multiple attribute paths for robustness across mistral-common | |
| versions. Falls back to an empty set (with a warning) if none work — | |
| callers should still have the HF ``all_special_ids`` as a baseline. | |
| """ | |
| if not self._available: | |
| return set() | |
| try: | |
| inner = self.tokenizer.instruct_tokenizer.tokenizer | |
| except AttributeError: | |
| logger.warning( | |
| "MistralTokenizer has no instruct_tokenizer.tokenizer attribute" | |
| ) | |
| return set() | |
| # Preferred path: Tekkenizer reserves ranks [0, num_special_tokens) | |
| # for control tokens, so we can materialise the full set cheaply. | |
| num_special = getattr(inner, "num_special_tokens", None) | |
| if isinstance(num_special, int) and num_special > 0: | |
| return set(range(num_special)) | |
| # Fallback: try a couple of commonly-used attribute shapes. | |
| for attr in ("_special_tokens", "special_tokens"): | |
| specials = getattr(inner, attr, None) | |
| if isinstance(specials, dict): | |
| # dict[str, int] — values are token IDs | |
| try: | |
| return {int(v) for v in specials.values()} | |
| except Exception: | |
| pass | |
| if isinstance(specials, (list, tuple, set)): | |
| try: | |
| return {int(v) for v in specials} | |
| except Exception: | |
| pass | |
| logger.warning( | |
| "Could not determine control token ids from MistralTokenizer; " | |
| "is_special will be limited to HF tokenizer's all_special_ids" | |
| ) | |
| return set() | |
| def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]: | |
| """ | |
| Factory function to create a MistralTokenizerWrapper. | |
| Returns None if mistral-common is not available or loading fails. | |
| Args: | |
| model_name: HuggingFace model path | |
| Returns: | |
| MistralTokenizerWrapper instance or None | |
| """ | |
| wrapper = MistralTokenizerWrapper(model_name) | |
| if wrapper.is_available: | |
| return wrapper | |
| return None | |