import streamlit as st import torch import numpy as np from PIL import Image from torchvision import transforms import json import cv2 import requests from io import BytesIO from src.model import SwinTransformerMultiLabel # Import from src folder # Title and description st.title("STAR Multi-Label Classifier") st.write("Upload an image or provide a URL to classify and blur sensitive areas.") # Load class labels from JSON label_file = "data/labels.json" with open(label_file, "r") as f: label_data = json.load(f) # Extract unique class labels class_labels = sorted(set(tag for tags in label_data.values() for tag in tags)) NUM_CLASSES = len(class_labels) # Define sensitive body parts to blur sensitive_labels = {"breast", "vagina", "penis"} # Adjust based on your dataset # Load trained model model_path = "models/star.pth" model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES) model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() # Define image preprocessing transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Function to detect sensitive areas using OpenCV def detect_sensitive_areas(image): """Detect sensitive areas (breasts, vagina, penis) using OpenCV.""" image_cv = np.array(image) # Convert PIL to NumPy gray = cv2.cvtColor(image_cv, cv2.COLOR_RGB2GRAY) # Load OpenCV pre-trained Haarcascade models (ensure these files exist) body_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_fullbody.xml") # Detect full bodies (basic approach, not perfect for specific areas) bodies = body_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) # Generate a blank mask for blurring mask = np.zeros_like(gray) # Mark detected regions in mask for (x, y, w, h) in bodies: mask[y:y+h, x:x+w] = 1 # Mark full-body regions (adjust for more precision) return mask # Function to blur only detected sensitive areas def blur_sensitive_parts(image, mask, blur_intensity=25): """Blurs only the detected sensitive areas while keeping the rest of the image clear.""" image_np = np.array(image) # Convert PIL to NumPy blurred = cv2.GaussianBlur(image_np, (51, 51), blur_intensity) # Apply blur mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2) # Convert mask to 3 channels result = np.where(mask == 1, blurred, image_np) # Blend only blurred areas return Image.fromarray(result) # Convert back to PIL Image # UI for image input: Upload or URL option = st.radio("Choose an input method:", ("Upload Image", "Enter Image URL")) image = None # Placeholder for the image if option == "Upload Image": uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") elif option == "Enter Image URL": image_url = st.text_input("Enter image URL:") if image_url: try: response = requests.get(image_url) image = Image.open(BytesIO(response.content)).convert("RGB") except Exception as e: st.error(f"Error fetching image: {e}") # Process the image if provided if image: # Preprocess image for model img_tensor = transform(image).unsqueeze(0) # Perform inference with torch.no_grad(): output = model(img_tensor) predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5] predicted_labels = [class_labels[i] for i in predicted_indices] # Detect sensitive areas mask = detect_sensitive_areas(image) # Apply selective blurring blurred_image = blur_sensitive_parts(image, mask) # Display only the blurred image st.image(blurred_image, caption="Blurred Image (Sensitive Areas Only)", use_container_width=True) # Display results st.write("✅ **Predicted Labels:**", ", ".join(predicted_labels) if predicted_labels else "No sensitive content detected")