kaisex commited on
Commit
bd4670b
·
verified ·
1 Parent(s): c7b339e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -14
app.py CHANGED
@@ -1,30 +1,91 @@
1
  import torch
 
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import gradio as gr
4
 
5
- # Load tokenizer and model from current directory
6
- model_path = "./DeBERTa" # assuming all files are in the root
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_path)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
9
- model.eval()
10
 
11
- # Prediction function
12
- def predict(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
14
  with torch.no_grad():
15
- outputs = model(**inputs)
16
  probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
17
-
18
  return {"Real News": probs[0], "Fake News": probs[1]}
19
 
20
- # Gradio Interface
21
- iface = gr.Interface(
22
- fn=predict,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  inputs=gr.Textbox(lines=4, placeholder="Enter news article or headline..."),
24
  outputs=gr.Label(num_top_classes=2),
25
- title="Fake News Detection (DeBERTa)",
26
- description="A model trained using DeBERTa to detect fake news."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
 
29
  if __name__ == "__main__":
30
- iface.launch()
 
1
  import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import gradio as gr
6
 
7
+ # ---------------------------------------------------------
8
+ # 1. Load DeBERTa text model
9
+ # ---------------------------------------------------------
10
+ model_path = "./DeBERTa"
11
  tokenizer = AutoTokenizer.from_pretrained(model_path)
12
+ text_model = AutoModelForSequenceClassification.from_pretrained(model_path)
13
+ text_model.eval()
14
 
15
+ # ---------------------------------------------------------
16
+ # 2. Load ViT image model
17
+ # ---------------------------------------------------------
18
+ class ViTModel(torch.nn.Module):
19
+ def __init__(self, base_model):
20
+ super().__init__()
21
+ self.model = base_model
22
+
23
+ def forward(self, x):
24
+ return self.model(x)
25
+
26
+ # Load your ViT model weights
27
+ vit_model = torch.load("trained_vit_final.pth", map_location=torch.device("cpu"))
28
+ vit_model.eval()
29
+
30
+ # Image preprocessing (modify if needed)
31
+ image_transforms = transforms.Compose([
32
+ transforms.Resize((224, 224)),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(
35
+ mean=[0.485, 0.456, 0.406],
36
+ std=[0.229, 0.224, 0.225]
37
+ )
38
+ ])
39
+
40
+ # ---------------------------------------------------------
41
+ # 3. Prediction functions
42
+ # ---------------------------------------------------------
43
+
44
+ def predict_text(text):
45
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
46
  with torch.no_grad():
47
+ outputs = text_model(**inputs)
48
  probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
 
49
  return {"Real News": probs[0], "Fake News": probs[1]}
50
 
51
+ def predict_image(img):
52
+ img = Image.fromarray(img)
53
+ img_tensor = image_transforms(img).unsqueeze(0)
54
+
55
+ with torch.no_grad():
56
+ logits = vit_model(img_tensor)
57
+ probs = torch.softmax(logits, dim=1).squeeze().tolist()
58
+
59
+ # EDIT LABELS depending on your ViT classes
60
+ return {
61
+ "Real News": probs[0],
62
+ "Fake News": probs[1]
63
+ }
64
+
65
+ # ---------------------------------------------------------
66
+ # 4. Create Gradio tabs: Text and Image
67
+ # ---------------------------------------------------------
68
+ text_tab = gr.Interface(
69
+ fn=predict_text,
70
  inputs=gr.Textbox(lines=4, placeholder="Enter news article or headline..."),
71
  outputs=gr.Label(num_top_classes=2),
72
+ title="Text Fake News Detector",
73
+ )
74
+
75
+ image_tab = gr.Interface(
76
+ fn=predict_image,
77
+ inputs=gr.Image(type="numpy"),
78
+ outputs=gr.Label(num_top_classes=2),
79
+ title="Image Fake News Detector (ViT)",
80
+ )
81
+
82
+ # ---------------------------------------------------------
83
+ # 5. Combine into one app
84
+ # ---------------------------------------------------------
85
+ app = gr.TabbedInterface(
86
+ [text_tab, image_tab],
87
+ ["Text Detection", "Image Detection"]
88
  )
89
 
90
  if __name__ == "__main__":
91
+ app.launch()