import gradio as gr import torch import html as html_lib from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase") model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase") model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) def get_color(p): hue = min(p * 120, 120) return f"hsl({hue},80%,35%)", f"hsla({hue},80%,50%,0.15)" def analyze_text(text, top_k): top_k = max(1, int(top_k)) if not text.strip(): return "

Paste some text and click Analyze.

" tokens = tokenizer.encode(text) if len(tokens) > 512: tokens = tokens[:512] with torch.no_grad(): input_ids = torch.tensor([tokens]).to(device) all_logits = model(input_ids).logits[0].cpu() css = """ """ parts = [css, '
'] for i in range(len(tokens)): tok = html_lib.escape(tokenizer.decode([tokens[i]])) if i == 0: parts.append(f'
{tok}
') continue probs = torch.softmax(all_logits[i - 1], dim=-1) actual_p = probs[tokens[i]].item() top_p, top_idx = probs.topk(top_k) color, bg = get_color(actual_p) rank = None alts = "" for j in range(top_k): a_text = html_lib.escape(tokenizer.decode([top_idx[j].item()])) a_p = top_p[j].item() hit = top_idx[j].item() == tokens[i] if hit: rank = j + 1 cls = ' class="w hi"' if hit else ' class="w"' pcls = ' class="p hi"' if hit else ' class="p"' alts += f'
{a_text}{a_p:.4f}
' rank_s = f"rank #{rank}" if rank else f"rank >{top_k}" tooltip = f'''
“{tok}”
P = {actual_p:.4f}  ({rank_s})
Top {top_k} alternatives
{alts}
''' parts.append(f'
{tok}{tooltip}
') parts.append('
') return ''.join(parts) def predict_next(text, num_candidates): num_candidates = max(1, int(num_candidates)) if not text.strip(): return "

Enter text and click Predict Next.

" tokens = tokenizer.encode(text) if len(tokens) > 512: tokens = tokens[:512] with torch.no_grad(): input_ids = torch.tensor([tokens]).to(device) logits = model(input_ids).logits[0, -1].cpu() probs = torch.softmax(logits, dim=-1) log_probs = torch.log(probs) top_p, top_idx = probs.topk(num_candidates) top_lp = log_probs[top_idx] rows = "" for j in range(num_candidates): tok_text = html_lib.escape(tokenizer.decode([top_idx[j].item()])) p = top_p[j].item() lp = top_lp[j].item() bar_width = max(1, int(p * 100)) hue = min(p * 120, 120) rows += f""" {j+1} {tok_text}
{p:.4f}
{lp:.4f} """ html = f"""
Top {num_candidates} predicted next tokens
{rows}
# TOKEN PROBABILITY LOG PROB
""" return html with gr.Blocks() as demo: gr.Markdown("# 🔍 Token Probability Explorer & Predictor\nPaste text, **hover** to preview or **click** a token to pin its tooltip open. Click elsewhere to dismiss.") text_input = gr.Textbox(label="Input Text", placeholder="Paste your text here…", lines=5) with gr.Row(): top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1) num_candidates_input = gr.Number(label="# Next Token Candidates", value=10, minimum=1, maximum=200, step=1) with gr.Row(): btn_analyze = gr.Button("Analyze", variant="primary") btn_predict = gr.Button("Predict Next", variant="secondary") output_analysis = gr.HTML(label="Analysis Output") output_prediction = gr.HTML(label="Predicted Next Tokens") btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis) btn_predict.click(fn=predict_next, inputs=[text_input, num_candidates_input], outputs=output_prediction) demo.launch( server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), css="footer{display:none!important}.main{max-width:960px;margin:auto}" )