LogProbs7 / app.py
BigSalmon's picture
Update app.py
2f6f751 verified
import gradio as gr
import torch
import html as html_lib
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def get_color(p):
hue = min(p * 120, 120)
return f"hsl({hue},80%,35%)", f"hsla({hue},80%,50%,0.15)"
def analyze_text(text, top_k):
top_k = max(1, int(top_k))
if not text.strip():
return "<p style='color:#999;text-align:center;padding:40px'>Paste some text and click Analyze.</p>"
tokens = tokenizer.encode(text)
if len(tokens) > 512:
tokens = tokens[:512]
with torch.no_grad():
input_ids = torch.tensor([tokens]).to(device)
all_logits = model(input_ids).logits[0].cpu()
css = """<style>
.tc{display:flex;flex-wrap:wrap;gap:5px;padding:20px;line-height:2.4;font-family:'Segoe UI',sans-serif}
.tw{position:relative;display:inline-block}
.tk{padding:4px 7px;border-radius:6px;cursor:pointer;font-size:15px;transition:.2s;border:1px solid transparent;user-select:none}
.tw:hover .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999}
.tt{display:none;position:absolute;bottom:calc(100% + 8px);left:50%;transform:translateX(-50%);
background:#1a1a2e;color:#eee;padding:14px;border-radius:12px;font-size:13px;z-index:9999;
box-shadow:0 10px 30px rgba(0,0,0,.35);min-width:220px;max-height:350px;overflow-y:auto}
.tt::after{content:'';position:absolute;top:100%;left:0;width:100%;height:12px}
.tw:hover .tt{display:block}
.tw.pinned .tt{display:block}
.tw.pinned .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999;outline:2px solid #7fdbca}
.th{font-weight:700;font-size:14px;color:#7fdbca;border-bottom:1px solid #333;padding-bottom:6px;margin-bottom:6px}
.tp{color:#ffd700;margin-bottom:8px}
.at{color:#ff79c6;font-size:10px;text-transform:uppercase;letter-spacing:1px;margin-bottom:4px}
.aw{display:flex;justify-content:space-between;padding:2px 0;font-size:12px}
.aw .w{color:#c3cee3}.aw .p{color:#666;margin-left:14px}
.hi{font-weight:700;color:#7fdbca!important}
</style>
<script>
document.addEventListener('click', function(e) {
const tk = e.target.closest('.tk');
const tw = tk ? tk.closest('.tw') : null;
if (tw) {
const wasPinned = tw.classList.contains('pinned');
document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned'));
if (!wasPinned) tw.classList.add('pinned');
} else if (!e.target.closest('.tt')) {
document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned'));
}
});
</script>"""
parts = [css, '<div class="tc">']
for i in range(len(tokens)):
tok = html_lib.escape(tokenizer.decode([tokens[i]]))
if i == 0:
parts.append(f'<div class="tw"><span class="tk" style="background:rgba(128,128,128,.1);color:#888">{tok}</span></div>')
continue
probs = torch.softmax(all_logits[i - 1], dim=-1)
actual_p = probs[tokens[i]].item()
top_p, top_idx = probs.topk(top_k)
color, bg = get_color(actual_p)
rank = None
alts = ""
for j in range(top_k):
a_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
a_p = top_p[j].item()
hit = top_idx[j].item() == tokens[i]
if hit: rank = j + 1
cls = ' class="w hi"' if hit else ' class="w"'
pcls = ' class="p hi"' if hit else ' class="p"'
alts += f'<div class="aw"><span{cls}>{a_text}</span><span{pcls}>{a_p:.4f}</span></div>'
rank_s = f"rank #{rank}" if rank else f"rank &gt;{top_k}"
tooltip = f'''<div class="tt">
<div class="th">&ldquo;{tok}&rdquo;</div>
<div class="tp">P = {actual_p:.4f} &nbsp;({rank_s})</div>
<div class="at">Top {top_k} alternatives</div>{alts}</div>'''
parts.append(f'<div class="tw"><span class="tk" style="background:{bg};color:{color}">{tok}</span>{tooltip}</div>')
parts.append('</div>')
return ''.join(parts)
def predict_next(text, num_candidates):
num_candidates = max(1, int(num_candidates))
if not text.strip():
return "<p style='color:#999;text-align:center;padding:40px'>Enter text and click Predict Next.</p>"
tokens = tokenizer.encode(text)
if len(tokens) > 512:
tokens = tokens[:512]
with torch.no_grad():
input_ids = torch.tensor([tokens]).to(device)
logits = model(input_ids).logits[0, -1].cpu()
probs = torch.softmax(logits, dim=-1)
log_probs = torch.log(probs)
top_p, top_idx = probs.topk(num_candidates)
top_lp = log_probs[top_idx]
rows = ""
for j in range(num_candidates):
tok_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
p = top_p[j].item()
lp = top_lp[j].item()
bar_width = max(1, int(p * 100))
hue = min(p * 120, 120)
rows += f"""<tr>
<td style="padding:6px 12px;font-weight:600;color:#e0e0e0;white-space:nowrap">{j+1}</td>
<td style="padding:6px 12px;font-family:monospace;font-size:15px;color:#7fdbca;white-space:nowrap">{tok_text}</td>
<td style="padding:6px 12px;width:100%">
<div style="background:hsla({hue},80%,50%,0.25);border-radius:4px;height:22px;width:{bar_width}%;min-width:2px;display:flex;align-items:center;padding-left:6px">
<span style="font-size:11px;color:hsl({hue},80%,70%);font-weight:600">{p:.4f}</span>
</div>
</td>
<td style="padding:6px 12px;font-family:monospace;font-size:13px;color:#888;white-space:nowrap">{lp:.4f}</td>
</tr>"""
html = f"""<div style="font-family:'Segoe UI',sans-serif;background:#1a1a2e;border-radius:12px;padding:16px;overflow-x:auto">
<div style="color:#ff79c6;font-size:11px;text-transform:uppercase;letter-spacing:1px;margin-bottom:10px">
Top {num_candidates} predicted next tokens</div>
<table style="width:100%;border-collapse:collapse">
<thead><tr style="border-bottom:1px solid #333">
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">#</th>
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">TOKEN</th>
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">PROBABILITY</th>
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">LOG PROB</th>
</tr></thead>
<tbody>{rows}</tbody>
</table></div>"""
return html
with gr.Blocks() as demo:
gr.Markdown("# 🔍 Token Probability Explorer & Predictor\nPaste text, **hover** to preview or **click** a token to pin its tooltip open. Click elsewhere to dismiss.")
text_input = gr.Textbox(label="Input Text", placeholder="Paste your text here…", lines=5)
with gr.Row():
top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1)
num_candidates_input = gr.Number(label="# Next Token Candidates", value=10, minimum=1, maximum=200, step=1)
with gr.Row():
btn_analyze = gr.Button("Analyze", variant="primary")
btn_predict = gr.Button("Predict Next", variant="secondary")
output_analysis = gr.HTML(label="Analysis Output")
output_prediction = gr.HTML(label="Predicted Next Tokens")
btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis)
btn_predict.click(fn=predict_next, inputs=[text_input, num_candidates_input], outputs=output_prediction)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Soft(),
css="footer{display:none!important}.main{max-width:960px;margin:auto}"
)