BigSalmon commited on
Commit
2f6f751
·
verified ·
1 Parent(s): 888f20b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -22
app.py CHANGED
@@ -94,28 +94,55 @@ document.addEventListener('click', function(e) {
94
  return ''.join(parts)
95
 
96
 
97
- def predict_next(text, num_tokens, temperature):
 
98
  if not text.strip():
99
- return ""
100
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
101
- input_len = inputs['input_ids'].shape[1]
102
- max_tokens = min(int(num_tokens), 512 - input_len)
103
 
104
- if max_tokens <= 0:
105
- return "Input too long to generate more."
 
106
 
107
  with torch.no_grad():
108
- output_ids = model.generate(
109
- **inputs,
110
- max_new_tokens=max_tokens,
111
- do_sample=True if temperature > 0 else False,
112
- temperature=temperature if temperature > 0 else 1.0,
113
- top_k=50,
114
- pad_token_id=tokenizer.eos_token_id,
115
- repetition_penalty=1.1
116
- )
117
- result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
118
- return result[len(text):].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  with gr.Blocks() as demo:
@@ -125,18 +152,17 @@ with gr.Blocks() as demo:
125
 
126
  with gr.Row():
127
  top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1)
128
- num_tokens_input = gr.Number(label="# Tokens to Predict", value=10, minimum=1, maximum=100, step=1)
129
- temperature_input = gr.Slider(label="Temperature (Prediction)", minimum=0.0, maximum=2.0, value=0.7, step=0.05)
130
 
131
  with gr.Row():
132
  btn_analyze = gr.Button("Analyze", variant="primary")
133
  btn_predict = gr.Button("Predict Next", variant="secondary")
134
 
135
  output_analysis = gr.HTML(label="Analysis Output")
136
- output_prediction = gr.Textbox(label="Predicted Continuation", lines=3, interactive=False)
137
 
138
  btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis)
139
- btn_predict.click(fn=predict_next, inputs=[text_input, num_tokens_input, temperature_input], outputs=output_prediction)
140
 
141
  demo.launch(
142
  server_name="0.0.0.0",
 
94
  return ''.join(parts)
95
 
96
 
97
+ def predict_next(text, num_candidates):
98
+ num_candidates = max(1, int(num_candidates))
99
  if not text.strip():
100
+ return "<p style='color:#999;text-align:center;padding:40px'>Enter text and click Predict Next.</p>"
 
 
 
101
 
102
+ tokens = tokenizer.encode(text)
103
+ if len(tokens) > 512:
104
+ tokens = tokens[:512]
105
 
106
  with torch.no_grad():
107
+ input_ids = torch.tensor([tokens]).to(device)
108
+ logits = model(input_ids).logits[0, -1].cpu()
109
+
110
+ probs = torch.softmax(logits, dim=-1)
111
+ log_probs = torch.log(probs)
112
+ top_p, top_idx = probs.topk(num_candidates)
113
+ top_lp = log_probs[top_idx]
114
+
115
+ rows = ""
116
+ for j in range(num_candidates):
117
+ tok_text = html_lib.escape(tokenizer.decode([top_idx[j].item()]))
118
+ p = top_p[j].item()
119
+ lp = top_lp[j].item()
120
+ bar_width = max(1, int(p * 100))
121
+ hue = min(p * 120, 120)
122
+ rows += f"""<tr>
123
+ <td style="padding:6px 12px;font-weight:600;color:#e0e0e0;white-space:nowrap">{j+1}</td>
124
+ <td style="padding:6px 12px;font-family:monospace;font-size:15px;color:#7fdbca;white-space:nowrap">{tok_text}</td>
125
+ <td style="padding:6px 12px;width:100%">
126
+ <div style="background:hsla({hue},80%,50%,0.25);border-radius:4px;height:22px;width:{bar_width}%;min-width:2px;display:flex;align-items:center;padding-left:6px">
127
+ <span style="font-size:11px;color:hsl({hue},80%,70%);font-weight:600">{p:.4f}</span>
128
+ </div>
129
+ </td>
130
+ <td style="padding:6px 12px;font-family:monospace;font-size:13px;color:#888;white-space:nowrap">{lp:.4f}</td>
131
+ </tr>"""
132
+
133
+ html = f"""<div style="font-family:'Segoe UI',sans-serif;background:#1a1a2e;border-radius:12px;padding:16px;overflow-x:auto">
134
+ <div style="color:#ff79c6;font-size:11px;text-transform:uppercase;letter-spacing:1px;margin-bottom:10px">
135
+ Top {num_candidates} predicted next tokens</div>
136
+ <table style="width:100%;border-collapse:collapse">
137
+ <thead><tr style="border-bottom:1px solid #333">
138
+ <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">#</th>
139
+ <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">TOKEN</th>
140
+ <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">PROBABILITY</th>
141
+ <th style="padding:6px 12px;text-align:left;color:#666;font-size:11px">LOG PROB</th>
142
+ </tr></thead>
143
+ <tbody>{rows}</tbody>
144
+ </table></div>"""
145
+ return html
146
 
147
 
148
  with gr.Blocks() as demo:
 
152
 
153
  with gr.Row():
154
  top_k_input = gr.Number(label="# Alternatives (Analysis)", value=10, minimum=1, maximum=200, step=1)
155
+ num_candidates_input = gr.Number(label="# Next Token Candidates", value=10, minimum=1, maximum=200, step=1)
 
156
 
157
  with gr.Row():
158
  btn_analyze = gr.Button("Analyze", variant="primary")
159
  btn_predict = gr.Button("Predict Next", variant="secondary")
160
 
161
  output_analysis = gr.HTML(label="Analysis Output")
162
+ output_prediction = gr.HTML(label="Predicted Next Tokens")
163
 
164
  btn_analyze.click(fn=analyze_text, inputs=[text_input, top_k_input], outputs=output_analysis)
165
+ btn_predict.click(fn=predict_next, inputs=[text_input, num_candidates_input], outputs=output_prediction)
166
 
167
  demo.launch(
168
  server_name="0.0.0.0",