Keetawan commited on
Commit
ef8449d
·
1 Parent(s): c9b77b5

fix:show all prob test

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -24,6 +24,7 @@ emoji_to_emotion = {
24
  # Function to make predictions
25
  def predict_sentiment(text):
26
  start_time = timer()
 
27
  inputs = model.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
28
  input_ids = inputs["input_ids"]
29
  attention_mask = inputs["attention_mask"]
@@ -32,13 +33,19 @@ def predict_sentiment(text):
32
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
33
 
34
  logits = outputs.logits
35
- _, predicted_class = torch.max(logits, dim=1)
36
- pred_time = round(timer() - start_time, 5)
37
 
38
  # Map predicted class to emoji
39
- result = emoji_to_emotion[predicted_class.item()]
 
 
 
 
 
 
 
40
 
41
- return result,pred_time
42
 
43
  # Create title, description and article strings
44
  title = "Emoji-aware Sentiment Analysis using Roberta Model"
@@ -50,7 +57,7 @@ article = "Sentiment Analysis, also known as opinion mining, is a branch of Natu
50
  iface = gr.Interface(
51
  fn=predict_sentiment,
52
  inputs="text",
53
- outputs=[gr.Label(num_top_classes=7, label="Predictions"), # what are the outputs?
54
  gr.Number(label="Prediction time (s)")],
55
  title=title,
56
  description=description,
 
24
  # Function to make predictions
25
  def predict_sentiment(text):
26
  start_time = timer()
27
+
28
  inputs = model.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
29
  input_ids = inputs["input_ids"]
30
  attention_mask = inputs["attention_mask"]
 
33
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
34
 
35
  logits = outputs.logits
36
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
 
37
 
38
  # Map predicted class to emoji
39
+ predicted_class = torch.argmax(logits, dim=1).item()
40
+ result = emoji_to_emotion[predicted_class]
41
+
42
+ # Create a dictionary of class probabilities
43
+ class_probabilities = {emoji_to_emotion[i]: float(probabilities[0, i]) for i in range(len(emoji_to_emotion))}
44
+
45
+ # Calculate prediction time
46
+ pred_time = round(timer() - start_time, 5)
47
 
48
+ return class_probabilities, pred_time
49
 
50
  # Create title, description and article strings
51
  title = "Emoji-aware Sentiment Analysis using Roberta Model"
 
57
  iface = gr.Interface(
58
  fn=predict_sentiment,
59
  inputs="text",
60
+ outputs=[gr.Label(num_top_classes=7, label="Predictions"),
61
  gr.Number(label="Prediction time (s)")],
62
  title=title,
63
  description=description,