yezdata commited on
Commit
b2cec4f
·
1 Parent(s): e7d33a4

remove top k slider

Browse files
Files changed (1) hide show
  1. main.py +3 -13
main.py CHANGED
@@ -12,7 +12,6 @@ repo_id = "yezdata/EmCoder"
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
14
  model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
15
- max_labels = getattr(model.config, "num_labels", 28)
16
 
17
  model.eval()
18
 
@@ -113,7 +112,7 @@ def predict_api(request: PredictRequest):
113
  def health_check():
114
  return {"status": "healthy"}
115
 
116
- def gradio_predict(text, top_n, monte_carlo, n_samples):
117
  request_data = PredictRequest(text=text, monte_carlo=bool(monte_carlo), n_samples=int(n_samples))
118
  response = predict_api(request_data)
119
 
@@ -123,12 +122,10 @@ def gradio_predict(text, top_n, monte_carlo, n_samples):
123
  reverse=True
124
  )
125
 
126
- top_preds = sorted_preds[:int(top_n)]
127
-
128
  standard_rows = []
129
  mc_rows = []
130
 
131
- for label_name, metrics in top_preds:
132
  if monte_carlo:
133
  prob = metrics["mean_probability"]
134
  mc_rows.append([
@@ -168,13 +165,6 @@ with gr.Blocks(title="EmCoder - Probabilistic Emotion Recognition") as ui:
168
  placeholder="Input text for classification...",
169
  lines=3
170
  )
171
- top_n_slider = gr.Slider(
172
- minimum=1,
173
- maximum=max_labels,
174
- value=min(5, max_labels),
175
- step=1,
176
- label="Top Emotions to display"
177
- )
178
  use_mc = gr.Checkbox(label="Use Monte Carlo Dropout (Uncertainty Estimation)", value=False)
179
  mc_samples_slider = gr.Slider(
180
  minimum=5,
@@ -208,7 +198,7 @@ with gr.Blocks(title="EmCoder - Probabilistic Emotion Recognition") as ui:
208
 
209
  submit_btn.click(
210
  fn=gradio_predict,
211
- inputs=[input_text, top_n_slider, use_mc, mc_samples_slider],
212
  outputs=[output_table_standard, output_table_mc]
213
  )
214
 
 
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
14
  model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
 
15
 
16
  model.eval()
17
 
 
112
  def health_check():
113
  return {"status": "healthy"}
114
 
115
+ def gradio_predict(text, monte_carlo, n_samples):
116
  request_data = PredictRequest(text=text, monte_carlo=bool(monte_carlo), n_samples=int(n_samples))
117
  response = predict_api(request_data)
118
 
 
122
  reverse=True
123
  )
124
 
 
 
125
  standard_rows = []
126
  mc_rows = []
127
 
128
+ for label_name, metrics in sorted_preds:
129
  if monte_carlo:
130
  prob = metrics["mean_probability"]
131
  mc_rows.append([
 
165
  placeholder="Input text for classification...",
166
  lines=3
167
  )
 
 
 
 
 
 
 
168
  use_mc = gr.Checkbox(label="Use Monte Carlo Dropout (Uncertainty Estimation)", value=False)
169
  mc_samples_slider = gr.Slider(
170
  minimum=5,
 
198
 
199
  submit_btn.click(
200
  fn=gradio_predict,
201
+ inputs=[input_text, use_mc, mc_samples_slider],
202
  outputs=[output_table_standard, output_table_mc]
203
  )
204