Update app.py
Browse files
app.py
CHANGED
|
@@ -94,28 +94,55 @@ document.addEventListener('click', function(e) {
|
|
| 94 |
return ''.join(parts)
|
| 95 |
|
| 96 |
|
| 97 |
-
def predict_next(text,
|
|
|
|
| 98 |
if not text.strip():
|
| 99 |
-
return ""
|
| 100 |
-
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
|
| 101 |
-
input_len = inputs['input_ids'].shape[1]
|
| 102 |
-
max_tokens = min(int(num_tokens), 512 - input_len)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
|
| 107 |
with torch.no_grad():
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
with gr.Blocks() as demo:
|
|
@@ -125,18 +152,17 @@ with gr.Blocks() as demo:
|
|
| 125 |
|
| 126 |
with gr.Row():
|
| 127 |
top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1)
|
| 128 |
-
|
| 129 |
-
temperature_input = gr.Slider(label="Temperature (Prediction)", minimum=0.0, maximum=2.0, value=0.7, step=0.05)
|
| 130 |
|
| 131 |
with gr.Row():
|
| 132 |
btn_analyze = gr.Button("Analyze", variant="primary")
|
| 133 |
btn_predict = gr.Button("Predict Next", variant="secondary")
|
| 134 |
|
| 135 |
output_analysis = gr.HTML(label="Analysis Output")
|
| 136 |
-
output_prediction = gr.
|
| 137 |
|
| 138 |
btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis)
|
| 139 |
-
btn_predict.click(fn=predict_next, inputs=[text_input,
|
| 140 |
|
| 141 |
demo.launch(
|
| 142 |
server_name="0.0.0.0",
|
|
|
|
| 94 |
return ''.join(parts)
|
| 95 |
|
| 96 |
|
| 97 |
+
def predict_next(text, num_candidates):
|
| 98 |
+
num_candidates = max(1, int(num_candidates))
|
| 99 |
if not text.strip():
|
| 100 |
+
return "<p style='color:#999;text-align:center;padding:40px'>Enter text and click Predict Next.</p>"
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
tokens = tokenizer.encode(text)
|
| 103 |
+
if len(tokens) > 512:
|
| 104 |
+
tokens = tokens[:512]
|
| 105 |
|
| 106 |
with torch.no_grad():
|
| 107 |
+
input_ids = torch.tensor([tokens]).to(device)
|
| 108 |
+
logits = model(input_ids).logits[0, -1].cpu()
|
| 109 |
+
|
| 110 |
+
probs = torch.softmax(logits, dim=-1)
|
| 111 |
+
log_probs = torch.log(probs)
|
| 112 |
+
top_p, top_idx = probs.topk(num_candidates)
|
| 113 |
+
top_lp = log_probs[top_idx]
|
| 114 |
+
|
| 115 |
+
rows = ""
|
| 116 |
+
for j in range(num_candidates):
|
| 117 |
+
tok_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
|
| 118 |
+
p = top_p[j].item()
|
| 119 |
+
lp = top_lp[j].item()
|
| 120 |
+
bar_width = max(1, int(p * 100))
|
| 121 |
+
hue = min(p * 120, 120)
|
| 122 |
+
rows += f"""<tr>
|
| 123 |
+
<td style="padding:6px 12px;font-weight:600;color:#e0e0e0;white-space:nowrap">{j+1}</td>
|
| 124 |
+
<td style="padding:6px 12px;font-family:monospace;font-size:15px;color:#7fdbca;white-space:nowrap">{tok_text}</td>
|
| 125 |
+
<td style="padding:6px 12px;width:100%">
|
| 126 |
+
<div style="background:hsla({hue},80%,50%,0.25);border-radius:4px;height:22px;width:{bar_width}%;min-width:2px;display:flex;align-items:center;padding-left:6px">
|
| 127 |
+
<span style="font-size:11px;color:hsl({hue},80%,70%);font-weight:600">{p:.4f}</span>
|
| 128 |
+
</div>
|
| 129 |
+
</td>
|
| 130 |
+
<td style="padding:6px 12px;font-family:monospace;font-size:13px;color:#888;white-space:nowrap">{lp:.4f}</td>
|
| 131 |
+
</tr>"""
|
| 132 |
+
|
| 133 |
+
html = f"""<div style="font-family:'Segoe UI',sans-serif;background:#1a1a2e;border-radius:12px;padding:16px;overflow-x:auto">
|
| 134 |
+
<div style="color:#ff79c6;font-size:11px;text-transform:uppercase;letter-spacing:1px;margin-bottom:10px">
|
| 135 |
+
Top {num_candidates} predicted next tokens</div>
|
| 136 |
+
<table style="width:100%;border-collapse:collapse">
|
| 137 |
+
<thead><tr style="border-bottom:1px solid #333">
|
| 138 |
+
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">#</th>
|
| 139 |
+
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">TOKEN</th>
|
| 140 |
+
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">PROBABILITY</th>
|
| 141 |
+
<th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">LOG PROB</th>
|
| 142 |
+
</tr></thead>
|
| 143 |
+
<tbody>{rows}</tbody>
|
| 144 |
+
</table></div>"""
|
| 145 |
+
return html
|
| 146 |
|
| 147 |
|
| 148 |
with gr.Blocks() as demo:
|
|
|
|
| 152 |
|
| 153 |
with gr.Row():
|
| 154 |
top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1)
|
| 155 |
+
num_candidates_input = gr.Number(label="# Next Token Candidates", value=10, minimum=1, maximum=200, step=1)
|
|
|
|
| 156 |
|
| 157 |
with gr.Row():
|
| 158 |
btn_analyze = gr.Button("Analyze", variant="primary")
|
| 159 |
btn_predict = gr.Button("Predict Next", variant="secondary")
|
| 160 |
|
| 161 |
output_analysis = gr.HTML(label="Analysis Output")
|
| 162 |
+
output_prediction = gr.HTML(label="Predicted Next Tokens")
|
| 163 |
|
| 164 |
btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis)
|
| 165 |
+
btn_predict.click(fn=predict_next, inputs=[text_input, num_candidates_input], outputs=output_prediction)
|
| 166 |
|
| 167 |
demo.launch(
|
| 168 |
server_name="0.0.0.0",
|