Spaces:
Sleeping
Sleeping
gary-boon
Claude Opus 4.5
commited on
Commit
·
2c6343b
1
Parent(s):
bc1f0e0
feat: add top-k/top-p sampling and detailed logits/probability tracking
Browse files- Add top_k and top_p sampling parameters to research attention endpoints
- Capture raw logits before temperature scaling for visualization
- Track greedy token (argmax) separately from sampled token
- Add raw probabilities (T=1) for comparison in UI
- Track cumulative probability and rank for each token alternative
- Add sampling metadata (temperature, top_k, top_p, was_greedy)
- Handle selected token outliers when not in top-N display
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- backend/model_service.py +234 -34
backend/model_service.py
CHANGED
|
@@ -1617,6 +1617,8 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1617 |
max_tokens = request.get("max_tokens", 8)
|
| 1618 |
auto_complete = request.get("auto_complete", False)
|
| 1619 |
temperature = request.get("temperature", 0.7)
|
|
|
|
|
|
|
| 1620 |
|
| 1621 |
# If auto_complete mode, ensure we have a reasonable upper limit
|
| 1622 |
if auto_complete:
|
|
@@ -1791,27 +1793,79 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1791 |
)
|
| 1792 |
|
| 1793 |
# Get logits for next token
|
| 1794 |
-
|
| 1795 |
|
| 1796 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1797 |
if temperature > 0:
|
| 1798 |
logits = logits / temperature
|
| 1799 |
probs = torch.softmax(logits, dim=0)
|
| 1800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1801 |
if temperature == 0:
|
| 1802 |
-
next_token_id = torch.argmax(
|
| 1803 |
else:
|
| 1804 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1805 |
next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False)
|
| 1806 |
|
| 1807 |
generated_token_ids.append(next_token_id)
|
| 1808 |
generated_tokens.append(next_token_text)
|
| 1809 |
|
| 1810 |
-
# Capture top-
|
| 1811 |
# Use log_softmax for numerical stability at low temperatures
|
| 1812 |
-
|
| 1813 |
-
top_k = 5 # Get top 5 alternatives
|
| 1814 |
-
_, top_indices = torch.topk(logits, k=min(top_k, len(logits)))
|
| 1815 |
|
| 1816 |
# Use log_softmax (numerically stable) then exp() for probabilities
|
| 1817 |
# This avoids underflow that occurs with softmax at low temperatures
|
|
@@ -1820,19 +1874,69 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
|
|
| 1820 |
top_probs = torch.exp(log_probs[top_indices])
|
| 1821 |
|
| 1822 |
alternatives = []
|
| 1823 |
-
|
|
|
|
|
|
|
| 1824 |
token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
|
|
|
|
|
|
|
|
|
|
| 1825 |
alternatives.append({
|
| 1826 |
"token": token_text,
|
| 1827 |
"token_id": idx,
|
| 1828 |
"probability": prob,
|
| 1829 |
-
"
|
|
|
|
|
|
|
|
|
|
| 1830 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1831 |
token_alternatives_by_step.append({
|
| 1832 |
"step": step,
|
| 1833 |
"selected_token": next_token_text,
|
| 1834 |
"selected_token_id": next_token_id,
|
| 1835 |
-
"alternatives": alternatives
|
|
|
|
|
|
|
| 1836 |
})
|
| 1837 |
|
| 1838 |
# Process attention and hidden states for ALL layers
|
|
@@ -2168,6 +2272,8 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2168 |
max_tokens = request.get("max_tokens", 8)
|
| 2169 |
auto_complete = request.get("auto_complete", False)
|
| 2170 |
temperature = request.get("temperature", 0.7)
|
|
|
|
|
|
|
| 2171 |
|
| 2172 |
# If auto_complete mode, ensure we have a reasonable upper limit
|
| 2173 |
if auto_complete:
|
|
@@ -2362,57 +2468,151 @@ async def analyze_research_attention_stream(request: Dict[str, Any], authenticat
|
|
| 2362 |
)
|
| 2363 |
|
| 2364 |
# Get logits for next token
|
| 2365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2366 |
|
| 2367 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2368 |
if temperature > 0:
|
| 2369 |
logits = logits / temperature
|
| 2370 |
probs = torch.softmax(logits, dim=0)
|
| 2371 |
|
| 2372 |
-
|
| 2373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2374 |
else:
|
| 2375 |
-
|
| 2376 |
|
| 2377 |
-
#
|
| 2378 |
-
if
|
| 2379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2380 |
else:
|
| 2381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2382 |
|
| 2383 |
generated_token_ids.append(next_token_id)
|
| 2384 |
generated_tokens.append(next_token_text)
|
| 2385 |
|
| 2386 |
-
# Capture top-
|
| 2387 |
# Use log_softmax for numerical stability at low temperatures
|
| 2388 |
-
|
| 2389 |
-
top_k = 5
|
| 2390 |
-
_, top_indices = torch.topk(logits, k=min(top_k, len(logits)))
|
| 2391 |
|
| 2392 |
# Use log_softmax (numerically stable) then exp() for probabilities
|
| 2393 |
-
# This avoids underflow that occurs with softmax at low temperatures
|
| 2394 |
-
# Note: logits is ALREADY temperature-scaled above, so no need to divide again
|
| 2395 |
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 2396 |
top_probs = torch.exp(log_probs[top_indices])
|
| 2397 |
|
| 2398 |
alternatives = []
|
| 2399 |
-
|
| 2400 |
-
|
| 2401 |
-
|
| 2402 |
-
|
| 2403 |
-
|
| 2404 |
-
|
|
|
|
| 2405 |
alternatives.append({
|
| 2406 |
"token": token_text,
|
| 2407 |
"token_id": idx,
|
| 2408 |
"probability": prob,
|
| 2409 |
-
"
|
|
|
|
|
|
|
|
|
|
| 2410 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2411 |
token_alternatives_by_step.append({
|
| 2412 |
"step": step,
|
| 2413 |
"selected_token": next_token_text,
|
| 2414 |
"selected_token_id": next_token_id,
|
| 2415 |
-
"alternatives": alternatives
|
|
|
|
|
|
|
| 2416 |
})
|
| 2417 |
|
| 2418 |
# === STAGE 3: EXTRACTING (per layer within each token) ===
|
|
|
|
| 1617 |
max_tokens = request.get("max_tokens", 8)
|
| 1618 |
auto_complete = request.get("auto_complete", False)
|
| 1619 |
temperature = request.get("temperature", 0.7)
|
| 1620 |
+
top_k_param = request.get("top_k", None) # Top-k sampling parameter
|
| 1621 |
+
top_p_param = request.get("top_p", None) # Top-p (nucleus) sampling parameter
|
| 1622 |
|
| 1623 |
# If auto_complete mode, ensure we have a reasonable upper limit
|
| 1624 |
if auto_complete:
|
|
|
|
| 1793 |
)
|
| 1794 |
|
| 1795 |
# Get logits for next token
|
| 1796 |
+
raw_logits = outputs.logits[0, -1, :].clone() # Clone raw logits before any scaling
|
| 1797 |
|
| 1798 |
+
# Capture raw logits for top-10 tokens (before temperature scaling)
|
| 1799 |
+
import math
|
| 1800 |
+
top_n_display = 10 # Get top 10 alternatives for display
|
| 1801 |
+
top_raw_logits, top_raw_indices = torch.topk(raw_logits, k=min(top_n_display, len(raw_logits)))
|
| 1802 |
+
|
| 1803 |
+
# Build raw logits entries (before temperature)
|
| 1804 |
+
logits_entries = []
|
| 1805 |
+
for rank, (logit_val, idx) in enumerate(zip(top_raw_logits.tolist(), top_raw_indices.tolist())):
|
| 1806 |
+
token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
|
| 1807 |
+
logits_entries.append({
|
| 1808 |
+
"token": token_text,
|
| 1809 |
+
"token_id": idx,
|
| 1810 |
+
"logit": logit_val,
|
| 1811 |
+
"rank": rank + 1
|
| 1812 |
+
})
|
| 1813 |
+
|
| 1814 |
+
# Greedy token (argmax of raw logits, before any sampling)
|
| 1815 |
+
greedy_token_id = torch.argmax(raw_logits).item()
|
| 1816 |
+
greedy_token = manager.tokenizer.decode([greedy_token_id], skip_special_tokens=False)
|
| 1817 |
+
|
| 1818 |
+
# Compute raw probabilities (T=1) for comparison visualization
|
| 1819 |
+
raw_probs = torch.softmax(raw_logits, dim=0)
|
| 1820 |
+
|
| 1821 |
+
# Apply temperature scaling
|
| 1822 |
+
logits = raw_logits.clone()
|
| 1823 |
if temperature > 0:
|
| 1824 |
logits = logits / temperature
|
| 1825 |
probs = torch.softmax(logits, dim=0)
|
| 1826 |
|
| 1827 |
+
# Apply top-k filtering if specified
|
| 1828 |
+
if top_k_param is not None and top_k_param > 0:
|
| 1829 |
+
top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k_param, len(probs)))
|
| 1830 |
+
probs_filtered = torch.zeros_like(probs)
|
| 1831 |
+
probs_filtered[top_k_indices] = top_k_probs
|
| 1832 |
+
probs_filtered = probs_filtered / probs_filtered.sum() # Renormalize
|
| 1833 |
+
else:
|
| 1834 |
+
probs_filtered = probs
|
| 1835 |
+
|
| 1836 |
+
# Apply top-p (nucleus) filtering if specified
|
| 1837 |
+
if top_p_param is not None and top_p_param < 1.0:
|
| 1838 |
+
sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True)
|
| 1839 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=0)
|
| 1840 |
+
# Find cutoff index where cumulative exceeds top_p
|
| 1841 |
+
cutoff_mask = cumulative_probs > top_p_param
|
| 1842 |
+
# Shift mask by 1 to keep at least one token
|
| 1843 |
+
cutoff_mask[1:] = cutoff_mask[:-1].clone()
|
| 1844 |
+
cutoff_mask[0] = False
|
| 1845 |
+
# Zero out tokens beyond cutoff
|
| 1846 |
+
sorted_probs[cutoff_mask] = 0
|
| 1847 |
+
# Scatter back to original order
|
| 1848 |
+
probs_filtered = torch.zeros_like(probs)
|
| 1849 |
+
probs_filtered.scatter_(0, sorted_indices, sorted_probs)
|
| 1850 |
+
if probs_filtered.sum() > 0:
|
| 1851 |
+
probs_filtered = probs_filtered / probs_filtered.sum() # Renormalize
|
| 1852 |
+
|
| 1853 |
if temperature == 0:
|
| 1854 |
+
next_token_id = torch.argmax(probs_filtered, dim=-1).item()
|
| 1855 |
else:
|
| 1856 |
+
# Ensure valid probability distribution for multinomial
|
| 1857 |
+
if probs_filtered.sum() > 0:
|
| 1858 |
+
next_token_id = torch.multinomial(probs_filtered, 1).item()
|
| 1859 |
+
else:
|
| 1860 |
+
next_token_id = torch.argmax(probs, dim=-1).item()
|
| 1861 |
next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False)
|
| 1862 |
|
| 1863 |
generated_token_ids.append(next_token_id)
|
| 1864 |
generated_tokens.append(next_token_text)
|
| 1865 |
|
| 1866 |
+
# Capture top-10 token alternatives with probabilities
|
| 1867 |
# Use log_softmax for numerical stability at low temperatures
|
| 1868 |
+
_, top_indices = torch.topk(logits, k=min(top_n_display, len(logits)))
|
|
|
|
|
|
|
| 1869 |
|
| 1870 |
# Use log_softmax (numerically stable) then exp() for probabilities
|
| 1871 |
# This avoids underflow that occurs with softmax at low temperatures
|
|
|
|
| 1874 |
top_probs = torch.exp(log_probs[top_indices])
|
| 1875 |
|
| 1876 |
alternatives = []
|
| 1877 |
+
cumulative = 0.0
|
| 1878 |
+
selected_in_top = False
|
| 1879 |
+
for rank, (prob, idx) in enumerate(zip(top_probs.tolist(), top_indices.tolist())):
|
| 1880 |
token_text = manager.tokenizer.decode([idx], skip_special_tokens=False)
|
| 1881 |
+
cumulative += prob
|
| 1882 |
+
if idx == next_token_id:
|
| 1883 |
+
selected_in_top = True
|
| 1884 |
alternatives.append({
|
| 1885 |
"token": token_text,
|
| 1886 |
"token_id": idx,
|
| 1887 |
"probability": prob,
|
| 1888 |
+
"raw_probability": raw_probs[idx].item(), # T=1 probability for comparison
|
| 1889 |
+
"log_probability": math.log(prob) if prob > 0 else float('-inf'),
|
| 1890 |
+
"cumulative_probability": cumulative,
|
| 1891 |
+
"rank": rank + 1
|
| 1892 |
})
|
| 1893 |
+
|
| 1894 |
+
# If selected token is not in top-N, add it with its actual probability
|
| 1895 |
+
if not selected_in_top:
|
| 1896 |
+
selected_prob = probs[next_token_id].item()
|
| 1897 |
+
selected_raw_prob = raw_probs[next_token_id].item()
|
| 1898 |
+
selected_log_prob = log_probs[next_token_id].item()
|
| 1899 |
+
selected_logit = raw_logits[next_token_id].item()
|
| 1900 |
+
# Find the rank of the selected token
|
| 1901 |
+
sorted_indices = torch.argsort(raw_logits, descending=True)
|
| 1902 |
+
selected_rank = (sorted_indices == next_token_id).nonzero(as_tuple=True)[0].item() + 1
|
| 1903 |
+
alternatives.append({
|
| 1904 |
+
"token": next_token_text,
|
| 1905 |
+
"token_id": next_token_id,
|
| 1906 |
+
"probability": selected_prob,
|
| 1907 |
+
"raw_probability": selected_raw_prob, # T=1 probability for comparison
|
| 1908 |
+
"log_probability": selected_log_prob,
|
| 1909 |
+
"cumulative_probability": None, # Not in sequence
|
| 1910 |
+
"rank": selected_rank,
|
| 1911 |
+
"is_selected_outlier": True # Flag for UI
|
| 1912 |
+
})
|
| 1913 |
+
# Also add to logits if not present
|
| 1914 |
+
if next_token_id not in [e["token_id"] for e in logits_entries]:
|
| 1915 |
+
logits_entries.append({
|
| 1916 |
+
"token": next_token_text,
|
| 1917 |
+
"token_id": next_token_id,
|
| 1918 |
+
"logit": selected_logit,
|
| 1919 |
+
"rank": selected_rank,
|
| 1920 |
+
"is_selected_outlier": True
|
| 1921 |
+
})
|
| 1922 |
+
|
| 1923 |
+
# Build sampling metadata
|
| 1924 |
+
sampling_metadata = {
|
| 1925 |
+
"temperature": temperature,
|
| 1926 |
+
"top_k": top_k_param,
|
| 1927 |
+
"top_p": top_p_param,
|
| 1928 |
+
"greedy_token_id": greedy_token_id,
|
| 1929 |
+
"greedy_token": greedy_token,
|
| 1930 |
+
"was_greedy": next_token_id == greedy_token_id
|
| 1931 |
+
}
|
| 1932 |
+
|
| 1933 |
token_alternatives_by_step.append({
|
| 1934 |
"step": step,
|
| 1935 |
"selected_token": next_token_text,
|
| 1936 |
"selected_token_id": next_token_id,
|
| 1937 |
+
"alternatives": alternatives,
|
| 1938 |
+
"logits": logits_entries,
|
| 1939 |
+
"sampling": sampling_metadata
|
| 1940 |
})
|
| 1941 |
|
| 1942 |
# Process attention and hidden states for ALL layers
|
|
|
|
| 2272 |
max_tokens = request.get("max_tokens", 8)
|
| 2273 |
auto_complete = request.get("auto_complete", False)
|
| 2274 |
temperature = request.get("temperature", 0.7)
|
| 2275 |
+
top_k_param = request.get("top_k", None) # Top-k sampling parameter
|
| 2276 |
+
top_p_param = request.get("top_p", None) # Top-p (nucleus) sampling parameter
|
| 2277 |
|
| 2278 |
# If auto_complete mode, ensure we have a reasonable upper limit
|
| 2279 |
if auto_complete:
|
|
|
|
| 2468 |
)
|
| 2469 |
|
| 2470 |
# Get logits for next token
|
| 2471 |
+
raw_logits = outputs.logits[0, -1, :].clone() # Clone raw logits before any scaling
|
| 2472 |
+
|
| 2473 |
+
# Capture raw logits for top-10 tokens (before temperature scaling)
|
| 2474 |
+
import math as math_module
|
| 2475 |
+
top_n_display = 10 # Get top 10 alternatives for display
|
| 2476 |
+
top_raw_logits, top_raw_indices = torch.topk(raw_logits, k=min(top_n_display, len(raw_logits)))
|
| 2477 |
|
| 2478 |
+
# Build raw logits entries (before temperature)
|
| 2479 |
+
# Use correct tokenizer for Devstral vs other models
|
| 2480 |
+
def decode_token(tid):
|
| 2481 |
+
if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
|
| 2482 |
+
return manager.mistral_tokenizer.decode_token(tid)
|
| 2483 |
+
else:
|
| 2484 |
+
return manager.tokenizer.decode([tid], skip_special_tokens=False)
|
| 2485 |
+
|
| 2486 |
+
logits_entries = []
|
| 2487 |
+
for rank, (logit_val, idx) in enumerate(zip(top_raw_logits.tolist(), top_raw_indices.tolist())):
|
| 2488 |
+
token_text = decode_token(idx)
|
| 2489 |
+
logits_entries.append({
|
| 2490 |
+
"token": token_text,
|
| 2491 |
+
"token_id": idx,
|
| 2492 |
+
"logit": logit_val,
|
| 2493 |
+
"rank": rank + 1
|
| 2494 |
+
})
|
| 2495 |
+
|
| 2496 |
+
# Greedy token (argmax of raw logits, before any sampling)
|
| 2497 |
+
greedy_token_id = torch.argmax(raw_logits).item()
|
| 2498 |
+
greedy_token = decode_token(greedy_token_id)
|
| 2499 |
+
|
| 2500 |
+
# Compute raw probabilities (T=1) for comparison visualization
|
| 2501 |
+
raw_probs = torch.softmax(raw_logits, dim=0)
|
| 2502 |
+
|
| 2503 |
+
# Apply temperature scaling
|
| 2504 |
+
logits = raw_logits.clone()
|
| 2505 |
if temperature > 0:
|
| 2506 |
logits = logits / temperature
|
| 2507 |
probs = torch.softmax(logits, dim=0)
|
| 2508 |
|
| 2509 |
+
# Apply top-k filtering if specified
|
| 2510 |
+
if top_k_param is not None and top_k_param > 0:
|
| 2511 |
+
top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k_param, len(probs)))
|
| 2512 |
+
probs_filtered = torch.zeros_like(probs)
|
| 2513 |
+
probs_filtered[top_k_indices] = top_k_probs
|
| 2514 |
+
probs_filtered = probs_filtered / probs_filtered.sum() # Renormalize
|
| 2515 |
else:
|
| 2516 |
+
probs_filtered = probs
|
| 2517 |
|
| 2518 |
+
# Apply top-p (nucleus) filtering if specified
|
| 2519 |
+
if top_p_param is not None and top_p_param < 1.0:
|
| 2520 |
+
sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True)
|
| 2521 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=0)
|
| 2522 |
+
cutoff_mask = cumulative_probs > top_p_param
|
| 2523 |
+
cutoff_mask[1:] = cutoff_mask[:-1].clone()
|
| 2524 |
+
cutoff_mask[0] = False
|
| 2525 |
+
sorted_probs[cutoff_mask] = 0
|
| 2526 |
+
probs_filtered = torch.zeros_like(probs)
|
| 2527 |
+
probs_filtered.scatter_(0, sorted_indices, sorted_probs)
|
| 2528 |
+
if probs_filtered.sum() > 0:
|
| 2529 |
+
probs_filtered = probs_filtered / probs_filtered.sum()
|
| 2530 |
+
|
| 2531 |
+
if temperature == 0:
|
| 2532 |
+
next_token_id = torch.argmax(probs_filtered, dim=-1).item()
|
| 2533 |
else:
|
| 2534 |
+
if probs_filtered.sum() > 0:
|
| 2535 |
+
next_token_id = torch.multinomial(probs_filtered, 1).item()
|
| 2536 |
+
else:
|
| 2537 |
+
next_token_id = torch.argmax(probs, dim=-1).item()
|
| 2538 |
+
|
| 2539 |
+
next_token_text = decode_token(next_token_id)
|
| 2540 |
|
| 2541 |
generated_token_ids.append(next_token_id)
|
| 2542 |
generated_tokens.append(next_token_text)
|
| 2543 |
|
| 2544 |
+
# Capture top-10 token alternatives with probabilities
|
| 2545 |
# Use log_softmax for numerical stability at low temperatures
|
| 2546 |
+
_, top_indices = torch.topk(logits, k=min(top_n_display, len(logits)))
|
|
|
|
|
|
|
| 2547 |
|
| 2548 |
# Use log_softmax (numerically stable) then exp() for probabilities
|
|
|
|
|
|
|
| 2549 |
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 2550 |
top_probs = torch.exp(log_probs[top_indices])
|
| 2551 |
|
| 2552 |
alternatives = []
|
| 2553 |
+
cumulative = 0.0
|
| 2554 |
+
selected_in_top = False
|
| 2555 |
+
for rank, (prob, idx) in enumerate(zip(top_probs.tolist(), top_indices.tolist())):
|
| 2556 |
+
token_text = decode_token(idx)
|
| 2557 |
+
cumulative += prob
|
| 2558 |
+
if idx == next_token_id:
|
| 2559 |
+
selected_in_top = True
|
| 2560 |
alternatives.append({
|
| 2561 |
"token": token_text,
|
| 2562 |
"token_id": idx,
|
| 2563 |
"probability": prob,
|
| 2564 |
+
"raw_probability": raw_probs[idx].item(), # T=1 probability for comparison
|
| 2565 |
+
"log_probability": math_module.log(prob) if prob > 0 else float('-inf'),
|
| 2566 |
+
"cumulative_probability": cumulative,
|
| 2567 |
+
"rank": rank + 1
|
| 2568 |
})
|
| 2569 |
+
|
| 2570 |
+
# If selected token is not in top-N, add it with its actual probability
|
| 2571 |
+
if not selected_in_top:
|
| 2572 |
+
selected_prob = probs[next_token_id].item()
|
| 2573 |
+
selected_raw_prob = raw_probs[next_token_id].item()
|
| 2574 |
+
selected_log_prob = log_probs[next_token_id].item()
|
| 2575 |
+
selected_logit = raw_logits[next_token_id].item()
|
| 2576 |
+
# Find the rank of the selected token
|
| 2577 |
+
sorted_indices = torch.argsort(raw_logits, descending=True)
|
| 2578 |
+
selected_rank = (sorted_indices == next_token_id).nonzero(as_tuple=True)[0].item() + 1
|
| 2579 |
+
alternatives.append({
|
| 2580 |
+
"token": next_token_text,
|
| 2581 |
+
"token_id": next_token_id,
|
| 2582 |
+
"probability": selected_prob,
|
| 2583 |
+
"raw_probability": selected_raw_prob, # T=1 probability for comparison
|
| 2584 |
+
"log_probability": selected_log_prob,
|
| 2585 |
+
"cumulative_probability": None,
|
| 2586 |
+
"rank": selected_rank,
|
| 2587 |
+
"is_selected_outlier": True
|
| 2588 |
+
})
|
| 2589 |
+
# Also add to logits if not present
|
| 2590 |
+
if next_token_id not in [e["token_id"] for e in logits_entries]:
|
| 2591 |
+
logits_entries.append({
|
| 2592 |
+
"token": next_token_text,
|
| 2593 |
+
"token_id": next_token_id,
|
| 2594 |
+
"logit": selected_logit,
|
| 2595 |
+
"rank": selected_rank,
|
| 2596 |
+
"is_selected_outlier": True
|
| 2597 |
+
})
|
| 2598 |
+
|
| 2599 |
+
# Build sampling metadata
|
| 2600 |
+
sampling_metadata = {
|
| 2601 |
+
"temperature": temperature,
|
| 2602 |
+
"top_k": top_k_param,
|
| 2603 |
+
"top_p": top_p_param,
|
| 2604 |
+
"greedy_token_id": greedy_token_id,
|
| 2605 |
+
"greedy_token": greedy_token,
|
| 2606 |
+
"was_greedy": next_token_id == greedy_token_id
|
| 2607 |
+
}
|
| 2608 |
+
|
| 2609 |
token_alternatives_by_step.append({
|
| 2610 |
"step": step,
|
| 2611 |
"selected_token": next_token_text,
|
| 2612 |
"selected_token_id": next_token_id,
|
| 2613 |
+
"alternatives": alternatives,
|
| 2614 |
+
"logits": logits_entries,
|
| 2615 |
+
"sampling": sampling_metadata
|
| 2616 |
})
|
| 2617 |
|
| 2618 |
# === STAGE 3: EXTRACTING (per layer within each token) ===
|