|
|
import streamlit as st |
|
|
from transformers import ViTForImageClassification, ViTImageProcessor |
|
|
from utils import load_image_vit, predict_toxicity_vit, get_label |
|
|
from PIL import Image |
|
|
import torch |
|
|
import io |
|
|
|
|
|
def classify_image(uploaded_file): |
|
|
|
|
|
model_name = "google/vit-base-patch16-224" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = ViTForImageClassification.from_pretrained(model_name) |
|
|
processor = ViTImageProcessor.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
try: |
|
|
model.classifier = torch.nn.Linear(model.classifier.in_features, 2) |
|
|
model.load_state_dict(torch.load("toxic_classifier.pth", map_location=device), strict=False) |
|
|
except FileNotFoundError: |
|
|
st.warning("Using pre-trained ImageNet weights. For toxic classification, upload toxic_classifier.pth.") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
inputs = load_image_vit(uploaded_file, processor) |
|
|
prediction, probabilities = predict_toxicity_vit(model, inputs, device) |
|
|
label = get_label(prediction) |
|
|
|
|
|
return label, probabilities |
|
|
|
|
|
def main(): |
|
|
st.title("ToxiScan - Toxic Image Classifier") |
|
|
st.write("Upload an image to detect if it contains toxic content using a pre-trained Vision Transformer.") |
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
|
|
|
image = Image.open(uploaded_file) |
|
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
|
|
|
with st.spinner("Analyzing..."): |
|
|
label, probabilities = classify_image(uploaded_file) |
|
|
|
|
|
|
|
|
st.subheader("Results") |
|
|
st.write(f"**Prediction:** {label}") |
|
|
st.write(f"**Confidence Scores:**") |
|
|
st.write(f"- Toxic: {probabilities[1]:.2%}") |
|
|
st.write(f"- Non-Toxic: {probabilities[0]:.2%}") |
|
|
|
|
|
|
|
|
st.bar_chart({"Toxic": probabilities[1], "Non-Toxic": probabilities[0]}) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |