Abs6187's picture
Update app.py
15cc9ad verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
from PIL import Image
from torchvision import transforms
import json
from torch import nn
from typing import Literal
import os
import logging
import traceback
import warnings
import time
import signal
import sys
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Use lighter models to reduce storage
LIGHTWEIGHT_TEXT_MODEL = "distilbert-base-uncased" # Much smaller than BERT
LIGHTWEIGHT_IMAGE_MODEL = "microsoft/resnet-18" # Smaller than ResNet-34
# Set environment variables to prevent protocol errors
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
os.environ["GRADIO_SERVER_PORT"] = "7860"
class MultimodalClassifier(nn.Module):
def __init__(
self,
text_encoder_id_or_path: str,
image_encoder_id_or_path: str,
projection_dim: int,
fusion_method: Literal["concat", "align", "cosine_similarity"] = "concat",
proj_dropout: float = 0.1,
fusion_dropout: float = 0.1,
num_classes: int = 1,
) -> None:
super().__init__()
self.fusion_method = fusion_method
self.projection_dim = projection_dim
self.num_classes = num_classes
self.text_encoder = AutoModel.from_pretrained(text_encoder_id_or_path)
self.text_projection = nn.Sequential(
nn.Linear(self.text_encoder.config.hidden_size, self.projection_dim),
nn.Dropout(proj_dropout),
)
self.image_encoder = AutoModel.from_pretrained(image_encoder_id_or_path, trust_remote_code=True)
self.image_encoder.classifier = nn.Identity()
# Adjust for ResNet-18 (512 features) vs ResNet-34 (512 features)
self.image_projection = nn.Sequential(
nn.Linear(512, self.projection_dim),
nn.Dropout(proj_dropout),
)
fusion_input_dim = self.projection_dim * 2 if fusion_method == "concat" else self.projection_dim
self.fusion_layer = nn.Sequential(
nn.Dropout(fusion_dropout),
nn.Linear(fusion_input_dim, self.projection_dim),
nn.GELU(),
nn.Dropout(fusion_dropout),
)
self.classifier = nn.Linear(self.projection_dim, self.num_classes)
def forward(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
full_text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).last_hidden_state
full_text_features = full_text_features[:, 0, :]
full_text_features = self.text_projection(full_text_features)
resnet_image_features = self.image_encoder(pixel_values=pixel_values).last_hidden_state
resnet_image_features = resnet_image_features.mean(dim=[-2, -1])
resnet_image_features = self.image_projection(resnet_image_features)
if self.fusion_method == "concat":
fused_features = torch.cat([full_text_features, resnet_image_features], dim=-1)
else:
fused_features = full_text_features * resnet_image_features
fused_features = self.fusion_layer(fused_features)
classification_output = self.classifier(fused_features)
return classification_output
def load_model():
try:
if not os.path.exists("config.json"):
raise FileNotFoundError("config.json file not found. Please ensure it exists in the current directory.")
logger.info("Loading configuration from config.json...")
with open("config.json", "r") as f:
config = json.load(f)
required_keys = ["text_encoder_id_or_path", "projection_dim", "fusion_method",
"proj_dropout", "fusion_dropout", "num_classes"]
for key in required_keys:
if key not in config:
raise KeyError(f"Missing required key '{key}' in config.json")
logger.info("Initializing MultimodalClassifier with lightweight models...")
model = MultimodalClassifier(
text_encoder_id_or_path=LIGHTWEIGHT_TEXT_MODEL, # Use DistilBERT instead of BERT
image_encoder_id_or_path=LIGHTWEIGHT_IMAGE_MODEL, # Use ResNet-18 instead of ResNet-34
projection_dim=config["projection_dim"],
fusion_method=config["fusion_method"],
proj_dropout=config["proj_dropout"],
fusion_dropout=config["fusion_dropout"],
num_classes=config["num_classes"]
)
if os.path.exists("model_weights.pth"):
logger.info("Loading model weights...")
checkpoint = torch.load("model_weights.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)
else:
logger.warning("model_weights.pth not found. Using untrained model for demonstration.")
logger.warning("For best results, please provide the trained model weights.")
logger.info("Model loaded successfully!")
return model
except FileNotFoundError as e:
logger.error(f"File error: {e}")
raise
except json.JSONDecodeError as e:
logger.error(f"JSON parsing error in config.json: {e}")
raise ValueError(f"Invalid JSON format in config.json: {e}")
except KeyError as e:
logger.error(f"Configuration error: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error loading model: {e}")
logger.error(traceback.format_exc())
raise
def initialize_components():
global model, text_tokenizer
try:
logger.info("Initializing model and tokenizer...")
model = load_model()
model.eval()
logger.info("Loading DistilBERT tokenizer...")
text_tokenizer = AutoTokenizer.from_pretrained(LIGHTWEIGHT_TEXT_MODEL)
logger.info("All components initialized successfully!")
return True
except Exception as e:
logger.error(f"Failed to initialize components: {e}")
logger.error(traceback.format_exc())
return False
model = None
text_tokenizer = None
initialization_success = initialize_components()
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def validate_inputs(image, text):
if not initialization_success:
raise RuntimeError("Model initialization failed. Please check the logs for details.")
if model is None or text_tokenizer is None:
raise RuntimeError("Model or tokenizer not loaded. Please restart the application.")
if image is None:
raise ValueError("Please upload an image.")
if not text or not text.strip():
raise ValueError("Please enter some text for analysis.")
if len(text.strip()) < 10:
raise ValueError("Text is too short. Please provide at least 10 characters.")
if len(text) > 2000:
raise ValueError("Text is too long. Please provide text with less than 2000 characters.")
try:
if image.mode not in ['RGB', 'RGBA', 'L']:
image = image.convert('RGB')
except Exception as e:
raise ValueError(f"Invalid image format: {e}")
def simple_fake_news_detector(text: str) -> str:
"""Simple rule-based fake news detector as fallback"""
fake_indicators = [
"breaking news", "shocking", "you won't believe", "doctors hate",
"one weird trick", "click here", "urgent", "exclusive", "leaked",
"nazi salute", "hail trump", "take our country back", "step on toes"
]
real_indicators = [
"according to", "reported by", "official statement", "confirmed",
"study shows", "research indicates", "data reveals", "analysis"
]
text_lower = text.lower()
fake_score = sum(1 for indicator in fake_indicators if indicator in text_lower)
real_score = sum(1 for indicator in real_indicators if indicator in text_lower)
if fake_score > real_score:
return "Fake News (Rule-based)"
elif real_score > fake_score:
return "Real News (Rule-based)"
else:
return "Uncertain (Rule-based)"
def predict(image: Image.Image, text: str) -> str:
try:
logger.info("Starting prediction...")
validate_inputs(image, text)
# If model weights are not available, use simple rule-based detection
if not os.path.exists("model_weights.pth"):
logger.info("Using rule-based fallback detection...")
return simple_fake_news_detector(text)
logger.info("Processing text input...")
text_inputs = text_tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512
)
logger.info("Processing image input...")
try:
if image.mode != 'RGB':
image = image.convert('RGB')
image_input = image_transform(image).unsqueeze(0)
except Exception as e:
raise ValueError(f"Failed to process image: {e}")
logger.info("Running model inference...")
with torch.no_grad():
try:
classification_output = model(
pixel_values=image_input,
input_ids=text_inputs["input_ids"],
attention_mask=text_inputs["attention_mask"]
)
predicted_class = torch.sigmoid(classification_output).round().item()
except Exception as e:
logger.warning(f"Model inference failed, using fallback: {e}")
return simple_fake_news_detector(text)
result = "Fake News" if predicted_class == 1 else "Real News"
logger.info(f"Prediction completed: {result}")
return result
except ValueError as e:
logger.warning(f"Input validation error: {e}")
return f"Error: {e}"
except RuntimeError as e:
logger.error(f"Runtime error: {e}")
return f"Error: {e}"
except Exception as e:
logger.error(f"Unexpected error during prediction: {e}")
logger.error(traceback.format_exc())
return f"Error: An unexpected error occurred. Please try again."
def create_interface():
try:
if not initialization_success:
error_msg = "Failed to initialize the model. Please check that config.json and model_weights.pth files exist and are valid."
logger.error(error_msg)
def error_function(image, text):
return error_msg
return gr.Interface(
fn=error_function,
inputs=[
gr.Image(type="pil", label="Upload Related Image"),
gr.Textbox(lines=2, placeholder="Enter news text for classification...", label="Input Text")
],
outputs=gr.Label(label="Error"),
title="Fake News Detector - Initialization Error",
description=error_msg
)
return gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Textbox(lines=2)
],
outputs=gr.Textbox(lines=1),
title="Fake News Detector",
allow_flagging="never"
)
except Exception as e:
logger.error(f"Failed to create interface: {e}")
logger.error(traceback.format_exc())
raise
def create_simple_interface():
"""Create a minimal interface to avoid protocol errors"""
try:
return gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil"),
gr.Textbox(lines=2)
],
outputs=gr.Textbox(lines=1),
title="Fake News Detector",
allow_flagging="never"
)
except Exception as e:
logger.error(f"Failed to create simple interface: {e}")
raise
def create_ultra_minimal_interface():
"""Create the most minimal interface possible"""
try:
return gr.Interface(
fn=predict,
inputs=[gr.Image(), gr.Textbox()],
outputs=gr.Textbox(),
allow_flagging="never"
)
except Exception as e:
logger.error(f"Failed to create ultra minimal interface: {e}")
raise
def main():
try:
logger.info("Starting Fake News Detector application...")
interface = None
try:
interface = create_interface()
except Exception as e:
logger.warning(f"Failed to create full interface, trying simple version: {e}")
try:
interface = create_simple_interface()
except Exception as e2:
logger.warning(f"Failed to create simple interface, using ultra minimal: {e2}")
interface = create_ultra_minimal_interface()
logger.info("Launching Gradio interface...")
try:
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
quiet=False,
inbrowser=False
)
except Exception as launch_error:
logger.warning(f"Launch failed on port 7860, trying port 7861: {launch_error}")
interface.launch(
server_name="0.0.0.0",
server_port=7861,
share=True,
quiet=False,
inbrowser=False
)
except Exception as e:
logger.error(f"Failed to start application: {e}")
logger.error(traceback.format_exc())
print(f"Error starting application: {e}")
print("Please check the logs for more details.")
if __name__ == "__main__":
main()