File size: 4,116 Bytes
4fd3db7 e6abe34 4fd3db7 4fcd204 e6abe34 7668eb6 c10fd45 4fd3db7 c10fd45 7668eb6 4fcd204 fdf437f 4fcd204 4fd3db7 f78d2cc 4fd3db7 4fcd204 4fd3db7 4fcd204 4fd3db7 7668eb6 fdf437f f78d2cc fdf437f f78d2cc fdf437f f78d2cc fdf437f e6abe34 fdf437f f78d2cc fdf437f e6abe34 7668eb6 4fd3db7 7668eb6 e6abe34 4fd3db7 e6abe34 4fd3db7 e6abe34 fdf437f c10fd45 fdf437f e6abe34 8623a7a fdf437f 4fd3db7 c10fd45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import streamlit as st
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import json
import cv2
import requests
from io import BytesIO
from src.model import SwinTransformerMultiLabel # Import from src folder
# Title and description
st.title("STAR Multi-Label Classifier")
st.write("Upload an image or provide a URL to classify and blur sensitive areas.")
# Load class labels from JSON
label_file = "data/labels.json"
with open(label_file, "r") as f:
label_data = json.load(f)
# Extract unique class labels
class_labels = sorted(set(tag for tags in label_data.values() for tag in tags))
NUM_CLASSES = len(class_labels)
# Define sensitive body parts to blur
sensitive_labels = {"breast", "vagina", "penis"} # Adjust based on your dataset
# Load trained model
model_path = "models/star.pth"
model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Define image preprocessing transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Function to detect sensitive areas using OpenCV
def detect_sensitive_areas(image):
"""Detect sensitive areas (breasts, vagina, penis) using OpenCV."""
image_cv = np.array(image) # Convert PIL to NumPy
gray = cv2.cvtColor(image_cv, cv2.COLOR_RGB2GRAY)
# Load OpenCV pre-trained Haarcascade models (ensure these files exist)
body_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_fullbody.xml")
# Detect full bodies (basic approach, not perfect for specific areas)
bodies = body_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
# Generate a blank mask for blurring
mask = np.zeros_like(gray)
# Mark detected regions in mask
for (x, y, w, h) in bodies:
mask[y:y+h, x:x+w] = 1 # Mark full-body regions (adjust for more precision)
return mask
# Function to blur only detected sensitive areas
def blur_sensitive_parts(image, mask, blur_intensity=25):
"""Blurs only the detected sensitive areas while keeping the rest of the image clear."""
image_np = np.array(image) # Convert PIL to NumPy
blurred = cv2.GaussianBlur(image_np, (51, 51), blur_intensity) # Apply blur
mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2) # Convert mask to 3 channels
result = np.where(mask == 1, blurred, image_np) # Blend only blurred areas
return Image.fromarray(result) # Convert back to PIL Image
# UI for image input: Upload or URL
option = st.radio("Choose an input method:", ("Upload Image", "Enter Image URL"))
image = None # Placeholder for the image
if option == "Upload Image":
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
elif option == "Enter Image URL":
image_url = st.text_input("Enter image URL:")
if image_url:
try:
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
st.error(f"Error fetching image: {e}")
# Process the image if provided
if image:
# Preprocess image for model
img_tensor = transform(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
output = model(img_tensor)
predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5]
predicted_labels = [class_labels[i] for i in predicted_indices]
# Detect sensitive areas
mask = detect_sensitive_areas(image)
# Apply selective blurring
blurred_image = blur_sensitive_parts(image, mask)
# Display only the blurred image
st.image(blurred_image, caption="Blurred Image (Sensitive Areas Only)", use_container_width=True)
# Display results
st.write("✅ **Predicted Labels:**", ", ".join(predicted_labels) if predicted_labels else "No sensitive content detected")
|