EXt1 commited on
Commit
5f8c0cf
·
1 Parent(s): 7053c39

add prob level

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -11,19 +11,23 @@ def classify_fake_news(text):
11
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
12
  outputs = model(**inputs)
13
  logits = outputs.logits
14
- predicted_class = logits.argmax().item()
15
-
16
- if predicted_class == 1:
17
- return "Fake News"
18
- else:
19
- return "Real News"
 
 
 
 
20
 
21
  # Create Gradio interface
22
  gr.Interface(
23
  fn=classify_fake_news,
24
  inputs=gr.Textbox(lines=8, placeholder="Enter text here..."),
25
  outputs="text",
26
- title="Thai Fake News Classification using mdeberta-v3",
27
  description="Classifies Thai News as Fake or Real with 91 percent accuracy using a fine-tuned BERT model",
28
  theme="compact"
29
  ).launch()
 
11
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
12
  outputs = model(**inputs)
13
  logits = outputs.logits
14
+ probs = F.softmax(logits, dim=1)
15
+ probs = probs.detach().cpu().numpy()[0]
16
+
17
+ labels = ["Real News", "Fake News"]
18
+ predicted_class = probs.argmax()
19
+
20
+ label = labels[predicted_class]
21
+ prob = float(probs[predicted_class]) * 100
22
+
23
+ return label, f"{prob:.2f}%"
24
 
25
  # Create Gradio interface
26
  gr.Interface(
27
  fn=classify_fake_news,
28
  inputs=gr.Textbox(lines=8, placeholder="Enter text here..."),
29
  outputs="text",
30
+ title="Thai Fake News Classification using mdeberta-v3-base",
31
  description="Classifies Thai News as Fake or Real with 91 percent accuracy using a fine-tuned BERT model",
32
  theme="compact"
33
  ).launch()