api / backend /mistral_tokenizer.py
gary-boon
Fix empty text and incorrect is_special for Mistral control tokens
b5e4add
"""
Mistral Tokenizer Wrapper
Provides correct tokenization for Devstral using mistral-common library.
The Tekken tokenizer used by Devstral is incompatible with HuggingFace's
standard tokenization approach. This wrapper uses mistral-common to
produce correct token sequences for the model.
"""
import logging
from typing import List, Optional, Set
logger = logging.getLogger(__name__)
class MistralTokenizerWrapper:
"""
Wrapper around mistral-common's MistralTokenizer for Devstral.
Uses encode_chat_completion() to produce correct token IDs
that the model actually expects, rather than HF's text-based approach
which produces corrupted tokens for Tekken-based models.
"""
def __init__(self, model_name: str):
"""
Initialize the Mistral tokenizer from HuggingFace hub.
Args:
model_name: HuggingFace model path (e.g., "mistralai/Devstral-Small-2507")
"""
try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
self.tokenizer = MistralTokenizer.from_hf_hub(model_name)
self._available = True
logger.info(f"Loaded MistralTokenizer for {model_name}")
except ImportError as e:
logger.warning(f"mistral-common not available: {e}")
self._available = False
self.tokenizer = None
except Exception as e:
logger.error(f"Failed to load MistralTokenizer: {e}")
self._available = False
self.tokenizer = None
@property
def is_available(self) -> bool:
"""Check if the tokenizer was loaded successfully."""
return self._available
def encode_chat(
self,
system_prompt: str,
user_prompt: str
) -> List[int]:
"""
Encode chat messages to token IDs using mistral-common.
This produces the correct token sequence for Devstral, including
proper handling of control tokens like [INST] and [/INST].
Args:
system_prompt: System message content
user_prompt: User message content (e.g., "def quicksort(arr):")
Returns:
List of token IDs ready for model input
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
from mistral_common.protocol.instruct.messages import (
SystemMessage, UserMessage
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
# Build messages list
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(UserMessage(content=user_prompt))
# Encode using mistral-common's chat completion encoding
request = ChatCompletionRequest(messages=messages)
tokenized = self.tokenizer.encode_chat_completion(request)
logger.info(f"Encoded chat: {len(tokenized.tokens)} tokens")
return tokenized.tokens
def decode(self, token_ids: List[int]) -> str:
"""
Decode token IDs back to text.
Args:
token_ids: List of token IDs to decode
Returns:
Decoded text string
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
return self.tokenizer.decode(token_ids)
def decode_token(self, token_id: int) -> str:
"""
Decode a single token ID to text.
Args:
token_id: Single token ID to decode
Returns:
Decoded text for this token
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
result = self.tokenizer.decode([token_id])
return result
def get_control_token_ids(self) -> Set[int]:
"""
Return the full set of control/special token IDs known to the
underlying Tekkenizer (e.g. ``<s>``, ``</s>``, ``[INST]``, ``[/INST]``,
``[SYSTEM_PROMPT]``, tool-call markers, etc.).
These IDs are needed to label tokens with an accurate ``is_special``
flag in the trace response. The HF tokenizer's ``all_special_ids``
misses Mistral-specific chat-template delimiters, so we source them
directly from mistral-common.
Tries multiple attribute paths for robustness across mistral-common
versions. Falls back to an empty set (with a warning) if none work —
callers should still have the HF ``all_special_ids`` as a baseline.
"""
if not self._available:
return set()
try:
inner = self.tokenizer.instruct_tokenizer.tokenizer
except AttributeError:
logger.warning(
"MistralTokenizer has no instruct_tokenizer.tokenizer attribute"
)
return set()
# Preferred path: Tekkenizer reserves ranks [0, num_special_tokens)
# for control tokens, so we can materialise the full set cheaply.
num_special = getattr(inner, "num_special_tokens", None)
if isinstance(num_special, int) and num_special > 0:
return set(range(num_special))
# Fallback: try a couple of commonly-used attribute shapes.
for attr in ("_special_tokens", "special_tokens"):
specials = getattr(inner, attr, None)
if isinstance(specials, dict):
# dict[str, int] — values are token IDs
try:
return {int(v) for v in specials.values()}
except Exception:
pass
if isinstance(specials, (list, tuple, set)):
try:
return {int(v) for v in specials}
except Exception:
pass
logger.warning(
"Could not determine control token ids from MistralTokenizer; "
"is_special will be limited to HF tokenizer's all_special_ids"
)
return set()
def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
"""
Factory function to create a MistralTokenizerWrapper.
Returns None if mistral-common is not available or loading fails.
Args:
model_name: HuggingFace model path
Returns:
MistralTokenizerWrapper instance or None
"""
wrapper = MistralTokenizerWrapper(model_name)
if wrapper.is_available:
return wrapper
return None