gary-boon Claude Opus 4.5 commited on
Commit
bc1f0e0
·
1 Parent(s): f2e89c2

fix: Improve token alternatives numerical stability and temperature control

Browse files

- Use log_softmax instead of softmax for numerically stable probability
computation at low temperatures (fixes underflow showing 0% for alternatives)
- Use topk on logits instead of probs to correctly identify top tokens
regardless of temperature (fixes showing special tokens like <unk>)
- Remove backend temperature override - now controlled by frontend UI
- Use correct tokenizer (MistralTokenizer) for Devstral model when
decoding alternative token texts in SSE endpoint

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +35 -11
backend/model_service.py CHANGED
@@ -1646,10 +1646,9 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1646
  if prompt_style == "instruction":
1647
  logger.info(f"Formatted prompt preview: {formatted_prompt[:200]}...")
1648
 
1649
- # Use model's recommended temperature for instruction models
1650
- if model_config and "recommended_temperature" in model_config:
1651
- temperature = model_config["recommended_temperature"]
1652
- logger.info(f"Using model recommended temperature={temperature}")
1653
 
1654
  # Tokenize and prepare - use MistralTokenizer for Devstral
1655
  if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
@@ -1809,9 +1808,17 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1809
  generated_tokens.append(next_token_text)
1810
 
1811
  # Capture top-k token alternatives with probabilities
 
1812
  import math
1813
  top_k = 5 # Get top 5 alternatives
1814
- top_probs, top_indices = torch.topk(probs, k=min(top_k, len(probs)))
 
 
 
 
 
 
 
1815
  alternatives = []
1816
  for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
1817
  token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
@@ -2190,9 +2197,9 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2190
 
2191
  prompt_style = model_config.get("prompt_style", "completion") if model_config else "completion"
2192
 
2193
- # Use model's recommended temperature for instruction models
2194
- if model_config and "recommended_temperature" in model_config:
2195
- temperature = model_config["recommended_temperature"]
2196
 
2197
  # Tokenize and prepare - use MistralTokenizer for Devstral
2198
  if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
@@ -2366,18 +2373,35 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2366
  next_token_id = torch.argmax(probs, dim=-1).item()
2367
  else:
2368
  next_token_id = torch.multinomial(probs, 1).item()
2369
- next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False)
 
 
 
 
 
2370
 
2371
  generated_token_ids.append(next_token_id)
2372
  generated_tokens.append(next_token_text)
2373
 
2374
  # Capture top-k token alternatives
 
2375
  import math as math_module
2376
  top_k = 5
2377
- top_probs, top_indices = torch.topk(probs, k=min(top_k, len(probs)))
 
 
 
 
 
 
 
2378
  alternatives = []
2379
  for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
2380
- token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
 
 
 
 
2381
  alternatives.append({
2382
  "token": token_text,
2383
  "token_id": idx,
 
1646
  if prompt_style == "instruction":
1647
  logger.info(f"Formatted prompt preview: {formatted_prompt[:200]}...")
1648
 
1649
+ # Temperature is now controlled by the frontend UI
1650
+ # The frontend sets appropriate defaults per model (0.15 for Devstral, 0.7 for CodeGen)
1651
+ logger.info(f"Using temperature={temperature} from request")
 
1652
 
1653
  # Tokenize and prepare - use MistralTokenizer for Devstral
1654
  if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
 
1808
  generated_tokens.append(next_token_text)
1809
 
1810
  # Capture top-k token alternatives with probabilities
1811
+ # Use log_softmax for numerical stability at low temperatures
1812
  import math
1813
  top_k = 5 # Get top 5 alternatives
1814
+ _, top_indices = torch.topk(logits, k=min(top_k, len(logits)))
1815
+
1816
+ # Use log_softmax (numerically stable) then exp() for probabilities
1817
+ # This avoids underflow that occurs with softmax at low temperatures
1818
+ # Note: logits is ALREADY temperature-scaled above, so no need to divide again
1819
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
1820
+ top_probs = torch.exp(log_probs[top_indices])
1821
+
1822
  alternatives = []
1823
  for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
1824
  token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
 
2197
 
2198
  prompt_style = model_config.get("prompt_style", "completion") if model_config else "completion"
2199
 
2200
+ # Temperature is now controlled by the frontend UI
2201
+ # The frontend sets appropriate defaults per model (0.15 for Devstral, 0.7 for CodeGen)
2202
+ logger.info(f"[SSE] Using temperature={temperature} from request")
2203
 
2204
  # Tokenize and prepare - use MistralTokenizer for Devstral
2205
  if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
 
2373
  next_token_id = torch.argmax(probs, dim=-1).item()
2374
  else:
2375
  next_token_id = torch.multinomial(probs, 1).item()
2376
+
2377
+ # Use correct tokenizer for Devstral vs other models
2378
+ if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
2379
+ next_token_text = manager.mistral_tokenizer.decode_token(next_token_id)
2380
+ else:
2381
+ next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False)
2382
 
2383
  generated_token_ids.append(next_token_id)
2384
  generated_tokens.append(next_token_text)
2385
 
2386
  # Capture top-k token alternatives
2387
+ # Use log_softmax for numerical stability at low temperatures
2388
  import math as math_module
2389
  top_k = 5
2390
+ _, top_indices = torch.topk(logits, k=min(top_k, len(logits)))
2391
+
2392
+ # Use log_softmax (numerically stable) then exp() for probabilities
2393
+ # This avoids underflow that occurs with softmax at low temperatures
2394
+ # Note: logits is ALREADY temperature-scaled above, so no need to divide again
2395
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
2396
+ top_probs = torch.exp(log_probs[top_indices])
2397
+
2398
  alternatives = []
2399
  for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
2400
+ # Use correct tokenizer for Devstral vs other models
2401
+ if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
2402
+ token_text = manager.mistral_tokenizer.decode_token(idx)
2403
+ else:
2404
+ token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
2405
  alternatives.append({
2406
  "token": token_text,
2407
  "token_id": idx,