gary-boon Claude Opus 4.5 commited on
Commit
2c6343b
·
1 Parent(s): bc1f0e0

feat: add top-k/top-p sampling and detailed logits/probability tracking

Browse files

- Add top_k and top_p sampling parameters to research attention endpoints
- Capture raw logits before temperature scaling for visualization
- Track greedy token (argmax) separately from sampled token
- Add raw probabilities (T=1) for comparison in UI
- Track cumulative probability and rank for each token alternative
- Add sampling metadata (temperature, top_k, top_p, was_greedy)
- Handle selected token outliers when not in top-N display

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +234 -34
backend/model_service.py CHANGED
@@ -1617,6 +1617,8 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1617
  max_tokens = request.get("max_tokens", 8)
1618
  auto_complete = request.get("auto_complete", False)
1619
  temperature = request.get("temperature", 0.7)
 
 
1620
 
1621
  # If auto_complete mode, ensure we have a reasonable upper limit
1622
  if auto_complete:
@@ -1791,27 +1793,79 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1791
  )
1792
 
1793
  # Get logits for next token
1794
- logits = outputs.logits[0, -1, :]
1795
 
1796
- # Apply temperature and sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1797
  if temperature > 0:
1798
  logits = logits / temperature
1799
  probs = torch.softmax(logits, dim=0)
1800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1801
  if temperature == 0:
1802
- next_token_id = torch.argmax(probs, dim=-1).item()
1803
  else:
1804
- next_token_id = torch.multinomial(probs, 1).item()
 
 
 
 
1805
  next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False)
1806
 
1807
  generated_token_ids.append(next_token_id)
1808
  generated_tokens.append(next_token_text)
1809
 
1810
- # Capture top-k token alternatives with probabilities
1811
  # Use log_softmax for numerical stability at low temperatures
1812
- import math
1813
- top_k = 5 # Get top 5 alternatives
1814
- _, top_indices = torch.topk(logits, k=min(top_k, len(logits)))
1815
 
1816
  # Use log_softmax (numerically stable) then exp() for probabilities
1817
  # This avoids underflow that occurs with softmax at low temperatures
@@ -1820,19 +1874,69 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1820
  top_probs = torch.exp(log_probs[top_indices])
1821
 
1822
  alternatives = []
1823
- for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
 
 
1824
  token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
 
 
 
1825
  alternatives.append({
1826
  "token": token_text,
1827
  "token_id": idx,
1828
  "probability": prob,
1829
- "log_probability": math.log(prob) if prob > 0 else float('-inf')
 
 
 
1830
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1831
  token_alternatives_by_step.append({
1832
  "step": step,
1833
  "selected_token": next_token_text,
1834
  "selected_token_id": next_token_id,
1835
- "alternatives": alternatives
 
 
1836
  })
1837
 
1838
  # Process attention and hidden states for ALL layers
@@ -2168,6 +2272,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2168
  max_tokens = request.get("max_tokens", 8)
2169
  auto_complete = request.get("auto_complete", False)
2170
  temperature = request.get("temperature", 0.7)
 
 
2171
 
2172
  # If auto_complete mode, ensure we have a reasonable upper limit
2173
  if auto_complete:
@@ -2362,57 +2468,151 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
2362
  )
2363
 
2364
  # Get logits for next token
2365
- logits = outputs.logits[0, -1, :]
 
 
 
 
 
2366
 
2367
- # Apply temperature and sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2368
  if temperature > 0:
2369
  logits = logits / temperature
2370
  probs = torch.softmax(logits, dim=0)
2371
 
2372
- if temperature == 0:
2373
- next_token_id = torch.argmax(probs, dim=-1).item()
 
 
 
 
2374
  else:
2375
- next_token_id = torch.multinomial(probs, 1).item()
2376
 
2377
- # Use correct tokenizer for Devstral vs other models
2378
- if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
2379
- next_token_text = manager.mistral_tokenizer.decode_token(next_token_id)
 
 
 
 
 
 
 
 
 
 
 
 
2380
  else:
2381
- next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False)
 
 
 
 
 
2382
 
2383
  generated_token_ids.append(next_token_id)
2384
  generated_tokens.append(next_token_text)
2385
 
2386
- # Capture top-k token alternatives
2387
  # Use log_softmax for numerical stability at low temperatures
2388
- import math as math_module
2389
- top_k = 5
2390
- _, top_indices = torch.topk(logits, k=min(top_k, len(logits)))
2391
 
2392
  # Use log_softmax (numerically stable) then exp() for probabilities
2393
- # This avoids underflow that occurs with softmax at low temperatures
2394
- # Note: logits is ALREADY temperature-scaled above, so no need to divide again
2395
  log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
2396
  top_probs = torch.exp(log_probs[top_indices])
2397
 
2398
  alternatives = []
2399
- for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
2400
- # Use correct tokenizer for Devstral vs other models
2401
- if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
2402
- token_text = manager.mistral_tokenizer.decode_token(idx)
2403
- else:
2404
- token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
 
2405
  alternatives.append({
2406
  "token": token_text,
2407
  "token_id": idx,
2408
  "probability": prob,
2409
- "log_probability": math_module.log(prob) if prob > 0 else float('-inf')
 
 
 
2410
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2411
  token_alternatives_by_step.append({
2412
  "step": step,
2413
  "selected_token": next_token_text,
2414
  "selected_token_id": next_token_id,
2415
- "alternatives": alternatives
 
 
2416
  })
2417
 
2418
  # === STAGE 3: EXTRACTING (per layer within each token) ===
 
1617
  max_tokens = request.get("max_tokens", 8)
1618
  auto_complete = request.get("auto_complete", False)
1619
  temperature = request.get("temperature", 0.7)
1620
+ top_k_param = request.get("top_k", None) # Top-k sampling parameter
1621
+ top_p_param = request.get("top_p", None) # Top-p (nucleus) sampling parameter
1622
 
1623
  # If auto_complete mode, ensure we have a reasonable upper limit
1624
  if auto_complete:
 
1793
  )
1794
 
1795
  # Get logits for next token
1796
+ raw_logits = outputs.logits[0, -1, :].clone() # Clone raw logits before any scaling
1797
 
1798
+ # Capture raw logits for top-10 tokens (before temperature scaling)
1799
+ import math
1800
+ top_n_display = 10 # Get top 10 alternatives for display
1801
+ top_raw_logits, top_raw_indices = torch.topk(raw_logits, k=min(top_n_display, len(raw_logits)))
1802
+
1803
+ # Build raw logits entries (before temperature)
1804
+ logits_entries = []
1805
+ for rank, (logit_val, idx) in enumerate(zip(top_raw_logits.tolist(), top_raw_indices.tolist())):
1806
+ token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
1807
+ logits_entries.append({
1808
+ "token": token_text,
1809
+ "token_id": idx,
1810
+ "logit": logit_val,
1811
+ "rank": rank + 1
1812
+ })
1813
+
1814
+ # Greedy token (argmax of raw logits, before any sampling)
1815
+ greedy_token_id = torch.argmax(raw_logits).item()
1816
+ greedy_token = manager.tokenizer.decode([greedy_token_id], skip_special_tokens=False)
1817
+
1818
+ # Compute raw probabilities (T=1) for comparison visualization
1819
+ raw_probs = torch.softmax(raw_logits, dim=0)
1820
+
1821
+ # Apply temperature scaling
1822
+ logits = raw_logits.clone()
1823
  if temperature > 0:
1824
  logits = logits / temperature
1825
  probs = torch.softmax(logits, dim=0)
1826
 
1827
+ # Apply top-k filtering if specified
1828
+ if top_k_param is not None and top_k_param > 0:
1829
+ top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k_param, len(probs)))
1830
+ probs_filtered = torch.zeros_like(probs)
1831
+ probs_filtered[top_k_indices] = top_k_probs
1832
+ probs_filtered = probs_filtered / probs_filtered.sum() # Renormalize
1833
+ else:
1834
+ probs_filtered = probs
1835
+
1836
+ # Apply top-p (nucleus) filtering if specified
1837
+ if top_p_param is not None and top_p_param < 1.0:
1838
+ sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True)
1839
+ cumulative_probs = torch.cumsum(sorted_probs, dim=0)
1840
+ # Find cutoff index where cumulative exceeds top_p
1841
+ cutoff_mask = cumulative_probs > top_p_param
1842
+ # Shift mask by 1 to keep at least one token
1843
+ cutoff_mask[1:] = cutoff_mask[:-1].clone()
1844
+ cutoff_mask[0] = False
1845
+ # Zero out tokens beyond cutoff
1846
+ sorted_probs[cutoff_mask] = 0
1847
+ # Scatter back to original order
1848
+ probs_filtered = torch.zeros_like(probs)
1849
+ probs_filtered.scatter_(0, sorted_indices, sorted_probs)
1850
+ if probs_filtered.sum() > 0:
1851
+ probs_filtered = probs_filtered / probs_filtered.sum() # Renormalize
1852
+
1853
  if temperature == 0:
1854
+ next_token_id = torch.argmax(probs_filtered, dim=-1).item()
1855
  else:
1856
+ # Ensure valid probability distribution for multinomial
1857
+ if probs_filtered.sum() > 0:
1858
+ next_token_id = torch.multinomial(probs_filtered, 1).item()
1859
+ else:
1860
+ next_token_id = torch.argmax(probs, dim=-1).item()
1861
  next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False)
1862
 
1863
  generated_token_ids.append(next_token_id)
1864
  generated_tokens.append(next_token_text)
1865
 
1866
+ # Capture top-10 token alternatives with probabilities
1867
  # Use log_softmax for numerical stability at low temperatures
1868
+ _, top_indices = torch.topk(logits, k=min(top_n_display, len(logits)))
 
 
1869
 
1870
  # Use log_softmax (numerically stable) then exp() for probabilities
1871
  # This avoids underflow that occurs with softmax at low temperatures
 
1874
  top_probs = torch.exp(log_probs[top_indices])
1875
 
1876
  alternatives = []
1877
+ cumulative = 0.0
1878
+ selected_in_top = False
1879
+ for rank, (prob, idx) in enumerate(zip(top_probs.tolist(), top_indices.tolist())):
1880
  token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
1881
+ cumulative += prob
1882
+ if idx == next_token_id:
1883
+ selected_in_top = True
1884
  alternatives.append({
1885
  "token": token_text,
1886
  "token_id": idx,
1887
  "probability": prob,
1888
+ "raw_probability": raw_probs[idx].item(), # T=1 probability for comparison
1889
+ "log_probability": math.log(prob) if prob > 0 else float('-inf'),
1890
+ "cumulative_probability": cumulative,
1891
+ "rank": rank + 1
1892
  })
1893
+
1894
+ # If selected token is not in top-N, add it with its actual probability
1895
+ if not selected_in_top:
1896
+ selected_prob = probs[next_token_id].item()
1897
+ selected_raw_prob = raw_probs[next_token_id].item()
1898
+ selected_log_prob = log_probs[next_token_id].item()
1899
+ selected_logit = raw_logits[next_token_id].item()
1900
+ # Find the rank of the selected token
1901
+ sorted_indices = torch.argsort(raw_logits, descending=True)
1902
+ selected_rank = (sorted_indices == next_token_id).nonzero(as_tuple=True)[0].item() + 1
1903
+ alternatives.append({
1904
+ "token": next_token_text,
1905
+ "token_id": next_token_id,
1906
+ "probability": selected_prob,
1907
+ "raw_probability": selected_raw_prob, # T=1 probability for comparison
1908
+ "log_probability": selected_log_prob,
1909
+ "cumulative_probability": None, # Not in sequence
1910
+ "rank": selected_rank,
1911
+ "is_selected_outlier": True # Flag for UI
1912
+ })
1913
+ # Also add to logits if not present
1914
+ if next_token_id not in [e["token_id"] for e in logits_entries]:
1915
+ logits_entries.append({
1916
+ "token": next_token_text,
1917
+ "token_id": next_token_id,
1918
+ "logit": selected_logit,
1919
+ "rank": selected_rank,
1920
+ "is_selected_outlier": True
1921
+ })
1922
+
1923
+ # Build sampling metadata
1924
+ sampling_metadata = {
1925
+ "temperature": temperature,
1926
+ "top_k": top_k_param,
1927
+ "top_p": top_p_param,
1928
+ "greedy_token_id": greedy_token_id,
1929
+ "greedy_token": greedy_token,
1930
+ "was_greedy": next_token_id == greedy_token_id
1931
+ }
1932
+
1933
  token_alternatives_by_step.append({
1934
  "step": step,
1935
  "selected_token": next_token_text,
1936
  "selected_token_id": next_token_id,
1937
+ "alternatives": alternatives,
1938
+ "logits": logits_entries,
1939
+ "sampling": sampling_metadata
1940
  })
1941
 
1942
  # Process attention and hidden states for ALL layers
 
2272
  max_tokens = request.get("max_tokens", 8)
2273
  auto_complete = request.get("auto_complete", False)
2274
  temperature = request.get("temperature", 0.7)
2275
+ top_k_param = request.get("top_k", None) # Top-k sampling parameter
2276
+ top_p_param = request.get("top_p", None) # Top-p (nucleus) sampling parameter
2277
 
2278
  # If auto_complete mode, ensure we have a reasonable upper limit
2279
  if auto_complete:
 
2468
  )
2469
 
2470
  # Get logits for next token
2471
+ raw_logits = outputs.logits[0, -1, :].clone() # Clone raw logits before any scaling
2472
+
2473
+ # Capture raw logits for top-10 tokens (before temperature scaling)
2474
+ import math as math_module
2475
+ top_n_display = 10 # Get top 10 alternatives for display
2476
+ top_raw_logits, top_raw_indices = torch.topk(raw_logits, k=min(top_n_display, len(raw_logits)))
2477
 
2478
+ # Build raw logits entries (before temperature)
2479
+ # Use correct tokenizer for Devstral vs other models
2480
+ def decode_token(tid):
2481
+ if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
2482
+ return manager.mistral_tokenizer.decode_token(tid)
2483
+ else:
2484
+ return manager.tokenizer.decode([tid], skip_special_tokens=False)
2485
+
2486
+ logits_entries = []
2487
+ for rank, (logit_val, idx) in enumerate(zip(top_raw_logits.tolist(), top_raw_indices.tolist())):
2488
+ token_text = decode_token(idx)
2489
+ logits_entries.append({
2490
+ "token": token_text,
2491
+ "token_id": idx,
2492
+ "logit": logit_val,
2493
+ "rank": rank + 1
2494
+ })
2495
+
2496
+ # Greedy token (argmax of raw logits, before any sampling)
2497
+ greedy_token_id = torch.argmax(raw_logits).item()
2498
+ greedy_token = decode_token(greedy_token_id)
2499
+
2500
+ # Compute raw probabilities (T=1) for comparison visualization
2501
+ raw_probs = torch.softmax(raw_logits, dim=0)
2502
+
2503
+ # Apply temperature scaling
2504
+ logits = raw_logits.clone()
2505
  if temperature > 0:
2506
  logits = logits / temperature
2507
  probs = torch.softmax(logits, dim=0)
2508
 
2509
+ # Apply top-k filtering if specified
2510
+ if top_k_param is not None and top_k_param > 0:
2511
+ top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k_param, len(probs)))
2512
+ probs_filtered = torch.zeros_like(probs)
2513
+ probs_filtered[top_k_indices] = top_k_probs
2514
+ probs_filtered = probs_filtered / probs_filtered.sum() # Renormalize
2515
  else:
2516
+ probs_filtered = probs
2517
 
2518
+ # Apply top-p (nucleus) filtering if specified
2519
+ if top_p_param is not None and top_p_param < 1.0:
2520
+ sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True)
2521
+ cumulative_probs = torch.cumsum(sorted_probs, dim=0)
2522
+ cutoff_mask = cumulative_probs > top_p_param
2523
+ cutoff_mask[1:] = cutoff_mask[:-1].clone()
2524
+ cutoff_mask[0] = False
2525
+ sorted_probs[cutoff_mask] = 0
2526
+ probs_filtered = torch.zeros_like(probs)
2527
+ probs_filtered.scatter_(0, sorted_indices, sorted_probs)
2528
+ if probs_filtered.sum() > 0:
2529
+ probs_filtered = probs_filtered / probs_filtered.sum()
2530
+
2531
+ if temperature == 0:
2532
+ next_token_id = torch.argmax(probs_filtered, dim=-1).item()
2533
  else:
2534
+ if probs_filtered.sum() > 0:
2535
+ next_token_id = torch.multinomial(probs_filtered, 1).item()
2536
+ else:
2537
+ next_token_id = torch.argmax(probs, dim=-1).item()
2538
+
2539
+ next_token_text = decode_token(next_token_id)
2540
 
2541
  generated_token_ids.append(next_token_id)
2542
  generated_tokens.append(next_token_text)
2543
 
2544
+ # Capture top-10 token alternatives with probabilities
2545
  # Use log_softmax for numerical stability at low temperatures
2546
+ _, top_indices = torch.topk(logits, k=min(top_n_display, len(logits)))
 
 
2547
 
2548
  # Use log_softmax (numerically stable) then exp() for probabilities
 
 
2549
  log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
2550
  top_probs = torch.exp(log_probs[top_indices])
2551
 
2552
  alternatives = []
2553
+ cumulative = 0.0
2554
+ selected_in_top = False
2555
+ for rank, (prob, idx) in enumerate(zip(top_probs.tolist(), top_indices.tolist())):
2556
+ token_text = decode_token(idx)
2557
+ cumulative += prob
2558
+ if idx == next_token_id:
2559
+ selected_in_top = True
2560
  alternatives.append({
2561
  "token": token_text,
2562
  "token_id": idx,
2563
  "probability": prob,
2564
+ "raw_probability": raw_probs[idx].item(), # T=1 probability for comparison
2565
+ "log_probability": math_module.log(prob) if prob > 0 else float('-inf'),
2566
+ "cumulative_probability": cumulative,
2567
+ "rank": rank + 1
2568
  })
2569
+
2570
+ # If selected token is not in top-N, add it with its actual probability
2571
+ if not selected_in_top:
2572
+ selected_prob = probs[next_token_id].item()
2573
+ selected_raw_prob = raw_probs[next_token_id].item()
2574
+ selected_log_prob = log_probs[next_token_id].item()
2575
+ selected_logit = raw_logits[next_token_id].item()
2576
+ # Find the rank of the selected token
2577
+ sorted_indices = torch.argsort(raw_logits, descending=True)
2578
+ selected_rank = (sorted_indices == next_token_id).nonzero(as_tuple=True)[0].item() + 1
2579
+ alternatives.append({
2580
+ "token": next_token_text,
2581
+ "token_id": next_token_id,
2582
+ "probability": selected_prob,
2583
+ "raw_probability": selected_raw_prob, # T=1 probability for comparison
2584
+ "log_probability": selected_log_prob,
2585
+ "cumulative_probability": None,
2586
+ "rank": selected_rank,
2587
+ "is_selected_outlier": True
2588
+ })
2589
+ # Also add to logits if not present
2590
+ if next_token_id not in [e["token_id"] for e in logits_entries]:
2591
+ logits_entries.append({
2592
+ "token": next_token_text,
2593
+ "token_id": next_token_id,
2594
+ "logit": selected_logit,
2595
+ "rank": selected_rank,
2596
+ "is_selected_outlier": True
2597
+ })
2598
+
2599
+ # Build sampling metadata
2600
+ sampling_metadata = {
2601
+ "temperature": temperature,
2602
+ "top_k": top_k_param,
2603
+ "top_p": top_p_param,
2604
+ "greedy_token_id": greedy_token_id,
2605
+ "greedy_token": greedy_token,
2606
+ "was_greedy": next_token_id == greedy_token_id
2607
+ }
2608
+
2609
  token_alternatives_by_step.append({
2610
  "step": step,
2611
  "selected_token": next_token_text,
2612
  "selected_token_id": next_token_id,
2613
+ "alternatives": alternatives,
2614
+ "logits": logits_entries,
2615
+ "sampling": sampling_metadata
2616
  })
2617
 
2618
  # === STAGE 3: EXTRACTING (per layer within each token) ===