File size: 7,755 Bytes
dcbf505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e1cfe
dcbf505
 
 
 
06e1cfe
dcbf505
06e1cfe
 
dcbf505
 
 
 
 
 
06e1cfe
 
 
 
 
 
 
 
 
 
 
 
 
 
dcbf505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f6f751
 
be68400
2f6f751
be68400
2f6f751
 
 
be68400
 
2f6f751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be68400
 
888f20b
be68400
 
 
 
dcbf505
be68400
2f6f751
be68400
 
 
 
 
 
2f6f751
be68400
 
2f6f751
dcbf505
888f20b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
import torch
import html as html_lib
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln123Paraphrase")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def get_color(p):
    hue = min(p * 120, 120)
    return f"hsl({hue},80%,35%)", f"hsla({hue},80%,50%,0.15)"

def analyze_text(text, top_k):
    top_k = max(1, int(top_k))
    if not text.strip():
        return "<p style='color:#999;text-align:center;padding:40px'>Paste some text and click Analyze.</p>"

    tokens = tokenizer.encode(text)
    if len(tokens) > 512:
        tokens = tokens[:512]

    with torch.no_grad():
        input_ids = torch.tensor([tokens]).to(device)
        all_logits = model(input_ids).logits[0].cpu()

    css = """<style>
.tc{display:flex;flex-wrap:wrap;gap:5px;padding:20px;line-height:2.4;font-family:'Segoe UI',sans-serif}
.tw{position:relative;display:inline-block}
.tk{padding:4px 7px;border-radius:6px;cursor:pointer;font-size:15px;transition:.2s;border:1px solid transparent;user-select:none}
.tw:hover .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999}
.tt{display:none;position:absolute;bottom:calc(100% + 8px);left:50%;transform:translateX(-50%);
background:#1a1a2e;color:#eee;padding:14px;border-radius:12px;font-size:13px;z-index:9999;
box-shadow:0 10px 30px rgba(0,0,0,.35);min-width:220px;max-height:350px;overflow-y:auto}
.tt::after{content:'';position:absolute;top:100%;left:0;width:100%;height:12px}
.tw:hover .tt{display:block}
.tw.pinned .tt{display:block}
.tw.pinned .tk{transform:translateY(-2px);box-shadow:0 4px 14px rgba(0,0,0,.18);border-color:#999;outline:2px solid #7fdbca}
.th{font-weight:700;font-size:14px;color:#7fdbca;border-bottom:1px solid #333;padding-bottom:6px;margin-bottom:6px}
.tp{color:#ffd700;margin-bottom:8px}
.at{color:#ff79c6;font-size:10px;text-transform:uppercase;letter-spacing:1px;margin-bottom:4px}
.aw{display:flex;justify-content:space-between;padding:2px 0;font-size:12px}
.aw .w{color:#c3cee3}.aw .p{color:#666;margin-left:14px}
.hi{font-weight:700;color:#7fdbca!important}
</style>
<script>
document.addEventListener('click', function(e) {
    const tk = e.target.closest('.tk');
    const tw = tk ? tk.closest('.tw') : null;
    if (tw) {
        const wasPinned = tw.classList.contains('pinned');
        document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned'));
        if (!wasPinned) tw.classList.add('pinned');
    } else if (!e.target.closest('.tt')) {
        document.querySelectorAll('.tw.pinned').forEach(el => el.classList.remove('pinned'));
    }
});
</script>"""

    parts = [css, '<div class="tc">']
    for i in range(len(tokens)):
        tok = html_lib.escape(tokenizer.decode([tokens[i]]))
        if i == 0:
            parts.append(f'<div class="tw"><span class="tk" style="background:rgba(128,128,128,.1);color:#888">{tok}</span></div>')
            continue

        probs = torch.softmax(all_logits[i - 1], dim=-1)
        actual_p = probs[tokens[i]].item()
        top_p, top_idx = probs.topk(top_k)
        color, bg = get_color(actual_p)

        rank = None
        alts = ""
        for j in range(top_k):
            a_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
            a_p = top_p[j].item()
            hit = top_idx[j].item() == tokens[i]
            if hit: rank = j + 1
            cls = ' class="w hi"' if hit else ' class="w"'
            pcls = ' class="p hi"' if hit else ' class="p"'
            alts += f'<div class="aw"><span{cls}>{a_text}</span><span{pcls}>{a_p:.4f}</span></div>'

        rank_s = f"rank #{rank}" if rank else f"rank &gt;{top_k}"
        tooltip = f'''<div class="tt">
<div class="th">&ldquo;{tok}&rdquo;</div>
<div class="tp">P = {actual_p:.4f} &nbsp;({rank_s})</div>
<div class="at">Top {top_k} alternatives</div>{alts}</div>'''

        parts.append(f'<div class="tw"><span class="tk" style="background:{bg};color:{color}">{tok}</span>{tooltip}</div>')

    parts.append('</div>')
    return ''.join(parts)


def predict_next(text, num_candidates):
    num_candidates = max(1, int(num_candidates))
    if not text.strip():
        return "<p style='color:#999;text-align:center;padding:40px'>Enter text and click Predict Next.</p>"

    tokens = tokenizer.encode(text)
    if len(tokens) > 512:
        tokens = tokens[:512]

    with torch.no_grad():
        input_ids = torch.tensor([tokens]).to(device)
        logits = model(input_ids).logits[0, -1].cpu()

    probs = torch.softmax(logits, dim=-1)
    log_probs = torch.log(probs)
    top_p, top_idx = probs.topk(num_candidates)
    top_lp = log_probs[top_idx]

    rows = ""
    for j in range(num_candidates):
        tok_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
        p = top_p[j].item()
        lp = top_lp[j].item()
        bar_width = max(1, int(p * 100))
        hue = min(p * 120, 120)
        rows += f"""<tr>
<td style="padding:6px 12px;font-weight:600;color:#e0e0e0;white-space:nowrap">{j+1}</td>
<td style="padding:6px 12px;font-family:monospace;font-size:15px;color:#7fdbca;white-space:nowrap">{tok_text}</td>
<td style="padding:6px 12px;width:100%">
  <div style="background:hsla({hue},80%,50%,0.25);border-radius:4px;height:22px;width:{bar_width}%;min-width:2px;display:flex;align-items:center;padding-left:6px">
    <span style="font-size:11px;color:hsl({hue},80%,70%);font-weight:600">{p:.4f}</span>
  </div>
</td>
<td style="padding:6px 12px;font-family:monospace;font-size:13px;color:#888;white-space:nowrap">{lp:.4f}</td>
</tr>"""

    html = f"""<div style="font-family:'Segoe UI',sans-serif;background:#1a1a2e;border-radius:12px;padding:16px;overflow-x:auto">
<div style="color:#ff79c6;font-size:11px;text-transform:uppercase;letter-spacing:1px;margin-bottom:10px">
Top {num_candidates} predicted next tokens</div>
<table style="width:100%;border-collapse:collapse">
<thead><tr style="border-bottom:1px solid #333">
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">#</th>
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">TOKEN</th>
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">PROBABILITY</th>
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">LOG PROB</th>
</tr></thead>
<tbody>{rows}</tbody>
</table></div>"""
    return html


with gr.Blocks() as demo:
    gr.Markdown("# 🔍 Token Probability Explorer & Predictor\nPaste text, **hover** to preview or **click** a token to pin its tooltip open. Click elsewhere to dismiss.")

    text_input = gr.Textbox(label="Input Text", placeholder="Paste your text here…", lines=5)

    with gr.Row():
        top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1)
        num_candidates_input = gr.Number(label="# Next Token Candidates", value=10, minimum=1, maximum=200, step=1)

    with gr.Row():
        btn_analyze = gr.Button("Analyze", variant="primary")
        btn_predict = gr.Button("Predict Next", variant="secondary")

    output_analysis = gr.HTML(label="Analysis Output")
    output_prediction = gr.HTML(label="Predicted Next Tokens")

    btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis)
    btn_predict.click(fn=predict_next, inputs=[text_input, num_candidates_input], outputs=output_prediction)

demo.launch(
    server_name="0.0.0.0",
    server_port=7860,
    theme=gr.themes.Soft(),
    css="footer{display:none!important}.main{max-width:960px;margin:auto}"
)