BenTouss commited on
Commit
2c05295
·
verified ·
1 Parent(s): c932303

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -2,30 +2,28 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
- # Load model and tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("BenTouss/mdeberta-eurochef")
7
  model = AutoModelForSequenceClassification.from_pretrained("BenTouss/mdeberta-eurochef")
8
 
9
- def get_labels(text):
10
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
11
-
12
- # Get predictions
13
  with torch.no_grad():
14
  outputs = model(**inputs)
15
  probs = torch.sigmoid(outputs.logits)[0]
16
-
17
- # Get predicted labels (threshold = 0.6)
18
- predicted_labels = []
19
  for idx, prob in enumerate(probs):
20
  if prob > 0.6:
21
- label = model.config.id2label[idx]
22
- predicted_labels.append((label, prob.item()))
23
-
24
- return predicted_labels
25
 
 
 
26
 
27
- def greet(name):
28
- return "Hello " + name + "!!"
 
 
 
29
 
30
- demo = gr.Interface(fn=get_labels, inputs="text", outputs="text")
31
- demo.launch()
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("BenTouss/mdeberta-eurochef")
6
  model = AutoModelForSequenceClassification.from_pretrained("BenTouss/mdeberta-eurochef")
7
 
8
+ def get_labels_table(text):
9
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
10
+
 
11
  with torch.no_grad():
12
  outputs = model(**inputs)
13
  probs = torch.sigmoid(outputs.logits)[0]
14
+
15
+ rows = []
 
16
  for idx, prob in enumerate(probs):
17
  if prob > 0.6:
18
+ rows.append([model.config.id2label[idx], float(prob)])
 
 
 
19
 
20
+ rows.sort(key=lambda x: x[1], reverse=True)
21
+ return rows
22
 
23
+ demo = gr.Interface(
24
+ fn=get_labels_table,
25
+ inputs=gr.Textbox(lines=3, label="Text"),
26
+ outputs=gr.Dataframe(headers=["label", "score"]),
27
+ )
28
 
29
+ demo.launch()