ToxiScan / app.py
JanviMl's picture
Update app.py
f07da63 verified
import streamlit as st
from transformers import ViTForImageClassification, ViTImageProcessor
from utils import load_image_vit, predict_toxicity_vit, get_label
from PIL import Image
import torch
import io
def classify_image(uploaded_file):
# Load pre-trained ViT model and processor
model_name = "google/vit-base-patch16-224"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)
# Modify for binary classification (toxic/non-toxic)
try:
model.classifier = torch.nn.Linear(model.classifier.in_features, 2)
model.load_state_dict(torch.load("toxic_classifier.pth", map_location=device), strict=False)
except FileNotFoundError:
st.warning("Using pre-trained ImageNet weights. For toxic classification, upload toxic_classifier.pth.")
model.to(device)
# Process image and predict
inputs = load_image_vit(uploaded_file, processor)
prediction, probabilities = predict_toxicity_vit(model, inputs, device)
label = get_label(prediction)
return label, probabilities
def main():
st.title("ToxiScan - Toxic Image Classifier")
st.write("Upload an image to detect if it contains toxic content using a pre-trained Vision Transformer.")
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
# Process and predict
with st.spinner("Analyzing..."):
label, probabilities = classify_image(uploaded_file)
# Display results
st.subheader("Results")
st.write(f"**Prediction:** {label}")
st.write(f"**Confidence Scores:**")
st.write(f"- Toxic: {probabilities[1]:.2%}")
st.write(f"- Non-Toxic: {probabilities[0]:.2%}")
# Bar chart for visualization
st.bar_chart({"Toxic": probabilities[1], "Non-Toxic": probabilities[0]})
if __name__ == "__main__":
main()