Ploypatcha commited on
Commit
22d0c4f
·
verified ·
1 Parent(s): 266535f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -4
app.py CHANGED
@@ -8,14 +8,29 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval()
10
 
11
- labels = ["happy", "love", "angry", "sadness", "fear", "trust", "disgust", "surprise", "anticipation", "optimism", "pessimism"]
 
12
 
13
- def predict(text):
14
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(model.device)
15
  with torch.no_grad():
16
  outputs = model(**inputs)
17
  probs = torch.sigmoid(outputs.logits)[0].cpu().numpy()
18
- results = {labels[i]: float(np.round(probs[i], 3)) for i in range(len(labels))}
 
 
 
 
 
 
 
 
 
 
19
  return results
20
 
21
- gr.Interface(fn=predict, inputs=gr.Textbox(label="Enter english comment"), outputs="label").launch()
 
 
 
 
 
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval()
10
 
11
+ labels = ["happy", "love", "angry", "sadness", "fear", "trust",
12
+ "disgust", "surprise", "anticipation", "optimism", "pessimism"]
13
 
14
+ def predict(text, threshold=0.5):
15
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(model.device)
16
  with torch.no_grad():
17
  outputs = model(**inputs)
18
  probs = torch.sigmoid(outputs.logits)[0].cpu().numpy()
19
+
20
+ # เลือกเฉพาะอารมณ์ที่ prob > threshold
21
+ results = {
22
+ labels[i]: float(np.round(probs[i], 3))
23
+ for i in range(len(labels)) if probs[i] > threshold
24
+ }
25
+
26
+ # ถ้าไม่มีอารมณ์ไหนผ่าน threshold
27
+ if not results:
28
+ return {"No dominant emotion": 1.0}
29
+
30
  return results
31
 
32
+ gr.Interface(
33
+ fn=predict,
34
+ inputs=gr.Textbox(label="Enter english comment"),
35
+ outputs=gr.Label(label="Detected Emotions")
36
+ ).launch()