File size: 6,507 Bytes
ed06dcb
 
 
 
 
 
 
 
 
 
b5e4add
ed06dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bdf299
 
ed06dcb
b5e4add
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed06dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Mistral Tokenizer Wrapper
Provides correct tokenization for Devstral using mistral-common library.

The Tekken tokenizer used by Devstral is incompatible with HuggingFace's
standard tokenization approach. This wrapper uses mistral-common to
produce correct token sequences for the model.
"""

import logging
from typing import List, Optional, Set

logger = logging.getLogger(__name__)


class MistralTokenizerWrapper:
    """
    Wrapper around mistral-common's MistralTokenizer for Devstral.

    Uses encode_chat_completion() to produce correct token IDs
    that the model actually expects, rather than HF's text-based approach
    which produces corrupted tokens for Tekken-based models.
    """

    def __init__(self, model_name: str):
        """
        Initialize the Mistral tokenizer from HuggingFace hub.

        Args:
            model_name: HuggingFace model path (e.g., "mistralai/Devstral-Small-2507")
        """
        try:
            from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
            self.tokenizer = MistralTokenizer.from_hf_hub(model_name)
            self._available = True
            logger.info(f"Loaded MistralTokenizer for {model_name}")
        except ImportError as e:
            logger.warning(f"mistral-common not available: {e}")
            self._available = False
            self.tokenizer = None
        except Exception as e:
            logger.error(f"Failed to load MistralTokenizer: {e}")
            self._available = False
            self.tokenizer = None

    @property
    def is_available(self) -> bool:
        """Check if the tokenizer was loaded successfully."""
        return self._available

    def encode_chat(
        self,
        system_prompt: str,
        user_prompt: str
    ) -> List[int]:
        """
        Encode chat messages to token IDs using mistral-common.

        This produces the correct token sequence for Devstral, including
        proper handling of control tokens like [INST] and [/INST].

        Args:
            system_prompt: System message content
            user_prompt: User message content (e.g., "def quicksort(arr):")

        Returns:
            List of token IDs ready for model input
        """
        if not self._available:
            raise RuntimeError("MistralTokenizer not available")

        from mistral_common.protocol.instruct.messages import (
            SystemMessage, UserMessage
        )
        from mistral_common.protocol.instruct.request import ChatCompletionRequest

        # Build messages list
        messages = []
        if system_prompt:
            messages.append(SystemMessage(content=system_prompt))
        messages.append(UserMessage(content=user_prompt))

        # Encode using mistral-common's chat completion encoding
        request = ChatCompletionRequest(messages=messages)
        tokenized = self.tokenizer.encode_chat_completion(request)

        logger.info(f"Encoded chat: {len(tokenized.tokens)} tokens")
        return tokenized.tokens

    def decode(self, token_ids: List[int]) -> str:
        """
        Decode token IDs back to text.

        Args:
            token_ids: List of token IDs to decode

        Returns:
            Decoded text string
        """
        if not self._available:
            raise RuntimeError("MistralTokenizer not available")

        return self.tokenizer.decode(token_ids)

    def decode_token(self, token_id: int) -> str:
        """
        Decode a single token ID to text.

        Args:
            token_id: Single token ID to decode

        Returns:
            Decoded text for this token
        """
        if not self._available:
            raise RuntimeError("MistralTokenizer not available")

        result = self.tokenizer.decode([token_id])
        return result

    def get_control_token_ids(self) -> Set[int]:
        """
        Return the full set of control/special token IDs known to the
        underlying Tekkenizer (e.g. ``<s>``, ``</s>``, ``[INST]``, ``[/INST]``,
        ``[SYSTEM_PROMPT]``, tool-call markers, etc.).

        These IDs are needed to label tokens with an accurate ``is_special``
        flag in the trace response. The HF tokenizer's ``all_special_ids``
        misses Mistral-specific chat-template delimiters, so we source them
        directly from mistral-common.

        Tries multiple attribute paths for robustness across mistral-common
        versions. Falls back to an empty set (with a warning) if none work —
        callers should still have the HF ``all_special_ids`` as a baseline.
        """
        if not self._available:
            return set()

        try:
            inner = self.tokenizer.instruct_tokenizer.tokenizer
        except AttributeError:
            logger.warning(
                "MistralTokenizer has no instruct_tokenizer.tokenizer attribute"
            )
            return set()

        # Preferred path: Tekkenizer reserves ranks [0, num_special_tokens)
        # for control tokens, so we can materialise the full set cheaply.
        num_special = getattr(inner, "num_special_tokens", None)
        if isinstance(num_special, int) and num_special > 0:
            return set(range(num_special))

        # Fallback: try a couple of commonly-used attribute shapes.
        for attr in ("_special_tokens", "special_tokens"):
            specials = getattr(inner, attr, None)
            if isinstance(specials, dict):
                # dict[str, int] — values are token IDs
                try:
                    return {int(v) for v in specials.values()}
                except Exception:
                    pass
            if isinstance(specials, (list, tuple, set)):
                try:
                    return {int(v) for v in specials}
                except Exception:
                    pass

        logger.warning(
            "Could not determine control token ids from MistralTokenizer; "
            "is_special will be limited to HF tokenizer's all_special_ids"
        )
        return set()


def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
    """
    Factory function to create a MistralTokenizerWrapper.

    Returns None if mistral-common is not available or loading fails.

    Args:
        model_name: HuggingFace model path

    Returns:
        MistralTokenizerWrapper instance or None
    """
    wrapper = MistralTokenizerWrapper(model_name)
    if wrapper.is_available:
        return wrapper
    return None