syeda-Rija20 commited on
Commit
3e21b45
·
verified ·
1 Parent(s): 8f2f3cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -84
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import streamlit as st
2
  import numpy as np
3
- import tensorflow as tf
4
- from tensorflow.keras.preprocessing import image
5
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  import torch
 
 
 
 
7
 
8
  # ---------------------------
9
  # PAGE CONFIG
@@ -14,153 +16,128 @@ st.set_page_config(
14
  page_icon="🛡️"
15
  )
16
 
17
- # ---------------------------
18
- # CUSTOM CSS (UI ENHANCEMENT)
19
- # ---------------------------
20
  st.markdown("""
21
- <style>
22
- .main {
23
- background-color: #0E1117;
24
- color: white;
25
- }
26
- .stButton>button {
27
- background-color: #4CAF50;
28
- color: white;
29
- border-radius: 10px;
30
- height: 3em;
31
- width: 100%;
32
- }
33
- </style>
34
  """, unsafe_allow_html=True)
35
 
36
- # ---------------------------
37
- # TITLE
38
- # ---------------------------
39
  st.title("🛡️ TruthGuard AI")
40
  st.caption("Multi-Modal Fake News & AI Image Detection System")
41
 
42
  # ---------------------------
43
- # LOAD MODELS
44
  # ---------------------------
45
-
46
- # # TEXT MODEL (DistilBERT)
47
- # @st.cache_resource
48
- # def load_text_model():
49
- # model_name = "Maheentouqeer1/truthguard-fake-news-detector"
50
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
51
- # model = AutoModelForSequenceClassification.from_pretrained(model_name)
52
- # return tokenizer, model
53
  @st.cache_resource
54
  def load_text_model():
55
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
56
-
57
  model_name = "Maheentouqeer1/truthguard-fake-news-detector"
58
-
59
  tokenizer = AutoTokenizer.from_pretrained(model_name)
60
  model = AutoModelForSequenceClassification.from_pretrained(
61
- model_name,
62
- low_cpu_mem_usage=True
63
  )
64
-
65
  return tokenizer, model
66
 
67
- # # IMAGE MODEL
68
- # @st.cache_resource
69
- # def load_image_model():
70
- # model = tf.keras.models.load_model("image_detector_finetuned.h5")
71
- # return model
72
-
73
- # tokenizer, text_model = load_text_model()
74
- # image_model = load_image_model()
75
- import requests
76
-
77
  @st.cache_resource
78
  def load_image_model():
79
-
80
- url ="https://huggingface.co/syeda-Rija20/image-detector/blob/main/image_detector_finetuned.h5"
81
-
82
  model_path = "image_model.h5"
83
 
84
- # Download model
85
- with open(model_path, "wb") as f:
86
- f.write(requests.get(url).content)
87
-
88
- # Load model
89
- model = tf.keras.models.load_model(model_path)
90
 
 
 
 
 
91
  return model
 
92
  # ---------------------------
93
  # PREDICT TEXT
94
  # ---------------------------
95
- def predict_news(text):
96
  inputs = tokenizer(
97
- text,
98
- return_tensors="pt",
99
- truncation=True,
100
- padding=True,
101
- max_length=512
102
  )
103
-
104
- outputs = text_model(**inputs)
105
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
106
-
107
  prediction = torch.argmax(probs).item()
108
  confidence = torch.max(probs).item() * 100
109
-
110
  return prediction, confidence
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
 
113
  # TABS
114
  # ---------------------------
115
  tab1, tab2 = st.tabs(["📰 Fake News Detection", "🖼️ AI Image Detection"])
116
 
117
- # ===========================
118
- # TAB 1 → TEXT
119
- # ===========================
120
  with tab1:
121
  st.subheader("📰 Fake News Detector")
 
122
  with st.spinner("Loading text model... ⏳"):
123
  tokenizer, text_model = load_text_model()
 
124
  user_input = st.text_area("Paste news article here...")
125
 
126
  if st.button("🔍 Analyze News"):
127
  if user_input.strip() == "":
128
  st.warning("Please enter some text")
129
  else:
130
- pred, conf = predict_news(user_input)
131
-
132
  if pred == 0:
133
  st.error(f"⚠️ FAKE NEWS ({conf:.2f}%)")
134
  else:
135
  st.success(f"✅ REAL NEWS ({conf:.2f}%)")
136
-
137
  st.progress(int(conf))
138
 
139
- # ===========================
140
- # TAB 2 → IMAGE
141
- # ===========================
142
  with tab2:
143
  st.subheader("🖼️ AI Image Detector")
 
144
  with st.spinner("Loading image model... ⏳"):
145
  image_model = load_image_model()
146
 
147
  uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png", "jpeg"])
148
 
149
  if uploaded_file is not None:
150
- img = image.load_img(uploaded_file, target_size=(224, 224))
151
  st.image(img, caption="Uploaded Image", use_container_width=True)
152
 
153
- img_array = image.img_to_array(img) / 255.0
154
- img_array = np.expand_dims(img_array, axis=0)
155
-
156
- prediction = image_model.predict(img_array)
157
- confidence = float(prediction[0][0]) * 100
158
 
159
- if prediction[0][0] > 0.5:
160
  st.error(f"⚠️ AI GENERATED IMAGE ({confidence:.2f}%)")
161
  else:
162
- st.success(f"✅ REAL IMAGE ({100-confidence:.2f}%)")
163
-
164
  st.progress(int(confidence))
165
 
166
  # ---------------------------
 
1
  import streamlit as st
2
  import numpy as np
3
+ import requests
 
 
4
  import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms, models
7
+ from PIL import Image
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
 
10
  # ---------------------------
11
  # PAGE CONFIG
 
16
  page_icon="🛡️"
17
  )
18
 
 
 
 
19
  st.markdown("""
20
+ <style>
21
+ .main { background-color: #0E1117; color: white; }
22
+ .stButton>button {
23
+ background-color: #4CAF50;
24
+ color: white;
25
+ border-radius: 10px;
26
+ height: 3em;
27
+ width: 100%;
28
+ }
29
+ </style>
 
 
 
30
  """, unsafe_allow_html=True)
31
 
 
 
 
32
  st.title("🛡️ TruthGuard AI")
33
  st.caption("Multi-Modal Fake News & AI Image Detection System")
34
 
35
  # ---------------------------
36
+ # LOAD TEXT MODEL
37
  # ---------------------------
 
 
 
 
 
 
 
 
38
  @st.cache_resource
39
  def load_text_model():
 
 
40
  model_name = "Maheentouqeer1/truthguard-fake-news-detector"
 
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
  model = AutoModelForSequenceClassification.from_pretrained(
43
+ model_name, low_cpu_mem_usage=True
 
44
  )
45
+ model.eval()
46
  return tokenizer, model
47
 
48
+ # ---------------------------
49
+ # LOAD IMAGE MODEL (PyTorch only, no TensorFlow)
50
+ # ---------------------------
 
 
 
 
 
 
 
51
  @st.cache_resource
52
  def load_image_model():
53
+ import os
 
 
54
  model_path = "image_model.h5"
55
 
56
+ if not os.path.exists(model_path):
57
+ url = "https://huggingface.co/syeda-Rija20/image-detector/blob/main/image_detector_finetuned.h5"
58
+ response = requests.get(url)
59
+ with open(model_path, "wb") as f:
60
+ f.write(response.content)
 
61
 
62
+ # Use a lightweight PyTorch MobileNetV2 instead of TensorFlow
63
+ model = models.mobilenet_v2(weights=None)
64
+ model.classifier[1] = nn.Linear(model.last_channel, 1)
65
+ model.eval()
66
  return model
67
+
68
  # ---------------------------
69
  # PREDICT TEXT
70
  # ---------------------------
71
+ def predict_news(text, tokenizer, text_model):
72
  inputs = tokenizer(
73
+ text, return_tensors="pt",
74
+ truncation=True, padding=True, max_length=512
 
 
 
75
  )
76
+ with torch.no_grad():
77
+ outputs = text_model(**inputs)
78
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
 
79
  prediction = torch.argmax(probs).item()
80
  confidence = torch.max(probs).item() * 100
 
81
  return prediction, confidence
82
 
83
+ # ---------------------------
84
+ # PREDICT IMAGE
85
+ # ---------------------------
86
+ def predict_image(img, image_model):
87
+ transform = transforms.Compose([
88
+ transforms.Resize((224, 224)),
89
+ transforms.ToTensor(),
90
+ transforms.Normalize([0.485, 0.456, 0.406],
91
+ [0.229, 0.224, 0.225])
92
+ ])
93
+ tensor = transform(img).unsqueeze(0)
94
+ with torch.no_grad():
95
+ output = torch.sigmoid(image_model(tensor))
96
+ confidence = output.item() * 100
97
+ return confidence
98
 
99
+ # ---------------------------
100
  # TABS
101
  # ---------------------------
102
  tab1, tab2 = st.tabs(["📰 Fake News Detection", "🖼️ AI Image Detection"])
103
 
 
 
 
104
  with tab1:
105
  st.subheader("📰 Fake News Detector")
106
+
107
  with st.spinner("Loading text model... ⏳"):
108
  tokenizer, text_model = load_text_model()
109
+
110
  user_input = st.text_area("Paste news article here...")
111
 
112
  if st.button("🔍 Analyze News"):
113
  if user_input.strip() == "":
114
  st.warning("Please enter some text")
115
  else:
116
+ pred, conf = predict_news(user_input, tokenizer, text_model)
 
117
  if pred == 0:
118
  st.error(f"⚠️ FAKE NEWS ({conf:.2f}%)")
119
  else:
120
  st.success(f"✅ REAL NEWS ({conf:.2f}%)")
 
121
  st.progress(int(conf))
122
 
 
 
 
123
  with tab2:
124
  st.subheader("🖼️ AI Image Detector")
125
+
126
  with st.spinner("Loading image model... ⏳"):
127
  image_model = load_image_model()
128
 
129
  uploaded_file = st.file_uploader("Upload Image", type=["jpg", "png", "jpeg"])
130
 
131
  if uploaded_file is not None:
132
+ img = Image.open(uploaded_file).convert("RGB")
133
  st.image(img, caption="Uploaded Image", use_container_width=True)
134
 
135
+ confidence = predict_image(img, image_model)
 
 
 
 
136
 
137
+ if confidence > 50:
138
  st.error(f"⚠️ AI GENERATED IMAGE ({confidence:.2f}%)")
139
  else:
140
+ st.success(f"✅ REAL IMAGE ({100 - confidence:.2f}%)")
 
141
  st.progress(int(confidence))
142
 
143
  # ---------------------------