# app.py import streamlit as st import torch import numpy as np from transformers import CLIPProcessor from PIL import Image import onnxruntime as ort import os def inject_css(): st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model_and_processor(): try: current_dir = os.path.dirname(__file__) processor_path = os.path.join(current_dir, "clip_processor") onnx_path = os.path.join(current_dir, "clip_model", "train_quantized.onnx") processor = CLIPProcessor.from_pretrained(processor_path) if not os.path.exists(onnx_path): st.error(f"❌ ONNX model not found at: {onnx_path}") st.info("Make sure 'clip_model/train_quantized.onnx' exists.") return None, None, None session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) return session, processor, "onnx" except Exception as e: st.error(f"Failed to load model or processor: {str(e)}") return None, None, None def predict_text_only(text, image, session, processor, device): inputs = processor( text=[text], images=image, return_tensors="np", padding="max_length", truncation=True, max_length=77 ) # Replace with zero image values for text-only inputs['pixel_values'] = np.zeros((1, 3, 224, 224), dtype=np.float32) mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32).reshape(1, 3, 1, 1) std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32).reshape(1, 3, 1, 1) inputs['pixel_values'] = (inputs['pixel_values'] - mean) / std onnx_inputs = { "input_ids": np.array(inputs["input_ids"], dtype=np.int64), "attention_mask": np.array(inputs["attention_mask"], dtype=np.int64), "pixel_values": np.array(inputs["pixel_values"], dtype=np.float32) } logits = session.run(["logits"], onnx_inputs)[0] probs = torch.softmax(torch.from_numpy(logits[0]), dim=0).numpy() pred = np.argmax(probs) conf = probs[pred] return "Real" if pred == 1 else "Fake", conf def predict_image_only(text, image, session, processor, device): inputs = processor( text=[text], images=image, return_tensors="np", truncation=True, max_length=77, do_convert_rgb=True, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711], input_data_format="channels_last" ) # Zero out text inputs for image-only inputs['input_ids'].fill(0) inputs['attention_mask'].fill(0) onnx_inputs = { "input_ids": np.array(inputs["input_ids"], dtype=np.int64), "attention_mask": np.array(inputs["attention_mask"], dtype=np.int64), "pixel_values": np.array(inputs["pixel_values"], dtype=np.float32) } logits = session.run(["logits"], onnx_inputs)[0] probs = torch.softmax(torch.from_numpy(logits[0]), dim=0).numpy() pred = np.argmax(probs) conf = probs[pred] return "Real" if pred == 1 else "Fake", conf def main(): inject_css() st.set_page_config(page_title="Multimodal BN-EN Fake News Scanner", layout="centered") st.markdown("
Enter text and upload an image to analyze text-only and image-only predictions.
", unsafe_allow_html=True) session, processor, device = load_model_and_processor() if session is None: st.stop() text_input = st.text_area("Enter News Text", placeholder="Type a headline or article snippet...", height=180) uploaded_image = st.file_uploader("Upload News Image", type=["jpg", "jpeg", "png"], help="Upload a related image") if st.button("Analyze Multimodal Input"): if not text_input.strip(): st.warning("Please enter news text.") elif not uploaded_image: st.warning("Please upload a news image.") else: try: image = Image.open(uploaded_image).convert("RGB") except Exception as e: st.error(f"Cannot open image: {e}") return with st.spinner("Running text-only and image-only analysis..."): text_pred, text_conf = predict_text_only(text_input, image, session, processor, device) img_pred, img_conf = predict_image_only(text_input, image, session, processor, device) st.session_state.modality_results = { "text": {"label": text_pred, "conf": text_conf}, "image": {"label": img_pred, "conf": img_conf} } if 'modality_results' in st.session_state: res = st.session_state.modality_results text_icon = "🟢" if res['text']['label'] == "Real" else "🔴" text_class = "text-real" if res['text']['label'] == "Real" else "text-fake" st.markdown(f"""