import gradio as gr
import torch
import html as html_lib
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def get_color(p):
hue = min(p * 120, 120)
return f"hsl({hue},80%,35%)", f"hsla({hue},80%,50%,0.15)"
def analyze_text(text, top_k):
top_k = max(1, int(top_k))
if not text.strip():
return "
Paste some text and click Analyze.
"
tokens = tokenizer.encode(text)
if len(tokens) > 512:
tokens = tokens[:512]
with torch.no_grad():
input_ids = torch.tensor([tokens]).to(device)
all_logits = model(input_ids).logits[0].cpu()
css = """
"""
parts = [css, '']
for i in range(len(tokens)):
tok = html_lib.escape(tokenizer.decode([tokens[i]]))
if i == 0:
parts.append(f'
{tok}
')
continue
probs = torch.softmax(all_logits[i - 1], dim=-1)
actual_p = probs[tokens[i]].item()
top_p, top_idx = probs.topk(top_k)
color, bg = get_color(actual_p)
rank = None
alts = ""
for j in range(top_k):
a_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
a_p = top_p[j].item()
hit = top_idx[j].item() == tokens[i]
if hit: rank = j + 1
cls = ' class="w hi"' if hit else ' class="w"'
pcls = ' class="p hi"' if hit else ' class="p"'
alts += f'
{a_text}{a_p:.4f}
'
rank_s = f"rank #{rank}" if rank else f"rank >{top_k}"
tooltip = f'''
“{tok}”
P = {actual_p:.4f} ({rank_s})
Top {top_k} alternatives
{alts}
'''
parts.append(f'
{tok}{tooltip}
')
parts.append('
')
return ''.join(parts)
def predict_next(text, num_candidates):
num_candidates = max(1, int(num_candidates))
if not text.strip():
return "Enter text and click Predict Next.
"
tokens = tokenizer.encode(text)
if len(tokens) > 512:
tokens = tokens[:512]
with torch.no_grad():
input_ids = torch.tensor([tokens]).to(device)
logits = model(input_ids).logits[0, -1].cpu()
probs = torch.softmax(logits, dim=-1)
log_probs = torch.log(probs)
top_p, top_idx = probs.topk(num_candidates)
top_lp = log_probs[top_idx]
rows = ""
for j in range(num_candidates):
tok_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
p = top_p[j].item()
lp = top_lp[j].item()
bar_width = max(1, int(p * 100))
hue = min(p * 120, 120)
rows += f"""
| {j+1} |
{tok_text} |
{p:.4f}
|
{lp:.4f} |
"""
html = f"""
Top {num_candidates} predicted next tokens
| # |
TOKEN |
PROBABILITY |
LOG PROB |
{rows}
"""
return html
with gr.Blocks() as demo:
gr.Markdown("# 🔍 Token Probability Explorer & Predictor\nPaste text, **hover** to preview or **click** a token to pin its tooltip open. Click elsewhere to dismiss.")
text_input = gr.Textbox(label="Input Text", placeholder="Paste your text here…", lines=5)
with gr.Row():
top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1)
num_candidates_input = gr.Number(label="# Next Token Candidates", value=10, minimum=1, maximum=200, step=1)
with gr.Row():
btn_analyze = gr.Button("Analyze", variant="primary")
btn_predict = gr.Button("Predict Next", variant="secondary")
output_analysis = gr.HTML(label="Analysis Output")
output_prediction = gr.HTML(label="Predicted Next Tokens")
btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis)
btn_predict.click(fn=predict_next, inputs=[text_input, num_candidates_input], outputs=output_prediction)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Soft(),
css="footer{display:none!important}.main{max-width:960px;margin:auto}"
)