kaisex commited on
Commit
f5d94b8
·
verified ·
1 Parent(s): e8719ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -51
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import torch
2
- from PIL import Image
3
  from transformers import (
4
  AutoTokenizer,
5
- AutoModelForSequenceClassification,
6
- ViTForImageClassification,
7
- AutoImageProcessor
8
  )
9
  import gradio as gr
10
 
@@ -20,25 +17,7 @@ text_model.eval()
20
 
21
 
22
  # ---------------------------------------------------------
23
- # 2. Load ViT image model
24
- # ---------------------------------------------------------
25
- # Your model weights are local, but your folder likely does NOT contain the processor.
26
- # So we load a pretrained processor instead.
27
- processor_model_name = "google/vit-base-patch16-224-in21k"
28
-
29
- image_processor = AutoImageProcessor.from_pretrained(processor_model_name)
30
-
31
- vit_path = "./trained_vit_final"
32
-
33
- vit_model = ViTForImageClassification.from_pretrained(
34
- vit_path,
35
- ignore_mismatched_sizes=True
36
- )
37
- vit_model.eval()
38
-
39
-
40
- # ---------------------------------------------------------
41
- # 3. Prediction functions
42
  # ---------------------------------------------------------
43
 
44
  def predict_text(text):
@@ -54,45 +33,20 @@ def predict_text(text):
54
  }
55
 
56
 
57
- def predict_image(img):
58
- inputs = image_processor(images=img, return_tensors="pt")
59
-
60
- with torch.no_grad():
61
- outputs = vit_model(**inputs)
62
- probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
63
-
64
- return {
65
- "Real News": float(probs[0]),
66
- "Fake News": float(probs[1])
67
- }
68
-
69
-
70
  # ---------------------------------------------------------
71
- # 4. Gradio UI
72
  # ---------------------------------------------------------
73
 
74
- text_tab = gr.Interface(
75
  fn=predict_text,
76
  inputs=gr.Textbox(lines=4, placeholder="Enter news article or headline..."),
77
  outputs=gr.Label(num_top_classes=2),
78
  title="Text Fake News Detector",
79
  )
80
 
81
- image_tab = gr.Interface(
82
- fn=predict_image,
83
- inputs=gr.Image(type="numpy"),
84
- outputs=gr.Label(num_top_classes=2),
85
- title="Image Fake News Detector (ViT)",
86
- )
87
-
88
- app = gr.TabbedInterface(
89
- [text_tab, image_tab],
90
- ["Text Detection", "Image Detection"]
91
- )
92
-
93
 
94
  # ---------------------------------------------------------
95
- # 5. Launch
96
  # ---------------------------------------------------------
97
 
98
  if __name__ == "__main__":
 
1
  import torch
 
2
  from transformers import (
3
  AutoTokenizer,
4
+ AutoModelForSequenceClassification
 
 
5
  )
6
  import gradio as gr
7
 
 
17
 
18
 
19
  # ---------------------------------------------------------
20
+ # 2. Prediction function (text only)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # ---------------------------------------------------------
22
 
23
  def predict_text(text):
 
33
  }
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # ---------------------------------------------------------
37
+ # 3. Gradio UI (single tab)
38
  # ---------------------------------------------------------
39
 
40
+ app = gr.Interface(
41
  fn=predict_text,
42
  inputs=gr.Textbox(lines=4, placeholder="Enter news article or headline..."),
43
  outputs=gr.Label(num_top_classes=2),
44
  title="Text Fake News Detector",
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # ---------------------------------------------------------
49
+ # 4. Launch
50
  # ---------------------------------------------------------
51
 
52
  if __name__ == "__main__":