Spaces:
Sleeping
Sleeping
Add per-token alternatives + hover tooltip UI
Browse files- server.py +62 -7
- static/css/styles.css +46 -1
- static/js/main.js +92 -3
server.py
CHANGED
|
@@ -97,26 +97,81 @@ class WordPredictionRequest(BaseModel):
|
|
| 97 |
@app.post("/api/chat")
|
| 98 |
def chat_generate(req: ChatRequest):
|
| 99 |
try:
|
|
|
|
| 100 |
prompt = (
|
| 101 |
"<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
|
| 102 |
f"<|im_start|>user\n{req.message}<|im_end|>\n"
|
| 103 |
"<|im_start|>assistant\n"
|
| 104 |
)
|
| 105 |
inputs = chat_tokenizer(prompt, return_tensors="pt").to(chat_model.device)
|
|
|
|
|
|
|
| 106 |
outputs = chat_model.generate(
|
| 107 |
**inputs,
|
| 108 |
max_new_tokens=req.max_new_tokens,
|
| 109 |
temperature=req.temperature,
|
| 110 |
-
do_sample=
|
| 111 |
-
|
|
|
|
| 112 |
eos_token_id=chat_tokenizer.eos_token_id,
|
| 113 |
pad_token_id=chat_tokenizer.eos_token_id,
|
| 114 |
)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
except Exception as e:
|
| 121 |
return {"success": False, "error": str(e)}
|
| 122 |
|
|
|
|
| 97 |
@app.post("/api/chat")
|
| 98 |
def chat_generate(req: ChatRequest):
|
| 99 |
try:
|
| 100 |
+
# Build prompt and run generation while requesting per-step scores
|
| 101 |
prompt = (
|
| 102 |
"<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
|
| 103 |
f"<|im_start|>user\n{req.message}<|im_end|>\n"
|
| 104 |
"<|im_start|>assistant\n"
|
| 105 |
)
|
| 106 |
inputs = chat_tokenizer(prompt, return_tensors="pt").to(chat_model.device)
|
| 107 |
+
|
| 108 |
+
# Generate deterministically (greedy) while returning scores for each generated step
|
| 109 |
outputs = chat_model.generate(
|
| 110 |
**inputs,
|
| 111 |
max_new_tokens=req.max_new_tokens,
|
| 112 |
temperature=req.temperature,
|
| 113 |
+
do_sample=False,
|
| 114 |
+
output_scores=True,
|
| 115 |
+
return_dict_in_generate=True,
|
| 116 |
eos_token_id=chat_tokenizer.eos_token_id,
|
| 117 |
pad_token_id=chat_tokenizer.eos_token_id,
|
| 118 |
)
|
| 119 |
+
|
| 120 |
+
# Full sequence and newly generated token ids
|
| 121 |
+
sequence = outputs.sequences[0]
|
| 122 |
+
start_idx = inputs["input_ids"].size(1)
|
| 123 |
+
generated_ids = sequence[start_idx:].tolist()
|
| 124 |
+
|
| 125 |
+
# Decode the full reply
|
| 126 |
+
reply = chat_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 127 |
+
|
| 128 |
+
# Prepare per-token alternatives using the per-step logits/scores
|
| 129 |
+
tokens_info = []
|
| 130 |
+
# outputs.scores is a tuple with one entry per generated step
|
| 131 |
+
if hasattr(outputs, "scores") and outputs.scores is not None:
|
| 132 |
+
for i, logits in enumerate(outputs.scores):
|
| 133 |
+
# logits shape: (batch_size, vocab_size)
|
| 134 |
+
probs = torch.softmax(logits[0], dim=-1)
|
| 135 |
+
chosen_id = generated_ids[i]
|
| 136 |
+
|
| 137 |
+
# Get top-k (we ask for 6 and drop the chosen token if present)
|
| 138 |
+
topk = torch.topk(probs, k=6)
|
| 139 |
+
alts = []
|
| 140 |
+
for idx, val in zip(topk.indices.tolist(), topk.values.tolist()):
|
| 141 |
+
if idx == chosen_id:
|
| 142 |
+
continue
|
| 143 |
+
alts.append({
|
| 144 |
+
"id": idx,
|
| 145 |
+
"token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(),
|
| 146 |
+
"probability": float(val)
|
| 147 |
+
})
|
| 148 |
+
if len(alts) >= 5:
|
| 149 |
+
break
|
| 150 |
+
|
| 151 |
+
# Fallback: if not enough alts, sample additional highest-prob tokens
|
| 152 |
+
if len(alts) < 5:
|
| 153 |
+
# get full topk of vocab (expensive but rare for short max_new_tokens)
|
| 154 |
+
fallback_topk = torch.topk(probs, k=10)
|
| 155 |
+
for idx, val in zip(fallback_topk.indices.tolist(), fallback_topk.values.tolist()):
|
| 156 |
+
if idx == chosen_id:
|
| 157 |
+
continue
|
| 158 |
+
if any(a["id"] == idx for a in alts):
|
| 159 |
+
continue
|
| 160 |
+
alts.append({
|
| 161 |
+
"id": idx,
|
| 162 |
+
"token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(),
|
| 163 |
+
"probability": float(val)
|
| 164 |
+
})
|
| 165 |
+
if len(alts) >= 5:
|
| 166 |
+
break
|
| 167 |
+
|
| 168 |
+
tokens_info.append({
|
| 169 |
+
"id": chosen_id,
|
| 170 |
+
"token": chat_tokenizer.decode([chosen_id], skip_special_tokens=True).strip(),
|
| 171 |
+
"alternatives": alts
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
return {"success": True, "response": reply, "tokens": tokens_info}
|
| 175 |
except Exception as e:
|
| 176 |
return {"success": False, "error": str(e)}
|
| 177 |
|
static/css/styles.css
CHANGED
|
@@ -409,4 +409,49 @@ textarea:focus {
|
|
| 409 |
background-color: rgba(220, 38, 38, 0.2);
|
| 410 |
border-left: 4px solid rgb(220, 38, 38);
|
| 411 |
margin: 0 auto;
|
| 412 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
background-color: rgba(220, 38, 38, 0.2);
|
| 410 |
border-left: 4px solid rgb(220, 38, 38);
|
| 411 |
margin: 0 auto;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
/* Token hover alternatives */
|
| 415 |
+
.generated-text {
|
| 416 |
+
display: inline-block;
|
| 417 |
+
line-height: 1.6;
|
| 418 |
+
}
|
| 419 |
+
.generated-token {
|
| 420 |
+
display: inline-block;
|
| 421 |
+
padding: 2px 4px;
|
| 422 |
+
margin-right: 1px;
|
| 423 |
+
border-radius: 4px;
|
| 424 |
+
cursor: pointer;
|
| 425 |
+
color: var(--text-light);
|
| 426 |
+
}
|
| 427 |
+
.generated-token:hover {
|
| 428 |
+
background: rgba(139,92,246,0.12);
|
| 429 |
+
}
|
| 430 |
+
.alt-tooltip {
|
| 431 |
+
position: absolute;
|
| 432 |
+
display: none;
|
| 433 |
+
min-width: 160px;
|
| 434 |
+
background: linear-gradient(180deg, #111827, #0b1220);
|
| 435 |
+
color: var(--text-light);
|
| 436 |
+
border: 1px solid rgba(139,92,246,0.18);
|
| 437 |
+
border-radius: 8px;
|
| 438 |
+
padding: 8px;
|
| 439 |
+
box-shadow: 0 8px 24px rgba(2,6,23,0.6);
|
| 440 |
+
z-index: 2000;
|
| 441 |
+
}
|
| 442 |
+
.alt-title {
|
| 443 |
+
font-weight: 600;
|
| 444 |
+
font-size: 0.9rem;
|
| 445 |
+
margin-bottom: 6px;
|
| 446 |
+
opacity: 0.9;
|
| 447 |
+
}
|
| 448 |
+
.alt-row {
|
| 449 |
+
display: flex;
|
| 450 |
+
justify-content: space-between;
|
| 451 |
+
gap: 8px;
|
| 452 |
+
padding: 6px 6px;
|
| 453 |
+
border-radius: 6px;
|
| 454 |
+
}
|
| 455 |
+
.alt-row:hover { background: rgba(255,255,255,0.02); }
|
| 456 |
+
.alt-token { color: var(--text-light); }
|
| 457 |
+
.alt-prob { color: var(--accent-gray); font-size: 0.85rem; }
|
static/js/main.js
CHANGED
|
@@ -148,8 +148,84 @@ async function sendMessage() {
|
|
| 148 |
const data = await response.json();
|
| 149 |
const botMessage = document.createElement('div');
|
| 150 |
botMessage.className = 'message assistant';
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
// Clear input
|
| 155 |
input.value = '';
|
|
@@ -317,4 +393,17 @@ document.getElementById('summary-input').addEventListener('keypress', (e) => {
|
|
| 317 |
e.preventDefault();
|
| 318 |
generateSummary();
|
| 319 |
}
|
| 320 |
-
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
const data = await response.json();
|
| 149 |
const botMessage = document.createElement('div');
|
| 150 |
botMessage.className = 'message assistant';
|
| 151 |
+
|
| 152 |
+
// If the server returned per-token info, render tokens individually so we
|
| 153 |
+
// can show alternative tokens on hover. Otherwise, fall back to plain text.
|
| 154 |
+
if (data.tokens && Array.isArray(data.tokens) && data.tokens.length > 0) {
|
| 155 |
+
const frag = document.createDocumentFragment();
|
| 156 |
+
const wrapper = document.createElement('div');
|
| 157 |
+
wrapper.className = 'generated-text';
|
| 158 |
+
|
| 159 |
+
data.tokens.forEach((t, idx) => {
|
| 160 |
+
const span = document.createElement('span');
|
| 161 |
+
span.className = 'generated-token';
|
| 162 |
+
span.setAttribute('data-token-index', idx);
|
| 163 |
+
span.textContent = t.token || '';
|
| 164 |
+
// store alternatives on the element for quick access
|
| 165 |
+
span._alternatives = t.alternatives || [];
|
| 166 |
+
wrapper.appendChild(span);
|
| 167 |
+
});
|
| 168 |
+
|
| 169 |
+
frag.appendChild(wrapper);
|
| 170 |
+
botMessage.appendChild(frag);
|
| 171 |
+
chatOutput.appendChild(botMessage);
|
| 172 |
+
|
| 173 |
+
// Tooltip element for showing alternatives
|
| 174 |
+
let tooltip = document.getElementById('alt-tooltip');
|
| 175 |
+
if (!tooltip) {
|
| 176 |
+
tooltip = document.createElement('div');
|
| 177 |
+
tooltip.id = 'alt-tooltip';
|
| 178 |
+
tooltip.className = 'alt-tooltip';
|
| 179 |
+
document.body.appendChild(tooltip);
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
// Attach hover listeners
|
| 183 |
+
wrapper.querySelectorAll('.generated-token').forEach(el => {
|
| 184 |
+
el.addEventListener('mouseenter', (ev) => {
|
| 185 |
+
const alts = el._alternatives || [];
|
| 186 |
+
if (!alts.length) return;
|
| 187 |
+
// build tooltip html
|
| 188 |
+
tooltip.innerHTML = '';
|
| 189 |
+
const title = document.createElement('div');
|
| 190 |
+
title.className = 'alt-title';
|
| 191 |
+
title.textContent = 'Alternatives';
|
| 192 |
+
tooltip.appendChild(title);
|
| 193 |
+
|
| 194 |
+
alts.forEach(a => {
|
| 195 |
+
const row = document.createElement('div');
|
| 196 |
+
row.className = 'alt-row';
|
| 197 |
+
const tok = document.createElement('span');
|
| 198 |
+
tok.className = 'alt-token';
|
| 199 |
+
tok.textContent = a.token || '';
|
| 200 |
+
const prob = document.createElement('span');
|
| 201 |
+
prob.className = 'alt-prob';
|
| 202 |
+
prob.textContent = `${(a.probability * 100).toFixed(2)}%`;
|
| 203 |
+
row.appendChild(tok);
|
| 204 |
+
row.appendChild(prob);
|
| 205 |
+
// click to insert token into input (optional UX)
|
| 206 |
+
row.addEventListener('click', () => {
|
| 207 |
+
const chatInput = document.getElementById('chat-input');
|
| 208 |
+
insertAtCursor(chatInput, a.token || '');
|
| 209 |
+
});
|
| 210 |
+
tooltip.appendChild(row);
|
| 211 |
+
});
|
| 212 |
+
|
| 213 |
+
// Position tooltip near the hovered token
|
| 214 |
+
const rect = el.getBoundingClientRect();
|
| 215 |
+
tooltip.style.display = 'block';
|
| 216 |
+
tooltip.style.left = `${rect.left + window.scrollX}px`;
|
| 217 |
+
tooltip.style.top = `${rect.bottom + window.scrollY + 6}px`;
|
| 218 |
+
});
|
| 219 |
+
|
| 220 |
+
el.addEventListener('mouseleave', () => {
|
| 221 |
+
const tooltip = document.getElementById('alt-tooltip');
|
| 222 |
+
if (tooltip) tooltip.style.display = 'none';
|
| 223 |
+
});
|
| 224 |
+
});
|
| 225 |
+
} else {
|
| 226 |
+
botMessage.textContent = data.response || 'Sorry, I could not process your request.';
|
| 227 |
+
chatOutput.appendChild(botMessage);
|
| 228 |
+
}
|
| 229 |
|
| 230 |
// Clear input
|
| 231 |
input.value = '';
|
|
|
|
| 393 |
e.preventDefault();
|
| 394 |
generateSummary();
|
| 395 |
}
|
| 396 |
+
});
|
| 397 |
+
|
| 398 |
+
// Helper to insert text at the cursor position for input/textarea
|
| 399 |
+
function insertAtCursor(el, text) {
|
| 400 |
+
if (!el) return;
|
| 401 |
+
const start = typeof el.selectionStart === 'number' ? el.selectionStart : el.value.length;
|
| 402 |
+
const end = typeof el.selectionEnd === 'number' ? el.selectionEnd : start;
|
| 403 |
+
const before = el.value.substring(0, start);
|
| 404 |
+
const after = el.value.substring(end);
|
| 405 |
+
el.value = before + text + after;
|
| 406 |
+
const pos = before.length + text.length;
|
| 407 |
+
el.selectionStart = el.selectionEnd = pos;
|
| 408 |
+
el.focus();
|
| 409 |
+
}
|