JanviMl commited on
Commit
f07da63
·
verified ·
1 Parent(s): 9886d85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -1,40 +1,56 @@
1
  import streamlit as st
2
- import torch
3
- from model import ToxicImageClassifier
4
- from utils import load_image, predict_toxicity, get_label
5
  from PIL import Image
 
 
6
 
7
- def main():
8
- st.title("ToxiScan - Toxic Image Classifier")
9
- st.write("Upload an image to check if it contains toxic content")
10
-
11
- # Load model
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model = ToxicImageClassifier()
 
 
 
14
  try:
15
- model.load_state_dict(torch.load("toxic_classifier.pth", map_location=device))
 
16
  except FileNotFoundError:
17
- st.error("Model weights (toxic_classifier.pth) not found. Please train the model first.")
18
- return
19
  model.to(device)
20
 
 
 
 
 
 
 
 
 
 
 
 
21
  # File uploader
22
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
23
 
24
  if uploaded_file is not None:
25
- # Display image
26
  image = Image.open(uploaded_file)
27
  st.image(image, caption="Uploaded Image", use_column_width=True)
28
 
29
  # Process and predict
30
  with st.spinner("Analyzing..."):
31
- image_tensor = load_image(uploaded_file)
32
- prediction, probabilities = predict_toxicity(model, image_tensor, device)
33
- label = get_label(prediction)
34
 
35
  # Display results
36
- st.write(f"Prediction: **{label}**")
37
- st.write(f"Confidence: Toxic: {probabilities[1]:.2%}, Non-Toxic: {probabilities[0]:.2%}")
 
 
 
 
 
38
  st.bar_chart({"Toxic": probabilities[1], "Non-Toxic": probabilities[0]})
39
 
40
  if __name__ == "__main__":
 
1
  import streamlit as st
2
+ from transformers import ViTForImageClassification, ViTImageProcessor
3
+ from utils import load_image_vit, predict_toxicity_vit, get_label
 
4
  from PIL import Image
5
+ import torch
6
+ import io
7
 
8
+ def classify_image(uploaded_file):
9
+ # Load pre-trained ViT model and processor
10
+ model_name = "google/vit-base-patch16-224"
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = ViTForImageClassification.from_pretrained(model_name)
13
+ processor = ViTImageProcessor.from_pretrained(model_name)
14
+
15
+ # Modify for binary classification (toxic/non-toxic)
16
  try:
17
+ model.classifier = torch.nn.Linear(model.classifier.in_features, 2)
18
+ model.load_state_dict(torch.load("toxic_classifier.pth", map_location=device), strict=False)
19
  except FileNotFoundError:
20
+ st.warning("Using pre-trained ImageNet weights. For toxic classification, upload toxic_classifier.pth.")
 
21
  model.to(device)
22
 
23
+ # Process image and predict
24
+ inputs = load_image_vit(uploaded_file, processor)
25
+ prediction, probabilities = predict_toxicity_vit(model, inputs, device)
26
+ label = get_label(prediction)
27
+
28
+ return label, probabilities
29
+
30
+ def main():
31
+ st.title("ToxiScan - Toxic Image Classifier")
32
+ st.write("Upload an image to detect if it contains toxic content using a pre-trained Vision Transformer.")
33
+
34
  # File uploader
35
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
36
 
37
  if uploaded_file is not None:
38
+ # Display uploaded image
39
  image = Image.open(uploaded_file)
40
  st.image(image, caption="Uploaded Image", use_column_width=True)
41
 
42
  # Process and predict
43
  with st.spinner("Analyzing..."):
44
+ label, probabilities = classify_image(uploaded_file)
 
 
45
 
46
  # Display results
47
+ st.subheader("Results")
48
+ st.write(f"**Prediction:** {label}")
49
+ st.write(f"**Confidence Scores:**")
50
+ st.write(f"- Toxic: {probabilities[1]:.2%}")
51
+ st.write(f"- Non-Toxic: {probabilities[0]:.2%}")
52
+
53
+ # Bar chart for visualization
54
  st.bar_chart({"Toxic": probabilities[1], "Non-Toxic": probabilities[0]})
55
 
56
  if __name__ == "__main__":