|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
import json |
|
|
from torch import nn |
|
|
from typing import Literal |
|
|
import os |
|
|
import logging |
|
|
import traceback |
|
|
import warnings |
|
|
import time |
|
|
import signal |
|
|
import sys |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
LIGHTWEIGHT_TEXT_MODEL = "distilbert-base-uncased" |
|
|
LIGHTWEIGHT_IMAGE_MODEL = "microsoft/resnet-18" |
|
|
|
|
|
|
|
|
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0" |
|
|
os.environ["GRADIO_SERVER_PORT"] = "7860" |
|
|
|
|
|
class MultimodalClassifier(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
text_encoder_id_or_path: str, |
|
|
image_encoder_id_or_path: str, |
|
|
projection_dim: int, |
|
|
fusion_method: Literal["concat", "align", "cosine_similarity"] = "concat", |
|
|
proj_dropout: float = 0.1, |
|
|
fusion_dropout: float = 0.1, |
|
|
num_classes: int = 1, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.fusion_method = fusion_method |
|
|
self.projection_dim = projection_dim |
|
|
self.num_classes = num_classes |
|
|
|
|
|
self.text_encoder = AutoModel.from_pretrained(text_encoder_id_or_path) |
|
|
self.text_projection = nn.Sequential( |
|
|
nn.Linear(self.text_encoder.config.hidden_size, self.projection_dim), |
|
|
nn.Dropout(proj_dropout), |
|
|
) |
|
|
|
|
|
self.image_encoder = AutoModel.from_pretrained(image_encoder_id_or_path, trust_remote_code=True) |
|
|
self.image_encoder.classifier = nn.Identity() |
|
|
|
|
|
self.image_projection = nn.Sequential( |
|
|
nn.Linear(512, self.projection_dim), |
|
|
nn.Dropout(proj_dropout), |
|
|
) |
|
|
|
|
|
fusion_input_dim = self.projection_dim * 2 if fusion_method == "concat" else self.projection_dim |
|
|
self.fusion_layer = nn.Sequential( |
|
|
nn.Dropout(fusion_dropout), |
|
|
nn.Linear(fusion_input_dim, self.projection_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(fusion_dropout), |
|
|
) |
|
|
|
|
|
self.classifier = nn.Linear(self.projection_dim, self.num_classes) |
|
|
|
|
|
def forward(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
full_text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).last_hidden_state |
|
|
full_text_features = full_text_features[:, 0, :] |
|
|
full_text_features = self.text_projection(full_text_features) |
|
|
|
|
|
resnet_image_features = self.image_encoder(pixel_values=pixel_values).last_hidden_state |
|
|
resnet_image_features = resnet_image_features.mean(dim=[-2, -1]) |
|
|
resnet_image_features = self.image_projection(resnet_image_features) |
|
|
|
|
|
if self.fusion_method == "concat": |
|
|
fused_features = torch.cat([full_text_features, resnet_image_features], dim=-1) |
|
|
else: |
|
|
fused_features = full_text_features * resnet_image_features |
|
|
|
|
|
fused_features = self.fusion_layer(fused_features) |
|
|
classification_output = self.classifier(fused_features) |
|
|
return classification_output |
|
|
|
|
|
def load_model(): |
|
|
try: |
|
|
if not os.path.exists("config.json"): |
|
|
raise FileNotFoundError("config.json file not found. Please ensure it exists in the current directory.") |
|
|
|
|
|
logger.info("Loading configuration from config.json...") |
|
|
with open("config.json", "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
required_keys = ["text_encoder_id_or_path", "projection_dim", "fusion_method", |
|
|
"proj_dropout", "fusion_dropout", "num_classes"] |
|
|
for key in required_keys: |
|
|
if key not in config: |
|
|
raise KeyError(f"Missing required key '{key}' in config.json") |
|
|
|
|
|
logger.info("Initializing MultimodalClassifier with lightweight models...") |
|
|
model = MultimodalClassifier( |
|
|
text_encoder_id_or_path=LIGHTWEIGHT_TEXT_MODEL, |
|
|
image_encoder_id_or_path=LIGHTWEIGHT_IMAGE_MODEL, |
|
|
projection_dim=config["projection_dim"], |
|
|
fusion_method=config["fusion_method"], |
|
|
proj_dropout=config["proj_dropout"], |
|
|
fusion_dropout=config["fusion_dropout"], |
|
|
num_classes=config["num_classes"] |
|
|
) |
|
|
|
|
|
if os.path.exists("model_weights.pth"): |
|
|
logger.info("Loading model weights...") |
|
|
checkpoint = torch.load("model_weights.pth", map_location=torch.device('cpu')) |
|
|
model.load_state_dict(checkpoint, strict=False) |
|
|
else: |
|
|
logger.warning("model_weights.pth not found. Using untrained model for demonstration.") |
|
|
logger.warning("For best results, please provide the trained model weights.") |
|
|
|
|
|
logger.info("Model loaded successfully!") |
|
|
return model |
|
|
|
|
|
except FileNotFoundError as e: |
|
|
logger.error(f"File error: {e}") |
|
|
raise |
|
|
except json.JSONDecodeError as e: |
|
|
logger.error(f"JSON parsing error in config.json: {e}") |
|
|
raise ValueError(f"Invalid JSON format in config.json: {e}") |
|
|
except KeyError as e: |
|
|
logger.error(f"Configuration error: {e}") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error loading model: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
def initialize_components(): |
|
|
global model, text_tokenizer |
|
|
try: |
|
|
logger.info("Initializing model and tokenizer...") |
|
|
model = load_model() |
|
|
model.eval() |
|
|
|
|
|
logger.info("Loading DistilBERT tokenizer...") |
|
|
text_tokenizer = AutoTokenizer.from_pretrained(LIGHTWEIGHT_TEXT_MODEL) |
|
|
|
|
|
logger.info("All components initialized successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize components: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return False |
|
|
|
|
|
model = None |
|
|
text_tokenizer = None |
|
|
initialization_success = initialize_components() |
|
|
|
|
|
image_transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def validate_inputs(image, text): |
|
|
if not initialization_success: |
|
|
raise RuntimeError("Model initialization failed. Please check the logs for details.") |
|
|
|
|
|
if model is None or text_tokenizer is None: |
|
|
raise RuntimeError("Model or tokenizer not loaded. Please restart the application.") |
|
|
|
|
|
if image is None: |
|
|
raise ValueError("Please upload an image.") |
|
|
|
|
|
if not text or not text.strip(): |
|
|
raise ValueError("Please enter some text for analysis.") |
|
|
|
|
|
if len(text.strip()) < 10: |
|
|
raise ValueError("Text is too short. Please provide at least 10 characters.") |
|
|
|
|
|
if len(text) > 2000: |
|
|
raise ValueError("Text is too long. Please provide text with less than 2000 characters.") |
|
|
|
|
|
try: |
|
|
if image.mode not in ['RGB', 'RGBA', 'L']: |
|
|
image = image.convert('RGB') |
|
|
except Exception as e: |
|
|
raise ValueError(f"Invalid image format: {e}") |
|
|
|
|
|
def simple_fake_news_detector(text: str) -> str: |
|
|
"""Simple rule-based fake news detector as fallback""" |
|
|
fake_indicators = [ |
|
|
"breaking news", "shocking", "you won't believe", "doctors hate", |
|
|
"one weird trick", "click here", "urgent", "exclusive", "leaked", |
|
|
"nazi salute", "hail trump", "take our country back", "step on toes" |
|
|
] |
|
|
|
|
|
real_indicators = [ |
|
|
"according to", "reported by", "official statement", "confirmed", |
|
|
"study shows", "research indicates", "data reveals", "analysis" |
|
|
] |
|
|
|
|
|
text_lower = text.lower() |
|
|
fake_score = sum(1 for indicator in fake_indicators if indicator in text_lower) |
|
|
real_score = sum(1 for indicator in real_indicators if indicator in text_lower) |
|
|
|
|
|
if fake_score > real_score: |
|
|
return "Fake News (Rule-based)" |
|
|
elif real_score > fake_score: |
|
|
return "Real News (Rule-based)" |
|
|
else: |
|
|
return "Uncertain (Rule-based)" |
|
|
|
|
|
def predict(image: Image.Image, text: str) -> str: |
|
|
try: |
|
|
logger.info("Starting prediction...") |
|
|
|
|
|
validate_inputs(image, text) |
|
|
|
|
|
|
|
|
if not os.path.exists("model_weights.pth"): |
|
|
logger.info("Using rule-based fallback detection...") |
|
|
return simple_fake_news_detector(text) |
|
|
|
|
|
logger.info("Processing text input...") |
|
|
text_inputs = text_tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
) |
|
|
|
|
|
logger.info("Processing image input...") |
|
|
try: |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
image_input = image_transform(image).unsqueeze(0) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to process image: {e}") |
|
|
|
|
|
logger.info("Running model inference...") |
|
|
with torch.no_grad(): |
|
|
try: |
|
|
classification_output = model( |
|
|
pixel_values=image_input, |
|
|
input_ids=text_inputs["input_ids"], |
|
|
attention_mask=text_inputs["attention_mask"] |
|
|
) |
|
|
predicted_class = torch.sigmoid(classification_output).round().item() |
|
|
except Exception as e: |
|
|
logger.warning(f"Model inference failed, using fallback: {e}") |
|
|
return simple_fake_news_detector(text) |
|
|
|
|
|
result = "Fake News" if predicted_class == 1 else "Real News" |
|
|
logger.info(f"Prediction completed: {result}") |
|
|
return result |
|
|
|
|
|
except ValueError as e: |
|
|
logger.warning(f"Input validation error: {e}") |
|
|
return f"Error: {e}" |
|
|
except RuntimeError as e: |
|
|
logger.error(f"Runtime error: {e}") |
|
|
return f"Error: {e}" |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error during prediction: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return f"Error: An unexpected error occurred. Please try again." |
|
|
|
|
|
def create_interface(): |
|
|
try: |
|
|
if not initialization_success: |
|
|
error_msg = "Failed to initialize the model. Please check that config.json and model_weights.pth files exist and are valid." |
|
|
logger.error(error_msg) |
|
|
|
|
|
def error_function(image, text): |
|
|
return error_msg |
|
|
|
|
|
return gr.Interface( |
|
|
fn=error_function, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Upload Related Image"), |
|
|
gr.Textbox(lines=2, placeholder="Enter news text for classification...", label="Input Text") |
|
|
], |
|
|
outputs=gr.Label(label="Error"), |
|
|
title="Fake News Detector - Initialization Error", |
|
|
description=error_msg |
|
|
) |
|
|
|
|
|
return gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil"), |
|
|
gr.Textbox(lines=2) |
|
|
], |
|
|
outputs=gr.Textbox(lines=1), |
|
|
title="Fake News Detector", |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create interface: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
def create_simple_interface(): |
|
|
"""Create a minimal interface to avoid protocol errors""" |
|
|
try: |
|
|
return gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Image(type="pil"), |
|
|
gr.Textbox(lines=2) |
|
|
], |
|
|
outputs=gr.Textbox(lines=1), |
|
|
title="Fake News Detector", |
|
|
allow_flagging="never" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create simple interface: {e}") |
|
|
raise |
|
|
|
|
|
def create_ultra_minimal_interface(): |
|
|
"""Create the most minimal interface possible""" |
|
|
try: |
|
|
return gr.Interface( |
|
|
fn=predict, |
|
|
inputs=[gr.Image(), gr.Textbox()], |
|
|
outputs=gr.Textbox(), |
|
|
allow_flagging="never" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create ultra minimal interface: {e}") |
|
|
raise |
|
|
|
|
|
def main(): |
|
|
try: |
|
|
logger.info("Starting Fake News Detector application...") |
|
|
|
|
|
interface = None |
|
|
try: |
|
|
interface = create_interface() |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to create full interface, trying simple version: {e}") |
|
|
try: |
|
|
interface = create_simple_interface() |
|
|
except Exception as e2: |
|
|
logger.warning(f"Failed to create simple interface, using ultra minimal: {e2}") |
|
|
interface = create_ultra_minimal_interface() |
|
|
|
|
|
logger.info("Launching Gradio interface...") |
|
|
try: |
|
|
interface.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=True, |
|
|
quiet=False, |
|
|
inbrowser=False |
|
|
) |
|
|
except Exception as launch_error: |
|
|
logger.warning(f"Launch failed on port 7860, trying port 7861: {launch_error}") |
|
|
interface.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7861, |
|
|
share=True, |
|
|
quiet=False, |
|
|
inbrowser=False |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to start application: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
print(f"Error starting application: {e}") |
|
|
print("Please check the logs for more details.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|