STAR / app.py
ramagururadhakrishnan's picture
Updated
7668eb6 verified
import streamlit as st
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import json
import cv2
import requests
from io import BytesIO
from src.model import SwinTransformerMultiLabel # Import from src folder
# Title and description
st.title("STAR Multi-Label Classifier")
st.write("Upload an image or provide a URL to classify and blur sensitive areas.")
# Load class labels from JSON
label_file = "data/labels.json"
with open(label_file, "r") as f:
label_data = json.load(f)
# Extract unique class labels
class_labels = sorted(set(tag for tags in label_data.values() for tag in tags))
NUM_CLASSES = len(class_labels)
# Define sensitive body parts to blur
sensitive_labels = {"breast", "vagina", "penis"} # Adjust based on your dataset
# Load trained model
model_path = "models/star.pth"
model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Define image preprocessing transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Function to detect sensitive areas using OpenCV
def detect_sensitive_areas(image):
"""Detect sensitive areas (breasts, vagina, penis) using OpenCV."""
image_cv = np.array(image) # Convert PIL to NumPy
gray = cv2.cvtColor(image_cv, cv2.COLOR_RGB2GRAY)
# Load OpenCV pre-trained Haarcascade models (ensure these files exist)
body_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_fullbody.xml")
# Detect full bodies (basic approach, not perfect for specific areas)
bodies = body_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
# Generate a blank mask for blurring
mask = np.zeros_like(gray)
# Mark detected regions in mask
for (x, y, w, h) in bodies:
mask[y:y+h, x:x+w] = 1 # Mark full-body regions (adjust for more precision)
return mask
# Function to blur only detected sensitive areas
def blur_sensitive_parts(image, mask, blur_intensity=25):
"""Blurs only the detected sensitive areas while keeping the rest of the image clear."""
image_np = np.array(image) # Convert PIL to NumPy
blurred = cv2.GaussianBlur(image_np, (51, 51), blur_intensity) # Apply blur
mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2) # Convert mask to 3 channels
result = np.where(mask == 1, blurred, image_np) # Blend only blurred areas
return Image.fromarray(result) # Convert back to PIL Image
# UI for image input: Upload or URL
option = st.radio("Choose an input method:", ("Upload Image", "Enter Image URL"))
image = None # Placeholder for the image
if option == "Upload Image":
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
elif option == "Enter Image URL":
image_url = st.text_input("Enter image URL:")
if image_url:
try:
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
st.error(f"Error fetching image: {e}")
# Process the image if provided
if image:
# Preprocess image for model
img_tensor = transform(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
output = model(img_tensor)
predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5]
predicted_labels = [class_labels[i] for i in predicted_indices]
# Detect sensitive areas
mask = detect_sensitive_areas(image)
# Apply selective blurring
blurred_image = blur_sensitive_parts(image, mask)
# Display only the blurred image
st.image(blurred_image, caption="Blurred Image (Sensitive Areas Only)", use_container_width=True)
# Display results
st.write("✅ **Predicted Labels:**", ", ".join(predicted_labels) if predicted_labels else "No sensitive content detected")