File size: 7,755 Bytes
dcbf505 06e1cfe dcbf505 06e1cfe dcbf505 06e1cfe dcbf505 06e1cfe dcbf505 2f6f751 be68400 2f6f751 be68400 2f6f751 be68400 2f6f751 be68400 888f20b be68400 dcbf505 be68400 2f6f751 be68400 2f6f751 be68400 2f6f751 dcbf505 888f20b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import gradio as gr
import torch
import html as html_lib
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def get_color(p):
hue = min(p * 120, 120)
return f"hsl({hue},80%,35%)", f"hsla({hue},80%,50%,0.15)"
def analyze_text(text, top_k):
top_k = max(1, int(top_k))
if not text.strip():
return "<p style='color:#999;text-align:center;padding:40px'>Paste some text and click Analyze.</p>"
tokens = tokenizer.encode(text)
if len(tokens) > 512:
tokens = tokens[:512]
with torch.no_grad():
input_ids = torch.tensor([tokens]).to(device)
all_logits = model(input_ids).logits[0].cpu()
css = """<style>
.tc{display:flex;flex-wrap:wrap;gap:5px;padding:20px;line-height:2.4;font-family:'Segoe UI',sans-serif}
.tw{position:relative;display:inline-block}
.tk{padding:4px 7px;border-radius:6px;cursor:pointer;font-size:15px;transition:.2s;border:1px solid transparent;user-select:none}
.tw:hover .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999}
.tt{display:none;position:absolute;bottom:calc(100% + 8px);left:50%;transform:translateX(-50%);
background:#1a1a2e;color:#eee;padding:14px;border-radius:12px;font-size:13px;z-index:9999;
box-shadow:0 10px 30px rgba(0,0,0,.35);min-width:220px;max-height:350px;overflow-y:auto}
.tt::after{content:'';position:absolute;top:100%;left:0;width:100%;height:12px}
.tw:hover .tt{display:block}
.tw.pinned .tt{display:block}
.tw.pinned .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999;outline:2px solid #7fdbca}
.th{font-weight:700;font-size:14px;color:#7fdbca;border-bottom:1px solid #333;padding-bottom:6px;margin-bottom:6px}
.tp{color:#ffd700;margin-bottom:8px}
.at{color:#ff79c6;font-size:10px;text-transform:uppercase;letter-spacing:1px;margin-bottom:4px}
.aw{display:flex;justify-content:space-between;padding:2px 0;font-size:12px}
.aw .w{color:#c3cee3}.aw .p{color:#666;margin-left:14px}
.hi{font-weight:700;color:#7fdbca!important}
</style>
<script>
document.addEventListener('click', function(e) {
const tk = e.target.closest('.tk');
const tw = tk ? tk.closest('.tw') : null;
if (tw) {
const wasPinned = tw.classList.contains('pinned');
document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned'));
if (!wasPinned) tw.classList.add('pinned');
} else if (!e.target.closest('.tt')) {
document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned'));
}
});
</script>"""
parts = [css, '<div class="tc">']
for i in range(len(tokens)):
tok = html_lib.escape(tokenizer.decode([tokens[i]]))
if i == 0:
parts.append(f'<div class="tw"><span class="tk" style="background:rgba(128,128,128,.1);color:#888">{tok}</span></div>')
continue
probs = torch.softmax(all_logits[i - 1], dim=-1)
actual_p = probs[tokens[i]].item()
top_p, top_idx = probs.topk(top_k)
color, bg = get_color(actual_p)
rank = None
alts = ""
for j in range(top_k):
a_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
a_p = top_p[j].item()
hit = top_idx[j].item() == tokens[i]
if hit: rank = j + 1
cls = ' class="w hi"' if hit else ' class="w"'
pcls = ' class="p hi"' if hit else ' class="p"'
alts += f'<div class="aw"><span{cls}>{a_text}</span><span{pcls}>{a_p:.4f}</span></div>'
rank_s = f"rank #{rank}" if rank else f"rank >{top_k}"
tooltip = f'''<div class="tt">
<div class="th">“{tok}”</div>
<div class="tp">P = {actual_p:.4f} ({rank_s})</div>
<div class="at">Top {top_k} alternatives</div>{alts}</div>'''
parts.append(f'<div class="tw"><span class="tk" style="background:{bg};color:{color}">{tok}</span>{tooltip}</div>')
parts.append('</div>')
return ''.join(parts)
def predict_next(text, num_candidates):
num_candidates = max(1, int(num_candidates))
if not text.strip():
return "<p style='color:#999;text-align:center;padding:40px'>Enter text and click Predict Next.</p>"
tokens = tokenizer.encode(text)
if len(tokens) > 512:
tokens = tokens[:512]
with torch.no_grad():
input_ids = torch.tensor([tokens]).to(device)
logits = model(input_ids).logits[0, -1].cpu()
probs = torch.softmax(logits, dim=-1)
log_probs = torch.log(probs)
top_p, top_idx = probs.topk(num_candidates)
top_lp = log_probs[top_idx]
rows = ""
for j in range(num_candidates):
tok_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
p = top_p[j].item()
lp = top_lp[j].item()
bar_width = max(1, int(p * 100))
hue = min(p * 120, 120)
rows += f"""<tr>
<td style="padding:6px 12px;font-weight:600;color:#e0e0e0;white-space:nowrap">{j+1}</td>
<td style="padding:6px 12px;font-family:monospace;font-size:15px;color:#7fdbca;white-space:nowrap">{tok_text}</td>
<td style="padding:6px 12px;width:100%">
<div style="background:hsla({hue},80%,50%,0.25);border-radius:4px;height:22px;width:{bar_width}%;min-width:2px;display:flex;align-items:center;padding-left:6px">
<span style="font-size:11px;color:hsl({hue},80%,70%);font-weight:600">{p:.4f}</span>
</div>
</td>
<td style="padding:6px 12px;font-family:monospace;font-size:13px;color:#888;white-space:nowrap">{lp:.4f}</td>
</tr>"""
html = f"""<div style="font-family:'Segoe UI',sans-serif;background:#1a1a2e;border-radius:12px;padding:16px;overflow-x:auto">
<div style="color:#ff79c6;font-size:11px;text-transform:uppercase;letter-spacing:1px;margin-bottom:10px">
Top {num_candidates} predicted next tokens</div>
<table style="width:100%;border-collapse:collapse">
<thead><tr style="border-bottom:1px solid #333">
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">#</th>
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">TOKEN</th>
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">PROBABILITY</th>
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">LOG PROB</th>
</tr></thead>
<tbody>{rows}</tbody>
</table></div>"""
return html
with gr.Blocks() as demo:
gr.Markdown("# 🔍 Token Probability Explorer & Predictor\nPaste text, **hover** to preview or **click** a token to pin its tooltip open. Click elsewhere to dismiss.")
text_input = gr.Textbox(label="Input Text", placeholder="Paste your text here…", lines=5)
with gr.Row():
top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1)
num_candidates_input = gr.Number(label="# Next Token Candidates", value=10, minimum=1, maximum=200, step=1)
with gr.Row():
btn_analyze = gr.Button("Analyze", variant="primary")
btn_predict = gr.Button("Predict Next", variant="secondary")
output_analysis = gr.HTML(label="Analysis Output")
output_prediction = gr.HTML(label="Predicted Next Tokens")
btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis)
btn_predict.click(fn=predict_next, inputs=[text_input, num_candidates_input], outputs=output_prediction)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Soft(),
css="footer{display:none!important}.main{max-width:960px;margin:auto}"
) |