| import gradio as gr | |
| import torch | |
| import html as html_lib | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase") | |
| model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase") | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| def get_color(p): | |
| hue = min(p * 120, 120) | |
| return f"hsl({hue},80%,35%)", f"hsla({hue},80%,50%,0.15)" | |
| def analyze_text(text, top_k): | |
| top_k = max(1, int(top_k)) | |
| if not text.strip(): | |
| return "<p style='color:#999;text-align:center;padding:40px'>Paste some text and click Analyze.</p>" | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) > 512: | |
| tokens = tokens[:512] | |
| with torch.no_grad(): | |
| input_ids = torch.tensor([tokens]).to(device) | |
| all_logits = model(input_ids).logits[0].cpu() | |
| css = """<style> | |
| .tc{display:flex;flex-wrap:wrap;gap:5px;padding:20px;line-height:2.4;font-family:'Segoe UI',sans-serif} | |
| .tw{position:relative;display:inline-block} | |
| .tk{padding:4px 7px;border-radius:6px;cursor:pointer;font-size:15px;transition:.2s;border:1px solid transparent;user-select:none} | |
| .tw:hover .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999} | |
| .tt{display:none;position:absolute;bottom:calc(100% + 8px);left:50%;transform:translateX(-50%); | |
| background:#1a1a2e;color:#eee;padding:14px;border-radius:12px;font-size:13px;z-index:9999; | |
| box-shadow:0 10px 30px rgba(0,0,0,.35);min-width:220px;max-height:350px;overflow-y:auto} | |
| .tt::after{content:'';position:absolute;top:100%;left:0;width:100%;height:12px} | |
| .tw:hover .tt{display:block} | |
| .tw.pinned .tt{display:block} | |
| .tw.pinned .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999;outline:2px solid #7fdbca} | |
| .th{font-weight:700;font-size:14px;color:#7fdbca;border-bottom:1px solid #333;padding-bottom:6px;margin-bottom:6px} | |
| .tp{color:#ffd700;margin-bottom:8px} | |
| .at{color:#ff79c6;font-size:10px;text-transform:uppercase;letter-spacing:1px;margin-bottom:4px} | |
| .aw{display:flex;justify-content:space-between;padding:2px 0;font-size:12px} | |
| .aw .w{color:#c3cee3}.aw .p{color:#666;margin-left:14px} | |
| .hi{font-weight:700;color:#7fdbca!important} | |
| </style> | |
| <script> | |
| document.addEventListener('click', function(e) { | |
| const tk = e.target.closest('.tk'); | |
| const tw = tk ? tk.closest('.tw') : null; | |
| if (tw) { | |
| const wasPinned = tw.classList.contains('pinned'); | |
| document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned')); | |
| if (!wasPinned) tw.classList.add('pinned'); | |
| } else if (!e.target.closest('.tt')) { | |
| document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned')); | |
| } | |
| }); | |
| </script>""" | |
| parts = [css, '<div class="tc">'] | |
| for i in range(len(tokens)): | |
| tok = html_lib.escape(tokenizer.decode([tokens[i]])) | |
| if i == 0: | |
| parts.append(f'<div class="tw"><span class="tk" style="background:rgba(128,128,128,.1);color:#888">{tok}</span></div>') | |
| continue | |
| probs = torch.softmax(all_logits[i - 1], dim=-1) | |
| actual_p = probs[tokens[i]].item() | |
| top_p, top_idx = probs.topk(top_k) | |
| color, bg = get_color(actual_p) | |
| rank = None | |
| alts = "" | |
| for j in range(top_k): | |
| a_text = html_lib.escape(tokenizer.decode([top_idx[j].item()])) | |
| a_p = top_p[j].item() | |
| hit = top_idx[j].item() == tokens[i] | |
| if hit: rank = j + 1 | |
| cls = ' class="w hi"' if hit else ' class="w"' | |
| pcls = ' class="p hi"' if hit else ' class="p"' | |
| alts += f'<div class="aw"><span{cls}>{a_text}</span><span{pcls}>{a_p:.4f}</span></div>' | |
| rank_s = f"rank #{rank}" if rank else f"rank >{top_k}" | |
| tooltip = f'''<div class="tt"> | |
| <div class="th">“{tok}”</div> | |
| <div class="tp">P = {actual_p:.4f} ({rank_s})</div> | |
| <div class="at">Top {top_k} alternatives</div>{alts}</div>''' | |
| parts.append(f'<div class="tw"><span class="tk" style="background:{bg};color:{color}">{tok}</span>{tooltip}</div>') | |
| parts.append('</div>') | |
| return ''.join(parts) | |
| def predict_next(text, num_candidates): | |
| num_candidates = max(1, int(num_candidates)) | |
| if not text.strip(): | |
| return "<p style='color:#999;text-align:center;padding:40px'>Enter text and click Predict Next.</p>" | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) > 512: | |
| tokens = tokens[:512] | |
| with torch.no_grad(): | |
| input_ids = torch.tensor([tokens]).to(device) | |
| logits = model(input_ids).logits[0, -1].cpu() | |
| probs = torch.softmax(logits, dim=-1) | |
| log_probs = torch.log(probs) | |
| top_p, top_idx = probs.topk(num_candidates) | |
| top_lp = log_probs[top_idx] | |
| rows = "" | |
| for j in range(num_candidates): | |
| tok_text = html_lib.escape(tokenizer.decode([top_idx[j].item()])) | |
| p = top_p[j].item() | |
| lp = top_lp[j].item() | |
| bar_width = max(1, int(p * 100)) | |
| hue = min(p * 120, 120) | |
| rows += f"""<tr> | |
| <td style="padding:6px 12px;font-weight:600;color:#e0e0e0;white-space:nowrap">{j+1}</td> | |
| <td style="padding:6px 12px;font-family:monospace;font-size:15px;color:#7fdbca;white-space:nowrap">{tok_text}</td> | |
| <td style="padding:6px 12px;width:100%"> | |
| <div style="background:hsla({hue},80%,50%,0.25);border-radius:4px;height:22px;width:{bar_width}%;min-width:2px;display:flex;align-items:center;padding-left:6px"> | |
| <span style="font-size:11px;color:hsl({hue},80%,70%);font-weight:600">{p:.4f}</span> | |
| </div> | |
| </td> | |
| <td style="padding:6px 12px;font-family:monospace;font-size:13px;color:#888;white-space:nowrap">{lp:.4f}</td> | |
| </tr>""" | |
| html = f"""<div style="font-family:'Segoe UI',sans-serif;background:#1a1a2e;border-radius:12px;padding:16px;overflow-x:auto"> | |
| <div style="color:#ff79c6;font-size:11px;text-transform:uppercase;letter-spacing:1px;margin-bottom:10px"> | |
| Top {num_candidates} predicted next tokens</div> | |
| <table style="width:100%;border-collapse:collapse"> | |
| <thead><tr style="border-bottom:1px solid #333"> | |
| <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">#</th> | |
| <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">TOKEN</th> | |
| <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">PROBABILITY</th> | |
| <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">LOG PROB</th> | |
| </tr></thead> | |
| <tbody>{rows}</tbody> | |
| </table></div>""" | |
| return html | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🔍 Token Probability Explorer & Predictor\nPaste text, **hover** to preview or **click** a token to pin its tooltip open. Click elsewhere to dismiss.") | |
| text_input = gr.Textbox(label="Input Text", placeholder="Paste your text here…", lines=5) | |
| with gr.Row(): | |
| top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1) | |
| num_candidates_input = gr.Number(label="# Next Token Candidates", value=10, minimum=1, maximum=200, step=1) | |
| with gr.Row(): | |
| btn_analyze = gr.Button("Analyze", variant="primary") | |
| btn_predict = gr.Button("Predict Next", variant="secondary") | |
| output_analysis = gr.HTML(label="Analysis Output") | |
| output_prediction = gr.HTML(label="Predicted Next Tokens") | |
| btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis) | |
| btn_predict.click(fn=predict_next, inputs=[text_input, num_candidates_input], outputs=output_prediction) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| theme=gr.themes.Soft(), | |
| css="footer{display:none!important}.main{max-width:960px;margin:auto}" | |
| ) |