File size: 6,688 Bytes
4fd1c52 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import requests
import os
NEWS_API_KEY = os.environ.get("NEWS_API_KEY", "")
TRUSTED_NEWS_SOURCES = [
"reuters.com",
"apnews.com",
"bbc.com",
"bbc.co.uk",
"theguardian.com",
"nytimes.com",
"washingtonpost.com",
"bloomberg.com",
"cnn.com",
"aljazeera.com",
"forbes.com",
"ft.com",
"economist.com",
"time.com",
"nbcnews.com"
]
print("Loading TRAK models...")
# Model 1 - TRAK Fake Detection BERT
clf1 = pipeline("text-classification", model="abd8433/TRAK-fake-detection-bert")
# Model 2 - TRAK Fake Detection Distilroberta
tokenizer2 = AutoTokenizer.from_pretrained("abd8433/TRAK-fake-detection-Distilroberta")
model2 = AutoModelForSequenceClassification.from_pretrained("abd8433/TRAK-fake-detection-Distilroberta")
model2.eval()
# Model 3 - TRAK Fake Detection TinyBERT
tokenizer3 = AutoTokenizer.from_pretrained("abd8433/TRAK-fake-detection-tinybert")
model3 = AutoModelForSequenceClassification.from_pretrained("abd8433/TRAK-fake-detection-tinybert")
model3.eval()
# Model 4 - TRAK Fake Detection RoBERTa
tokenizer4 = AutoTokenizer.from_pretrained("abd8433/TRAK-fake-Detection-roberta")
model4 = AutoModelForSequenceClassification.from_pretrained("abd8433/TRAK-fake-Detection-roberta")
model4.eval()
# Model 5 - TRAK RoBERTa T Fake Detection
tokenizer5 = AutoTokenizer.from_pretrained("abd8433/TRAK-Roberta-t-fake-detection")
model5 = AutoModelForSequenceClassification.from_pretrained("abd8433/TRAK-Roberta-t-fake-detection")
model5.eval()
print("All TRAK models loaded!")
def get_fake_score_model1(text):
result = clf1(text, truncation=True, max_length=512)[0]
label = result["label"]
score = result["score"]
if label == "LABEL_0":
return round(score * 100, 2)
else:
return round((1 - score) * 100, 2)
def get_fake_score_model2(text):
encoded = tokenizer2(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
with torch.no_grad():
logits = model2(**encoded)["logits"]
probs = F.softmax(logits, dim=1)[0]
return round(float(probs[1]) * 100, 2)
def get_fake_score_model3(text):
encoded = tokenizer3(text, truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
logits = model3(**encoded).logits
probs = F.softmax(logits, dim=1)[0]
return round(float(probs[1]) * 100, 2)
def get_fake_score_model4(text):
encoded = tokenizer4(text, truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
logits = model4(**encoded).logits
probs = F.softmax(logits, dim=1)[0]
fake_idx = 1
for idx, label in model4.config.id2label.items():
if "fake" in label.lower():
fake_idx = idx
return round(float(probs[fake_idx]) * 100, 2)
def get_fake_score_model5(text):
encoded = tokenizer5(text, truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
logits = model5(**encoded).logits
probs = F.softmax(logits, dim=1)[0]
fake_idx = 1
for idx, label in model5.config.id2label.items():
if "fake" in label.lower():
fake_idx = idx
return round(float(probs[fake_idx]) * 100, 2)
def check_news_exists(text):
if not NEWS_API_KEY:
return False, "API key not set", False
try:
query = text[:80]
url = "https://newsapi.org/v2/everything"
params = {
"q": query,
"apiKey": NEWS_API_KEY,
"pageSize": 5,
"language": "en",
"sortBy": "relevancy"
}
response = requests.get(url, params=params, timeout=5)
data = response.json()
if data.get("totalResults", 0) > 0:
# Check if any result is from a trusted source
for article in data["articles"]:
source_url = article.get("url", "")
source_name = article["source"]["name"]
for trusted in TRUSTED_NEWS_SOURCES:
if trusted in source_url:
return True, source_name, True # found in trusted source
# Found in news but not in trusted top 15
source = data["articles"][0]["source"]["name"]
return True, source, False
return False, "Not found in news", False
except:
return False, "News check failed", False
def detect(text):
score1 = get_fake_score_model1(text)
score2 = get_fake_score_model2(text)
score3 = get_fake_score_model3(text)
score4 = get_fake_score_model4(text)
score5 = get_fake_score_model5(text)
avg_fake = round((score1 + score2 + score3 + score4 + score5) / 5, 2)
exists_in_news, news_source, is_trusted = check_news_exists(text)
# If found in TOP 15 trusted sources β force REAL
if is_trusted:
return (
"β
REAL",
"100%",
"0%",
f"β
Verified in trusted source: {news_source}",
"Trusted source override applied β skipped model voting"
)
# If found in any news β reduce fake score by 30%
if exists_in_news:
avg_fake = max(0, avg_fake * 0.7)
avg_fake = round(avg_fake, 2)
avg_real = round(100 - avg_fake, 2)
votes_fake = 0
if score1 >= 50: votes_fake += 1
if score2 >= 50: votes_fake += 1
if score3 >= 50: votes_fake += 1
if score4 >= 50: votes_fake += 1
if score5 >= 50: votes_fake += 1
if votes_fake >= 3 and avg_fake >= 65:
verdict = "β FAKE"
elif votes_fake >= 3 and avg_fake >= 45:
verdict = "β οΈ SUSPICIOUS"
elif votes_fake == 2 and avg_fake >= 55:
verdict = "β οΈ SUSPICIOUS"
else:
verdict = "β
REAL"
news_info = f"Found in: {news_source}" if exists_in_news else "Not found in real news sources"
debug = f"M1:{score1} M2:{score2} M3:{score3} M4:{score4} M5:{score5} Votes:{votes_fake}/5"
return verdict, str(avg_real) + "%", str(avg_fake) + "%", news_info, debug
inputs = gr.Textbox(lines=10, placeholder="Paste news article here...", label="News Article")
out1 = gr.Textbox(label="Verdict")
out2 = gr.Textbox(label="Real Confidence")
out3 = gr.Textbox(label="Fake Confidence")
out4 = gr.Textbox(label="News Verification")
out5 = gr.Textbox(label="Debug")
demo = gr.Interface(
fn=detect,
inputs=inputs,
outputs=[out1, out2, out3, out4, out5],
title="TRAK Fake News Detector",
description="Uses 5 TRAK AI models plus NewsAPI verification against top 15 trusted news sources."
)
demo.launch() |