File size: 4,116 Bytes
4fd3db7
 
e6abe34
4fd3db7
 
4fcd204
e6abe34
7668eb6
 
c10fd45
4fd3db7
 
c10fd45
7668eb6
4fcd204
 
fdf437f
4fcd204
 
 
 
 
 
4fd3db7
f78d2cc
 
 
4fd3db7
4fcd204
4fd3db7
 
 
 
4fcd204
4fd3db7
 
 
 
 
 
7668eb6
fdf437f
 
 
 
 
 
 
 
 
 
f78d2cc
fdf437f
 
f78d2cc
fdf437f
 
 
f78d2cc
fdf437f
 
 
 
 
 
e6abe34
fdf437f
f78d2cc
fdf437f
e6abe34
 
 
7668eb6
 
4fd3db7
7668eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6abe34
4fd3db7
e6abe34
 
4fd3db7
 
e6abe34
 
 
fdf437f
 
c10fd45
 
fdf437f
e6abe34
8623a7a
fdf437f
4fd3db7
c10fd45
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import streamlit as st
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import json
import cv2
import requests
from io import BytesIO
from src.model import SwinTransformerMultiLabel  # Import from src folder

# Title and description
st.title("STAR Multi-Label Classifier")
st.write("Upload an image or provide a URL to classify and blur sensitive areas.")

# Load class labels from JSON
label_file = "data/labels.json"
with open(label_file, "r") as f:
    label_data = json.load(f)

# Extract unique class labels
class_labels = sorted(set(tag for tags in label_data.values() for tag in tags))
NUM_CLASSES = len(class_labels)

# Define sensitive body parts to blur
sensitive_labels = {"breast", "vagina", "penis"}  # Adjust based on your dataset

# Load trained model
model_path = "models/star.pth"
model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()

# Define image preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Function to detect sensitive areas using OpenCV
def detect_sensitive_areas(image):
    """Detect sensitive areas (breasts, vagina, penis) using OpenCV."""
    image_cv = np.array(image)  # Convert PIL to NumPy
    gray = cv2.cvtColor(image_cv, cv2.COLOR_RGB2GRAY)

    # Load OpenCV pre-trained Haarcascade models (ensure these files exist)
    body_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_fullbody.xml")

    # Detect full bodies (basic approach, not perfect for specific areas)
    bodies = body_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

    # Generate a blank mask for blurring
    mask = np.zeros_like(gray)

    # Mark detected regions in mask
    for (x, y, w, h) in bodies:
        mask[y:y+h, x:x+w] = 1  # Mark full-body regions (adjust for more precision)

    return mask

# Function to blur only detected sensitive areas
def blur_sensitive_parts(image, mask, blur_intensity=25):
    """Blurs only the detected sensitive areas while keeping the rest of the image clear."""
    image_np = np.array(image)  # Convert PIL to NumPy

    blurred = cv2.GaussianBlur(image_np, (51, 51), blur_intensity)  # Apply blur
    mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)  # Convert mask to 3 channels
    result = np.where(mask == 1, blurred, image_np)  # Blend only blurred areas

    return Image.fromarray(result)  # Convert back to PIL Image

# UI for image input: Upload or URL
option = st.radio("Choose an input method:", ("Upload Image", "Enter Image URL"))

image = None  # Placeholder for the image

if option == "Upload Image":
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    if uploaded_file is not None:
        image = Image.open(uploaded_file).convert("RGB")

elif option == "Enter Image URL":
    image_url = st.text_input("Enter image URL:")
    if image_url:
        try:
            response = requests.get(image_url)
            image = Image.open(BytesIO(response.content)).convert("RGB")
        except Exception as e:
            st.error(f"Error fetching image: {e}")

# Process the image if provided
if image:
    # Preprocess image for model
    img_tensor = transform(image).unsqueeze(0)

    # Perform inference
    with torch.no_grad():
        output = model(img_tensor)
        predicted_indices = [i for i in range(NUM_CLASSES) if output[0][i] > 0.5]
        predicted_labels = [class_labels[i] for i in predicted_indices]

    # Detect sensitive areas
    mask = detect_sensitive_areas(image)

    # Apply selective blurring
    blurred_image = blur_sensitive_parts(image, mask)

    # Display only the blurred image
    st.image(blurred_image, caption="Blurred Image (Sensitive Areas Only)", use_container_width=True)

    # Display results
    st.write("✅ **Predicted Labels:**", ", ".join(predicted_labels) if predicted_labels else "No sensitive content detected")