gary-boon Claude Opus 4.6 (1M context) commited on
Commit
b5e4add
·
1 Parent(s): e375e45

Fix empty text and incorrect is_special for Mistral control tokens

Browse files

mistral-common's decode_token() returns "" for chat-template tokens
(<s>, [INST], [/INST], [SYSTEM_PROMPT], [/SYSTEM_PROMPT], tool markers).
Fall back to the HF tokenizer so every token arrives with a printable
string form. Also widen special_token_ids from {eos, bos, pad, unk} to
include all control-token IDs from mistral-common's Tekkenizer, fixing
is_special for chat-template delimiters.

Applied to both streaming and non-streaming endpoints, plus the
logit-candidate decode helper.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

backend/mistral_tokenizer.py CHANGED
@@ -8,7 +8,7 @@ produce correct token sequences for the model.
8
  """
9
 
10
  import logging
11
- from typing import List, Optional
12
 
13
  logger = logging.getLogger(__name__)
14
 
@@ -118,6 +118,59 @@ class MistralTokenizerWrapper:
118
  result = self.tokenizer.decode([token_id])
119
  return result
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
123
  """
 
8
  """
9
 
10
  import logging
11
+ from typing import List, Optional, Set
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
118
  result = self.tokenizer.decode([token_id])
119
  return result
120
 
121
+ def get_control_token_ids(self) -> Set[int]:
122
+ """
123
+ Return the full set of control/special token IDs known to the
124
+ underlying Tekkenizer (e.g. ``<s>``, ``</s>``, ``[INST]``, ``[/INST]``,
125
+ ``[SYSTEM_PROMPT]``, tool-call markers, etc.).
126
+
127
+ These IDs are needed to label tokens with an accurate ``is_special``
128
+ flag in the trace response. The HF tokenizer's ``all_special_ids``
129
+ misses Mistral-specific chat-template delimiters, so we source them
130
+ directly from mistral-common.
131
+
132
+ Tries multiple attribute paths for robustness across mistral-common
133
+ versions. Falls back to an empty set (with a warning) if none work —
134
+ callers should still have the HF ``all_special_ids`` as a baseline.
135
+ """
136
+ if not self._available:
137
+ return set()
138
+
139
+ try:
140
+ inner = self.tokenizer.instruct_tokenizer.tokenizer
141
+ except AttributeError:
142
+ logger.warning(
143
+ "MistralTokenizer has no instruct_tokenizer.tokenizer attribute"
144
+ )
145
+ return set()
146
+
147
+ # Preferred path: Tekkenizer reserves ranks [0, num_special_tokens)
148
+ # for control tokens, so we can materialise the full set cheaply.
149
+ num_special = getattr(inner, "num_special_tokens", None)
150
+ if isinstance(num_special, int) and num_special > 0:
151
+ return set(range(num_special))
152
+
153
+ # Fallback: try a couple of commonly-used attribute shapes.
154
+ for attr in ("_special_tokens", "special_tokens"):
155
+ specials = getattr(inner, attr, None)
156
+ if isinstance(specials, dict):
157
+ # dict[str, int] — values are token IDs
158
+ try:
159
+ return {int(v) for v in specials.values()}
160
+ except Exception:
161
+ pass
162
+ if isinstance(specials, (list, tuple, set)):
163
+ try:
164
+ return {int(v) for v in specials}
165
+ except Exception:
166
+ pass
167
+
168
+ logger.warning(
169
+ "Could not determine control token ids from MistralTokenizer; "
170
+ "is_special will be limited to HF tokenizer's all_special_ids"
171
+ )
172
+ return set()
173
+
174
 
175
  def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
176
  """
backend/model_service.py CHANGED
@@ -1846,8 +1846,19 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1846
  prompt_token_ids = manager.mistral_tokenizer.encode_chat(system_prompt, prompt)
1847
  inputs = {"input_ids": torch.tensor([prompt_token_ids]).to(manager.device)}
1848
  prompt_length = len(prompt_token_ids)
1849
- # Decode tokens using MistralTokenizer for accuracy
1850
- prompt_tokens = [manager.mistral_tokenizer.decode_token(tid) for tid in prompt_token_ids]
 
 
 
 
 
 
 
 
 
 
 
1851
  logger.info(f"Used MistralTokenizer for Devstral: {prompt_length} tokens")
1852
  else:
1853
  # Standard HF tokenization for other models
@@ -2476,12 +2487,23 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
2476
  from .tokenizer_utils import TokenizerMetadata
2477
  token_metadata = TokenizerMetadata(manager.tokenizer)
2478
 
2479
- special_token_ids = {
 
 
 
 
 
 
 
2480
  manager.tokenizer.eos_token_id,
2481
  manager.tokenizer.bos_token_id,
2482
  manager.tokenizer.pad_token_id,
2483
- manager.tokenizer.unk_token_id
2484
- }
 
 
 
 
2485
 
2486
  def build_token_data(token_ids, token_texts, token_type):
2487
  """Build token data with full metadata for hover tooltips"""
@@ -2610,7 +2632,18 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2610
  prompt_token_ids = manager.mistral_tokenizer.encode_chat(system_prompt, prompt)
2611
  inputs = {"input_ids": torch.tensor([prompt_token_ids]).to(manager.device)}
2612
  prompt_length = len(prompt_token_ids)
2613
- prompt_tokens = [manager.mistral_tokenizer.decode_token(tid) for tid in prompt_token_ids]
 
 
 
 
 
 
 
 
 
 
 
2614
  else:
2615
  inputs = manager.tokenizer(formatted_prompt, return_tensors="pt").to(manager.device)
2616
  prompt_length = inputs["input_ids"].shape[1]
@@ -2859,10 +2892,19 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2859
  top_raw_logits, top_raw_indices = torch.topk(raw_logits, k=min(top_n_display, len(raw_logits)))
2860
 
2861
  # Build raw logits entries (before temperature)
2862
- # Use correct tokenizer for Devstral vs other models
 
 
 
2863
  def decode_token(tid):
2864
  if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
2865
- return manager.mistral_tokenizer.decode_token(tid)
 
 
 
 
 
 
2866
  else:
2867
  return manager.tokenizer.decode([tid], skip_special_tokens=False)
2868
 
@@ -3543,12 +3585,21 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
3543
  from .tokenizer_utils import TokenizerMetadata
3544
  token_metadata_builder = TokenizerMetadata(manager.tokenizer)
3545
 
3546
- special_token_ids_set = {
 
 
 
 
 
3547
  manager.tokenizer.eos_token_id,
3548
  manager.tokenizer.bos_token_id,
3549
  manager.tokenizer.pad_token_id,
3550
- manager.tokenizer.unk_token_id
3551
- }
 
 
 
 
3552
 
3553
  def build_token_data(token_ids, token_texts, token_type):
3554
  multi_split_flags = token_metadata_builder.is_multi_split_identifier(token_ids)
 
1846
  prompt_token_ids = manager.mistral_tokenizer.encode_chat(system_prompt, prompt)
1847
  inputs = {"input_ids": torch.tensor([prompt_token_ids]).to(manager.device)}
1848
  prompt_length = len(prompt_token_ids)
1849
+ # Decode tokens using MistralTokenizer for accuracy. mistral-common's
1850
+ # decode_token() returns "" for some control tokens (<s>, [INST],
1851
+ # [/INST], [SYSTEM_PROMPT], etc.). Fall back to the HF tokenizer so
1852
+ # every token arrives with a printable string form.
1853
+ def _decode_prompt_token(tid: int) -> str:
1854
+ text = manager.mistral_tokenizer.decode_token(tid)
1855
+ if text:
1856
+ return text
1857
+ try:
1858
+ return manager.tokenizer.decode([tid], skip_special_tokens=False) or ""
1859
+ except Exception:
1860
+ return ""
1861
+ prompt_tokens = [_decode_prompt_token(tid) for tid in prompt_token_ids]
1862
  logger.info(f"Used MistralTokenizer for Devstral: {prompt_length} tokens")
1863
  else:
1864
  # Standard HF tokenization for other models
 
2487
  from .tokenizer_utils import TokenizerMetadata
2488
  token_metadata = TokenizerMetadata(manager.tokenizer)
2489
 
2490
+ # Include every id the tokenizer considers a special / added token
2491
+ # (BOS, EOS, PAD, UNK, chat-template delimiters like [INST]/[/INST],
2492
+ # system-prompt markers, tool-call markers, etc.). The HF tokenizer's
2493
+ # all_special_ids misses Mistral-specific chat-template delimiters, so
2494
+ # we also union in the control-token ids from mistral-common when the
2495
+ # Mistral path is active.
2496
+ special_token_ids = set(getattr(manager.tokenizer, "all_special_ids", []) or [])
2497
+ for tok_id in (
2498
  manager.tokenizer.eos_token_id,
2499
  manager.tokenizer.bos_token_id,
2500
  manager.tokenizer.pad_token_id,
2501
+ manager.tokenizer.unk_token_id,
2502
+ ):
2503
+ if tok_id is not None:
2504
+ special_token_ids.add(tok_id)
2505
+ if manager.mistral_tokenizer is not None:
2506
+ special_token_ids |= manager.mistral_tokenizer.get_control_token_ids()
2507
 
2508
  def build_token_data(token_ids, token_texts, token_type):
2509
  """Build token data with full metadata for hover tooltips"""
 
2632
  prompt_token_ids = manager.mistral_tokenizer.encode_chat(system_prompt, prompt)
2633
  inputs = {"input_ids": torch.tensor([prompt_token_ids]).to(manager.device)}
2634
  prompt_length = len(prompt_token_ids)
2635
+ # mistral-common's decode_token() returns "" for control tokens
2636
+ # (<s>, [INST], [/INST], [SYSTEM_PROMPT], etc.). Fall back to
2637
+ # the HF tokenizer so every token has a printable string form.
2638
+ def _decode_prompt_token(tid: int) -> str:
2639
+ text = manager.mistral_tokenizer.decode_token(tid)
2640
+ if text:
2641
+ return text
2642
+ try:
2643
+ return manager.tokenizer.decode([tid], skip_special_tokens=False) or ""
2644
+ except Exception:
2645
+ return ""
2646
+ prompt_tokens = [_decode_prompt_token(tid) for tid in prompt_token_ids]
2647
  else:
2648
  inputs = manager.tokenizer(formatted_prompt, return_tensors="pt").to(manager.device)
2649
  prompt_length = inputs["input_ids"].shape[1]
 
2892
  top_raw_logits, top_raw_indices = torch.topk(raw_logits, k=min(top_n_display, len(raw_logits)))
2893
 
2894
  # Build raw logits entries (before temperature)
2895
+ # Use correct tokenizer for Devstral vs other models.
2896
+ # On the Mistral path, fall back to the HF tokenizer when
2897
+ # mistral-common returns "" so logit candidates (e.g. EOS,
2898
+ # chat-template ids) always carry a printable string.
2899
  def decode_token(tid):
2900
  if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
2901
+ text = manager.mistral_tokenizer.decode_token(tid)
2902
+ if text:
2903
+ return text
2904
+ try:
2905
+ return manager.tokenizer.decode([tid], skip_special_tokens=False) or ""
2906
+ except Exception:
2907
+ return ""
2908
  else:
2909
  return manager.tokenizer.decode([tid], skip_special_tokens=False)
2910
 
 
3585
  from .tokenizer_utils import TokenizerMetadata
3586
  token_metadata_builder = TokenizerMetadata(manager.tokenizer)
3587
 
3588
+ # Include every id the tokenizer considers a special / added token
3589
+ # (BOS, EOS, PAD, UNK, chat-template delimiters like [INST]/[/INST],
3590
+ # system-prompt markers, tool-call markers, etc.). See the matching
3591
+ # construction in the non-streaming endpoint for rationale.
3592
+ special_token_ids_set = set(getattr(manager.tokenizer, "all_special_ids", []) or [])
3593
+ for tok_id in (
3594
  manager.tokenizer.eos_token_id,
3595
  manager.tokenizer.bos_token_id,
3596
  manager.tokenizer.pad_token_id,
3597
+ manager.tokenizer.unk_token_id,
3598
+ ):
3599
+ if tok_id is not None:
3600
+ special_token_ids_set.add(tok_id)
3601
+ if manager.mistral_tokenizer is not None:
3602
+ special_token_ids_set |= manager.mistral_tokenizer.get_control_token_ids()
3603
 
3604
  def build_token_data(token_ids, token_texts, token_type):
3605
  multi_split_flags = token_metadata_builder.is_multi_split_identifier(token_ids)