Spaces:
Paused
Paused
File size: 6,507 Bytes
ed06dcb b5e4add ed06dcb 2bdf299 ed06dcb b5e4add ed06dcb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | """
Mistral Tokenizer Wrapper
Provides correct tokenization for Devstral using mistral-common library.
The Tekken tokenizer used by Devstral is incompatible with HuggingFace's
standard tokenization approach. This wrapper uses mistral-common to
produce correct token sequences for the model.
"""
import logging
from typing import List, Optional, Set
logger = logging.getLogger(__name__)
class MistralTokenizerWrapper:
"""
Wrapper around mistral-common's MistralTokenizer for Devstral.
Uses encode_chat_completion() to produce correct token IDs
that the model actually expects, rather than HF's text-based approach
which produces corrupted tokens for Tekken-based models.
"""
def __init__(self, model_name: str):
"""
Initialize the Mistral tokenizer from HuggingFace hub.
Args:
model_name: HuggingFace model path (e.g., "mistralai/Devstral-Small-2507")
"""
try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
self.tokenizer = MistralTokenizer.from_hf_hub(model_name)
self._available = True
logger.info(f"Loaded MistralTokenizer for {model_name}")
except ImportError as e:
logger.warning(f"mistral-common not available: {e}")
self._available = False
self.tokenizer = None
except Exception as e:
logger.error(f"Failed to load MistralTokenizer: {e}")
self._available = False
self.tokenizer = None
@property
def is_available(self) -> bool:
"""Check if the tokenizer was loaded successfully."""
return self._available
def encode_chat(
self,
system_prompt: str,
user_prompt: str
) -> List[int]:
"""
Encode chat messages to token IDs using mistral-common.
This produces the correct token sequence for Devstral, including
proper handling of control tokens like [INST] and [/INST].
Args:
system_prompt: System message content
user_prompt: User message content (e.g., "def quicksort(arr):")
Returns:
List of token IDs ready for model input
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
from mistral_common.protocol.instruct.messages import (
SystemMessage, UserMessage
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
# Build messages list
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(UserMessage(content=user_prompt))
# Encode using mistral-common's chat completion encoding
request = ChatCompletionRequest(messages=messages)
tokenized = self.tokenizer.encode_chat_completion(request)
logger.info(f"Encoded chat: {len(tokenized.tokens)} tokens")
return tokenized.tokens
def decode(self, token_ids: List[int]) -> str:
"""
Decode token IDs back to text.
Args:
token_ids: List of token IDs to decode
Returns:
Decoded text string
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
return self.tokenizer.decode(token_ids)
def decode_token(self, token_id: int) -> str:
"""
Decode a single token ID to text.
Args:
token_id: Single token ID to decode
Returns:
Decoded text for this token
"""
if not self._available:
raise RuntimeError("MistralTokenizer not available")
result = self.tokenizer.decode([token_id])
return result
def get_control_token_ids(self) -> Set[int]:
"""
Return the full set of control/special token IDs known to the
underlying Tekkenizer (e.g. ``<s>``, ``</s>``, ``[INST]``, ``[/INST]``,
``[SYSTEM_PROMPT]``, tool-call markers, etc.).
These IDs are needed to label tokens with an accurate ``is_special``
flag in the trace response. The HF tokenizer's ``all_special_ids``
misses Mistral-specific chat-template delimiters, so we source them
directly from mistral-common.
Tries multiple attribute paths for robustness across mistral-common
versions. Falls back to an empty set (with a warning) if none work —
callers should still have the HF ``all_special_ids`` as a baseline.
"""
if not self._available:
return set()
try:
inner = self.tokenizer.instruct_tokenizer.tokenizer
except AttributeError:
logger.warning(
"MistralTokenizer has no instruct_tokenizer.tokenizer attribute"
)
return set()
# Preferred path: Tekkenizer reserves ranks [0, num_special_tokens)
# for control tokens, so we can materialise the full set cheaply.
num_special = getattr(inner, "num_special_tokens", None)
if isinstance(num_special, int) and num_special > 0:
return set(range(num_special))
# Fallback: try a couple of commonly-used attribute shapes.
for attr in ("_special_tokens", "special_tokens"):
specials = getattr(inner, attr, None)
if isinstance(specials, dict):
# dict[str, int] — values are token IDs
try:
return {int(v) for v in specials.values()}
except Exception:
pass
if isinstance(specials, (list, tuple, set)):
try:
return {int(v) for v in specials}
except Exception:
pass
logger.warning(
"Could not determine control token ids from MistralTokenizer; "
"is_special will be limited to HF tokenizer's all_special_ids"
)
return set()
def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
"""
Factory function to create a MistralTokenizerWrapper.
Returns None if mistral-common is not available or loading fails.
Args:
model_name: HuggingFace model path
Returns:
MistralTokenizerWrapper instance or None
"""
wrapper = MistralTokenizerWrapper(model_name)
if wrapper.is_available:
return wrapper
return None
|