BigSalmon commited on
Commit
dcbf505
·
verified ·
1 Parent(s): 119a508

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import html as html_lib
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
7
+ model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
8
+ model.eval()
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model.to(device)
11
+
12
+ def get_color(p):
13
+ hue = min(p * 120, 120)
14
+ return f"hsl({hue},80%,35%)", f"hsla({hue},80%,50%,0.15)"
15
+
16
+ def analyze_text(text, top_k):
17
+ top_k = max(1, int(top_k))
18
+ if not text.strip():
19
+ return "<p style='color:#999;text-align:center;padding:40px'>Paste some text and click Analyze.</p>"
20
+
21
+ tokens = tokenizer.encode(text)
22
+ if len(tokens) > 512:
23
+ tokens = tokens[:512]
24
+
25
+ with torch.no_grad():
26
+ input_ids = torch.tensor([tokens]).to(device)
27
+ all_logits = model(input_ids).logits[0].cpu()
28
+
29
+ css = """<style>
30
+ .tc{display:flex;flex-wrap:wrap;gap:5px;padding:20px;line-height:2.4;font-family:'Segoe UI',sans-serif}
31
+ .tw{position:relative;display:inline-block}
32
+ .tk{padding:4px 7px;border-radius:6px;cursor:default;font-size:15px;transition:.2s;border:1px solid transparent}
33
+ .tw:hover .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999}
34
+ .tt{display:none;position:absolute;bottom:calc(100% + 8px);left:50%;transform:translateX(-50%);
35
+ background:#1a1a2e;color:#eee;padding:14px;border-radius:12px;font-size:13px;z-index:9999;
36
+ box-shadow:0 10px 30px rgba(0,0,0,.35);min-width:220px;max-height:350px;overflow-y:auto}
37
+ .tw:hover .tt{display:block}
38
+ .th{font-weight:700;font-size:14px;color:#7fdbca;border-bottom:1px solid #333;padding-bottom:6px;margin-bottom:6px}
39
+ .tp{color:#ffd700;margin-bottom:8px}
40
+ .at{color:#ff79c6;font-size:10px;text-transform:uppercase;letter-spacing:1px;margin-bottom:4px}
41
+ .aw{display:flex;justify-content:space-between;padding:2px 0;font-size:12px}
42
+ .aw .w{color:#c3cee3}.aw .p{color:#666;margin-left:14px}
43
+ .hi{font-weight:700;color:#7fdbca!important}
44
+ </style>"""
45
+
46
+ parts = [css, '<div class="tc">']
47
+ for i in range(len(tokens)):
48
+ tok = html_lib.escape(tokenizer.decode([tokens[i]]))
49
+ if i == 0:
50
+ parts.append(f'<div class="tw"><span class="tk" style="background:rgba(128,128,128,.1);color:#888">{tok}</span></div>')
51
+ continue
52
+
53
+ probs = torch.softmax(all_logits[i - 1], dim=-1)
54
+ actual_p = probs[tokens[i]].item()
55
+ top_p, top_idx = probs.topk(top_k)
56
+ color, bg = get_color(actual_p)
57
+
58
+ rank = None
59
+ alts = ""
60
+ for j in range(top_k):
61
+ a_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
62
+ a_p = top_p[j].item()
63
+ hit = top_idx[j].item() == tokens[i]
64
+ if hit: rank = j + 1
65
+ cls = ' class="w hi"' if hit else ' class="w"'
66
+ pcls = ' class="p hi"' if hit else ' class="p"'
67
+ alts += f'<div class="aw"><span{cls}>{a_text}</span><span{pcls}>{a_p:.4f}</span></div>'
68
+
69
+ rank_s = f"rank #{rank}" if rank else f"rank &gt;{top_k}"
70
+ tooltip = f'''<div class="tt">
71
+ <div class="th">&ldquo;{tok}&rdquo;</div>
72
+ <div class="tp">P = {actual_p:.4f} &nbsp;({rank_s})</div>
73
+ <div class="at">Top {top_k} alternatives</div>{alts}</div>'''
74
+
75
+ parts.append(f'<div class="tw"><span class="tk" style="background:{bg};color:{color}">{tok}</span>{tooltip}</div>')
76
+
77
+ parts.append('</div>')
78
+ return ''.join(parts)
79
+
80
+
81
+ with gr.Blocks(theme=gr.themes.Soft(), css="footer{display:none!important}.main{max-width:960px;margin:auto}") as demo:
82
+ gr.Markdown("# 🔍 Token Probability Explorer\nPaste text, hover over each token to see its probability and the most likely alternatives.")
83
+ with gr.Row():
84
+ text_input = gr.Textbox(label="Input Text", placeholder="Paste your text here…", lines=5, scale=4)
85
+ top_k_input = gr.Number(label="# Alternatives", value=10, minimum=1, maximum=200, step=1, scale=1)
86
+ btn = gr.Button("Analyze", variant="primary")
87
+ output = gr.HTML()
88
+ btn.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output)
89
+
90
+ demo.launch()