ButterM40 commited on
Commit
5a6a589
·
1 Parent(s): ef0620c

Add per-token alternatives + hover tooltip UI

Browse files
Files changed (3) hide show
  1. server.py +62 -7
  2. static/css/styles.css +46 -1
  3. static/js/main.js +92 -3
server.py CHANGED
@@ -97,26 +97,81 @@ class WordPredictionRequest(BaseModel):
97
  @app.post("/api/chat")
98
  def chat_generate(req: ChatRequest):
99
  try:
 
100
  prompt = (
101
  "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
102
  f"<|im_start|>user\n{req.message}<|im_end|>\n"
103
  "<|im_start|>assistant\n"
104
  )
105
  inputs = chat_tokenizer(prompt, return_tensors="pt").to(chat_model.device)
 
 
106
  outputs = chat_model.generate(
107
  **inputs,
108
  max_new_tokens=req.max_new_tokens,
109
  temperature=req.temperature,
110
- do_sample=True,
111
- top_p=0.9,
 
112
  eos_token_id=chat_tokenizer.eos_token_id,
113
  pad_token_id=chat_tokenizer.eos_token_id,
114
  )
115
- new_tokens = outputs[0][inputs["input_ids"].size(1):]
116
- reply = chat_tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
117
- if not reply:
118
- reply = chat_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
119
- return {"success": True, "response": reply}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
  return {"success": False, "error": str(e)}
122
 
 
97
  @app.post("/api/chat")
98
  def chat_generate(req: ChatRequest):
99
  try:
100
+ # Build prompt and run generation while requesting per-step scores
101
  prompt = (
102
  "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
103
  f"<|im_start|>user\n{req.message}<|im_end|>\n"
104
  "<|im_start|>assistant\n"
105
  )
106
  inputs = chat_tokenizer(prompt, return_tensors="pt").to(chat_model.device)
107
+
108
+ # Generate deterministically (greedy) while returning scores for each generated step
109
  outputs = chat_model.generate(
110
  **inputs,
111
  max_new_tokens=req.max_new_tokens,
112
  temperature=req.temperature,
113
+ do_sample=False,
114
+ output_scores=True,
115
+ return_dict_in_generate=True,
116
  eos_token_id=chat_tokenizer.eos_token_id,
117
  pad_token_id=chat_tokenizer.eos_token_id,
118
  )
119
+
120
+ # Full sequence and newly generated token ids
121
+ sequence = outputs.sequences[0]
122
+ start_idx = inputs["input_ids"].size(1)
123
+ generated_ids = sequence[start_idx:].tolist()
124
+
125
+ # Decode the full reply
126
+ reply = chat_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
127
+
128
+ # Prepare per-token alternatives using the per-step logits/scores
129
+ tokens_info = []
130
+ # outputs.scores is a tuple with one entry per generated step
131
+ if hasattr(outputs, "scores") and outputs.scores is not None:
132
+ for i, logits in enumerate(outputs.scores):
133
+ # logits shape: (batch_size, vocab_size)
134
+ probs = torch.softmax(logits[0], dim=-1)
135
+ chosen_id = generated_ids[i]
136
+
137
+ # Get top-k (we ask for 6 and drop the chosen token if present)
138
+ topk = torch.topk(probs, k=6)
139
+ alts = []
140
+ for idx, val in zip(topk.indices.tolist(), topk.values.tolist()):
141
+ if idx == chosen_id:
142
+ continue
143
+ alts.append({
144
+ "id": idx,
145
+ "token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(),
146
+ "probability": float(val)
147
+ })
148
+ if len(alts) >= 5:
149
+ break
150
+
151
+ # Fallback: if not enough alts, sample additional highest-prob tokens
152
+ if len(alts) < 5:
153
+ # get full topk of vocab (expensive but rare for short max_new_tokens)
154
+ fallback_topk = torch.topk(probs, k=10)
155
+ for idx, val in zip(fallback_topk.indices.tolist(), fallback_topk.values.tolist()):
156
+ if idx == chosen_id:
157
+ continue
158
+ if any(a["id"] == idx for a in alts):
159
+ continue
160
+ alts.append({
161
+ "id": idx,
162
+ "token": chat_tokenizer.decode([idx], skip_special_tokens=True).strip(),
163
+ "probability": float(val)
164
+ })
165
+ if len(alts) >= 5:
166
+ break
167
+
168
+ tokens_info.append({
169
+ "id": chosen_id,
170
+ "token": chat_tokenizer.decode([chosen_id], skip_special_tokens=True).strip(),
171
+ "alternatives": alts
172
+ })
173
+
174
+ return {"success": True, "response": reply, "tokens": tokens_info}
175
  except Exception as e:
176
  return {"success": False, "error": str(e)}
177
 
static/css/styles.css CHANGED
@@ -409,4 +409,49 @@ textarea:focus {
409
  background-color: rgba(220, 38, 38, 0.2);
410
  border-left: 4px solid rgb(220, 38, 38);
411
  margin: 0 auto;
412
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  background-color: rgba(220, 38, 38, 0.2);
410
  border-left: 4px solid rgb(220, 38, 38);
411
  margin: 0 auto;
412
+ }
413
+
414
+ /* Token hover alternatives */
415
+ .generated-text {
416
+ display: inline-block;
417
+ line-height: 1.6;
418
+ }
419
+ .generated-token {
420
+ display: inline-block;
421
+ padding: 2px 4px;
422
+ margin-right: 1px;
423
+ border-radius: 4px;
424
+ cursor: pointer;
425
+ color: var(--text-light);
426
+ }
427
+ .generated-token:hover {
428
+ background: rgba(139,92,246,0.12);
429
+ }
430
+ .alt-tooltip {
431
+ position: absolute;
432
+ display: none;
433
+ min-width: 160px;
434
+ background: linear-gradient(180deg, #111827, #0b1220);
435
+ color: var(--text-light);
436
+ border: 1px solid rgba(139,92,246,0.18);
437
+ border-radius: 8px;
438
+ padding: 8px;
439
+ box-shadow: 0 8px 24px rgba(2,6,23,0.6);
440
+ z-index: 2000;
441
+ }
442
+ .alt-title {
443
+ font-weight: 600;
444
+ font-size: 0.9rem;
445
+ margin-bottom: 6px;
446
+ opacity: 0.9;
447
+ }
448
+ .alt-row {
449
+ display: flex;
450
+ justify-content: space-between;
451
+ gap: 8px;
452
+ padding: 6px 6px;
453
+ border-radius: 6px;
454
+ }
455
+ .alt-row:hover { background: rgba(255,255,255,0.02); }
456
+ .alt-token { color: var(--text-light); }
457
+ .alt-prob { color: var(--accent-gray); font-size: 0.85rem; }
static/js/main.js CHANGED
@@ -148,8 +148,84 @@ async function sendMessage() {
148
  const data = await response.json();
149
  const botMessage = document.createElement('div');
150
  botMessage.className = 'message assistant';
151
- botMessage.textContent = data.response || 'Sorry, I could not process your request.';
152
- chatOutput.appendChild(botMessage);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  // Clear input
155
  input.value = '';
@@ -317,4 +393,17 @@ document.getElementById('summary-input').addEventListener('keypress', (e) => {
317
  e.preventDefault();
318
  generateSummary();
319
  }
320
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  const data = await response.json();
149
  const botMessage = document.createElement('div');
150
  botMessage.className = 'message assistant';
151
+
152
+ // If the server returned per-token info, render tokens individually so we
153
+ // can show alternative tokens on hover. Otherwise, fall back to plain text.
154
+ if (data.tokens && Array.isArray(data.tokens) && data.tokens.length > 0) {
155
+ const frag = document.createDocumentFragment();
156
+ const wrapper = document.createElement('div');
157
+ wrapper.className = 'generated-text';
158
+
159
+ data.tokens.forEach((t, idx) => {
160
+ const span = document.createElement('span');
161
+ span.className = 'generated-token';
162
+ span.setAttribute('data-token-index', idx);
163
+ span.textContent = t.token || '';
164
+ // store alternatives on the element for quick access
165
+ span._alternatives = t.alternatives || [];
166
+ wrapper.appendChild(span);
167
+ });
168
+
169
+ frag.appendChild(wrapper);
170
+ botMessage.appendChild(frag);
171
+ chatOutput.appendChild(botMessage);
172
+
173
+ // Tooltip element for showing alternatives
174
+ let tooltip = document.getElementById('alt-tooltip');
175
+ if (!tooltip) {
176
+ tooltip = document.createElement('div');
177
+ tooltip.id = 'alt-tooltip';
178
+ tooltip.className = 'alt-tooltip';
179
+ document.body.appendChild(tooltip);
180
+ }
181
+
182
+ // Attach hover listeners
183
+ wrapper.querySelectorAll('.generated-token').forEach(el => {
184
+ el.addEventListener('mouseenter', (ev) => {
185
+ const alts = el._alternatives || [];
186
+ if (!alts.length) return;
187
+ // build tooltip html
188
+ tooltip.innerHTML = '';
189
+ const title = document.createElement('div');
190
+ title.className = 'alt-title';
191
+ title.textContent = 'Alternatives';
192
+ tooltip.appendChild(title);
193
+
194
+ alts.forEach(a => {
195
+ const row = document.createElement('div');
196
+ row.className = 'alt-row';
197
+ const tok = document.createElement('span');
198
+ tok.className = 'alt-token';
199
+ tok.textContent = a.token || '';
200
+ const prob = document.createElement('span');
201
+ prob.className = 'alt-prob';
202
+ prob.textContent = `${(a.probability * 100).toFixed(2)}%`;
203
+ row.appendChild(tok);
204
+ row.appendChild(prob);
205
+ // click to insert token into input (optional UX)
206
+ row.addEventListener('click', () => {
207
+ const chatInput = document.getElementById('chat-input');
208
+ insertAtCursor(chatInput, a.token || '');
209
+ });
210
+ tooltip.appendChild(row);
211
+ });
212
+
213
+ // Position tooltip near the hovered token
214
+ const rect = el.getBoundingClientRect();
215
+ tooltip.style.display = 'block';
216
+ tooltip.style.left = `${rect.left + window.scrollX}px`;
217
+ tooltip.style.top = `${rect.bottom + window.scrollY + 6}px`;
218
+ });
219
+
220
+ el.addEventListener('mouseleave', () => {
221
+ const tooltip = document.getElementById('alt-tooltip');
222
+ if (tooltip) tooltip.style.display = 'none';
223
+ });
224
+ });
225
+ } else {
226
+ botMessage.textContent = data.response || 'Sorry, I could not process your request.';
227
+ chatOutput.appendChild(botMessage);
228
+ }
229
 
230
  // Clear input
231
  input.value = '';
 
393
  e.preventDefault();
394
  generateSummary();
395
  }
396
+ });
397
+
398
+ // Helper to insert text at the cursor position for input/textarea
399
+ function insertAtCursor(el, text) {
400
+ if (!el) return;
401
+ const start = typeof el.selectionStart === 'number' ? el.selectionStart : el.value.length;
402
+ const end = typeof el.selectionEnd === 'number' ? el.selectionEnd : start;
403
+ const before = el.value.substring(0, start);
404
+ const after = el.value.substring(end);
405
+ el.value = before + text + after;
406
+ const pos = before.length + text.length;
407
+ el.selectionStart = el.selectionEnd = pos;
408
+ el.focus();
409
+ }