""" Mistral Tokenizer Wrapper Provides correct tokenization for Devstral using mistral-common library. The Tekken tokenizer used by Devstral is incompatible with HuggingFace's standard tokenization approach. This wrapper uses mistral-common to produce correct token sequences for the model. """ import logging from typing import List, Optional, Set logger = logging.getLogger(__name__) class MistralTokenizerWrapper: """ Wrapper around mistral-common's MistralTokenizer for Devstral. Uses encode_chat_completion() to produce correct token IDs that the model actually expects, rather than HF's text-based approach which produces corrupted tokens for Tekken-based models. """ def __init__(self, model_name: str): """ Initialize the Mistral tokenizer from HuggingFace hub. Args: model_name: HuggingFace model path (e.g., "mistralai/Devstral-Small-2507") """ try: from mistral_common.tokens.tokenizers.mistral import MistralTokenizer self.tokenizer = MistralTokenizer.from_hf_hub(model_name) self._available = True logger.info(f"Loaded MistralTokenizer for {model_name}") except ImportError as e: logger.warning(f"mistral-common not available: {e}") self._available = False self.tokenizer = None except Exception as e: logger.error(f"Failed to load MistralTokenizer: {e}") self._available = False self.tokenizer = None @property def is_available(self) -> bool: """Check if the tokenizer was loaded successfully.""" return self._available def encode_chat( self, system_prompt: str, user_prompt: str ) -> List[int]: """ Encode chat messages to token IDs using mistral-common. This produces the correct token sequence for Devstral, including proper handling of control tokens like [INST] and [/INST]. Args: system_prompt: System message content user_prompt: User message content (e.g., "def quicksort(arr):") Returns: List of token IDs ready for model input """ if not self._available: raise RuntimeError("MistralTokenizer not available") from mistral_common.protocol.instruct.messages import ( SystemMessage, UserMessage ) from mistral_common.protocol.instruct.request import ChatCompletionRequest # Build messages list messages = [] if system_prompt: messages.append(SystemMessage(content=system_prompt)) messages.append(UserMessage(content=user_prompt)) # Encode using mistral-common's chat completion encoding request = ChatCompletionRequest(messages=messages) tokenized = self.tokenizer.encode_chat_completion(request) logger.info(f"Encoded chat: {len(tokenized.tokens)} tokens") return tokenized.tokens def decode(self, token_ids: List[int]) -> str: """ Decode token IDs back to text. Args: token_ids: List of token IDs to decode Returns: Decoded text string """ if not self._available: raise RuntimeError("MistralTokenizer not available") return self.tokenizer.decode(token_ids) def decode_token(self, token_id: int) -> str: """ Decode a single token ID to text. Args: token_id: Single token ID to decode Returns: Decoded text for this token """ if not self._available: raise RuntimeError("MistralTokenizer not available") result = self.tokenizer.decode([token_id]) return result def get_control_token_ids(self) -> Set[int]: """ Return the full set of control/special token IDs known to the underlying Tekkenizer (e.g. ````, ````, ``[INST]``, ``[/INST]``, ``[SYSTEM_PROMPT]``, tool-call markers, etc.). These IDs are needed to label tokens with an accurate ``is_special`` flag in the trace response. The HF tokenizer's ``all_special_ids`` misses Mistral-specific chat-template delimiters, so we source them directly from mistral-common. Tries multiple attribute paths for robustness across mistral-common versions. Falls back to an empty set (with a warning) if none work — callers should still have the HF ``all_special_ids`` as a baseline. """ if not self._available: return set() try: inner = self.tokenizer.instruct_tokenizer.tokenizer except AttributeError: logger.warning( "MistralTokenizer has no instruct_tokenizer.tokenizer attribute" ) return set() # Preferred path: Tekkenizer reserves ranks [0, num_special_tokens) # for control tokens, so we can materialise the full set cheaply. num_special = getattr(inner, "num_special_tokens", None) if isinstance(num_special, int) and num_special > 0: return set(range(num_special)) # Fallback: try a couple of commonly-used attribute shapes. for attr in ("_special_tokens", "special_tokens"): specials = getattr(inner, attr, None) if isinstance(specials, dict): # dict[str, int] — values are token IDs try: return {int(v) for v in specials.values()} except Exception: pass if isinstance(specials, (list, tuple, set)): try: return {int(v) for v in specials} except Exception: pass logger.warning( "Could not determine control token ids from MistralTokenizer; " "is_special will be limited to HF tokenizer's all_special_ids" ) return set() def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]: """ Factory function to create a MistralTokenizerWrapper. Returns None if mistral-common is not available or loading fails. Args: model_name: HuggingFace model path Returns: MistralTokenizerWrapper instance or None """ wrapper = MistralTokenizerWrapper(model_name) if wrapper.is_available: return wrapper return None