Ploypatcha commited on
Commit
6d97e8d
·
verified ·
1 Parent(s): 22d0c4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -11,26 +11,20 @@ model.eval()
11
  labels = ["happy", "love", "angry", "sadness", "fear", "trust",
12
  "disgust", "surprise", "anticipation", "optimism", "pessimism"]
13
 
14
- def predict(text, threshold=0.5):
15
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(model.device)
16
  with torch.no_grad():
17
  outputs = model(**inputs)
18
  probs = torch.sigmoid(outputs.logits)[0].cpu().numpy()
19
 
20
- # เลือกเฉพาะอารมณ์ที่ prob > threshold
21
- results = {
22
- labels[i]: float(np.round(probs[i], 3))
23
- for i in range(len(labels)) if probs[i] > threshold
24
- }
25
 
26
- # ถ้าไม่มีอารมณ์ไหนผ่าน threshold
27
- if not results:
28
- return {"No dominant emotion": 1.0}
29
-
30
- return results
31
 
32
  gr.Interface(
33
  fn=predict,
34
  inputs=gr.Textbox(label="Enter english comment"),
35
- outputs=gr.Label(label="Detected Emotions")
36
  ).launch()
 
11
  labels = ["happy", "love", "angry", "sadness", "fear", "trust",
12
  "disgust", "surprise", "anticipation", "optimism", "pessimism"]
13
 
14
+ def predict(text):
15
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(model.device)
16
  with torch.no_grad():
17
  outputs = model(**inputs)
18
  probs = torch.sigmoid(outputs.logits)[0].cpu().numpy()
19
 
20
+ max_idx = int(np.argmax(probs))
21
+ max_label = labels[max_idx]
22
+ max_score = round(probs[max_idx] * 100)
 
 
23
 
24
+ return f"{max_label} ({max_score}%)"
 
 
 
 
25
 
26
  gr.Interface(
27
  fn=predict,
28
  inputs=gr.Textbox(label="Enter english comment"),
29
+ outputs=gr.Text(label="Top Emotion")
30
  ).launch()