KenLumod commited on
Commit
413299e
·
verified ·
1 Parent(s): 434284a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModel, AutoTokenizer
4
+
5
+ # 1. Model Definition (Must Match Training Architecture)
6
+ class BERTFakeNewsClassifier(torch.nn.Module):
7
+ def __init__(self, base_model):
8
+ super().__init__()
9
+ self.bert = base_model
10
+ self.classifier = torch.nn.Sequential(
11
+ torch.nn.Dropout(0.1),
12
+ torch.nn.Linear(768, 512),
13
+ torch.nn.ReLU(),
14
+ torch.nn.Linear(512, 2),
15
+ torch.nn.LogSoftmax(dim=1)
16
+ )
17
+
18
+ def forward(self, input_ids, attention_mask):
19
+ outputs = self.bert(input_ids, attention_mask=attention_mask)
20
+ pooled = outputs.last_hidden_state[:, 0, 🙂 # Use [CLS] token
21
+ return self.classifier(pooled)
22
+
23
+ # 2. Load Model (Optimized for Inference)
24
+ def load_model_for_inference(model_path="KenLumod/ML-Project-Fake-Real-News-Detector-Final"):
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ # Load components
28
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
29
+ base_model = AutoModel.from_pretrained("bert-base-uncased")
30
+ model = BERTFakeNewsClassifier(base_model)
31
+
32
+ # Load trained weights
33
+ state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location=device)
34
+ model.load_state_dict(state_dict)
35
+ model.to(device).eval()
36
+
37
+ return model, tokenizer, device
38
+
39
+ # 3. Prediction Function (Gradio-Compatible)
40
+ def predict(text, model, tokenizer, device, max_length=128):
41
+ inputs = tokenizer(
42
+ text,
43
+ max_length=max_length,
44
+ padding="max_length",
45
+ truncation=True,
46
+ return_tensors="pt"
47
+ ).to(device)
48
+
49
+ with torch.no_grad():
50
+ logits = model(**inputs)
51
+ probs = torch.exp(logits).cpu().numpy()[0] # Convert log-probs to probabilities
52
+
53
+ return {"REAL": float(probs[0]), "FAKE": float(probs[1])}
54
+
55
+ # 4. Gradio Interface Builder
56
+ def create_gradio_interface():
57
+ model, tokenizer, device = load_model_for_inference()
58
+
59
+ def classify_text(text):
60
+ return predict(text, model, tokenizer, device)
61
+
62
+ return gr.Interface(
63
+ fn=classify_text,
64
+ inputs=gr.Textbox(
65
+ label="News Content",
66
+ placeholder="Paste news article or headline here...",
67
+ lines=3
68
+ ),
69
+ outputs=gr.Label(
70
+ label="Detection Result",
71
+ num_top_classes=2
72
+ ),
73
+ examples=[
74
+ ["BREAKING: Trump arrested at Mar-a-Lago - Secret Service confirms"],
75
+ ["Congress passes bipartisan infrastructure bill after months of negotiation"],
76
+ ["NASA discovers alien city on Mars - photos leaked"]
77
+ ],
78
+ title="Fake News Detector (BERT)",
79
+ description="Classifies news content as REAL or FAKE using fine-tuned BERT",
80
+ allow_flagging="never"
81
+ )
82
+
83
+ # 5. Launch
84
+ if _name_ == "_main_":
85
+ demo = create_gradio_interface()
86
+ demo.launch(
87
+ server_name="0.0.0.0",
88
+ server_port=7860,
89
+ share=False # Set to True for temporary public link
90
+ )