import streamlit as st from transformers import ViTForImageClassification, ViTImageProcessor from utils import load_image_vit, predict_toxicity_vit, get_label from PIL import Image import torch import io def classify_image(uploaded_file): # Load pre-trained ViT model and processor model_name = "google/vit-base-patch16-224" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ViTForImageClassification.from_pretrained(model_name) processor = ViTImageProcessor.from_pretrained(model_name) # Modify for binary classification (toxic/non-toxic) try: model.classifier = torch.nn.Linear(model.classifier.in_features, 2) model.load_state_dict(torch.load("toxic_classifier.pth", map_location=device), strict=False) except FileNotFoundError: st.warning("Using pre-trained ImageNet weights. For toxic classification, upload toxic_classifier.pth.") model.to(device) # Process image and predict inputs = load_image_vit(uploaded_file, processor) prediction, probabilities = predict_toxicity_vit(model, inputs, device) label = get_label(prediction) return label, probabilities def main(): st.title("ToxiScan - Toxic Image Classifier") st.write("Upload an image to detect if it contains toxic content using a pre-trained Vision Transformer.") # File uploader uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: # Display uploaded image image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) # Process and predict with st.spinner("Analyzing..."): label, probabilities = classify_image(uploaded_file) # Display results st.subheader("Results") st.write(f"**Prediction:** {label}") st.write(f"**Confidence Scores:**") st.write(f"- Toxic: {probabilities[1]:.2%}") st.write(f"- Non-Toxic: {probabilities[0]:.2%}") # Bar chart for visualization st.bar_chart({"Toxic": probabilities[1], "Non-Toxic": probabilities[0]}) if __name__ == "__main__": main()