| import streamlit as st |
| import torch |
| import numpy as np |
| from PIL import Image |
| from torchvision import transforms |
| import json |
| import cv2 |
| import requests |
| from io import BytesIO |
| from src.model import SwinTransformerMultiLabel |
|
|
| |
| st.title("STAR Multi-Label Classifier") |
| st.write("Upload an image or provide a URL to classify and blur sensitive areas.") |
|
|
| |
| label_file = "data/labels.json" |
| with open(label_file, "r") as f: |
| label_data = json.load(f) |
|
|
| |
| class_labels = sorted(set(tag for tags in label_data.values() for tag in tags)) |
| NUM_CLASSES = len(class_labels) |
|
|
| |
| sensitive_labels = {"breast", "vagina", "penis"} |
|
|
| |
| model_path = "models/star.pth" |
| model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES) |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) |
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| def detect_sensitive_areas(image): |
| """Detect sensitive areas (breasts, vagina, penis) using OpenCV.""" |
| image_cv = np.array(image) |
| gray = cv2.cvtColor(image_cv, cv2.COLOR_RGB2GRAY) |
|
|
| |
| body_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_fullbody.xml") |
|
|
| |
| bodies = body_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) |
|
|
| |
| mask = np.zeros_like(gray) |
|
|
| |
| for (x, y, w, h) in bodies: |
| mask[y:y+h, x:x+w] = 1 |
|
|
| return mask |
|
|
| |
| def blur_sensitive_parts(image, mask, blur_intensity=25): |
| """Blurs only the detected sensitive areas while keeping the rest of the image clear.""" |
| image_np = np.array(image) |
|
|
| blurred = cv2.GaussianBlur(image_np, (51, 51), blur_intensity) |
| mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2) |
| result = np.where(mask == 1, blurred, image_np) |
|
|
| return Image.fromarray(result) |
|
|
| |
| option = st.radio("Choose an input method:", ("Upload Image", "Enter Image URL")) |
|
|
| image = None |
|
|
| if option == "Upload Image": |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
| if uploaded_file is not None: |
| image = Image.open(uploaded_file).convert("RGB") |
|
|
| elif option == "Enter Image URL": |
| image_url = st.text_input("Enter image URL:") |
| if image_url: |
| try: |
| response = requests.get(image_url) |
| image = Image.open(BytesIO(response.content)).convert("RGB") |
| except Exception as e: |
| st.error(f"Error fetching image: {e}") |
|
|
| |
| if image: |
| |
| img_tensor = transform(image).unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| output = model(img_tensor) |
| predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5] |
| predicted_labels = [class_labels[i] for i in predicted_indices] |
|
|
| |
| mask = detect_sensitive_areas(image) |
|
|
| |
| blurred_image = blur_sensitive_parts(image, mask) |
|
|
| |
| st.image(blurred_image, caption="Blurred Image (Sensitive Areas Only)", use_container_width=True) |
|
|
| |
| st.write("✅ **Predicted Labels:**", ", ".join(predicted_labels) if predicted_labels else "No sensitive content detected") |
|
|