ramagururadhakrishnan commited on
Commit
7668eb6
·
verified ·
1 Parent(s): fdf437f

- URL as Image Source

Files changed (1) hide show
  1. app.py +24 -6
app.py CHANGED
@@ -5,11 +5,13 @@ from PIL import Image
5
  from torchvision import transforms
6
  import json
7
  import cv2
 
 
8
  from src.model import SwinTransformerMultiLabel # Import from src folder
9
 
10
  # Title and description
11
  st.title("STAR Multi-Label Classifier")
12
- st.write("Upload an image to classify and blur sensitive areas.")
13
 
14
  # Load class labels from JSON
15
  label_file = "data/labels.json"
@@ -36,7 +38,7 @@ transform = transforms.Compose([
36
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
37
  ])
38
 
39
- # Function to detect sensitive parts using OpenCV
40
  def detect_sensitive_areas(image):
41
  """Detect sensitive areas (breasts, vagina, penis) using OpenCV."""
42
  image_cv = np.array(image) # Convert PIL to NumPy
@@ -68,11 +70,27 @@ def blur_sensitive_parts(image, mask, blur_intensity=25):
68
 
69
  return Image.fromarray(result) # Convert back to PIL Image
70
 
71
- # Upload image
72
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
73
- if uploaded_file is not None:
74
- image = Image.open(uploaded_file).convert("RGB")
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # Preprocess image for model
77
  img_tensor = transform(image).unsqueeze(0)
78
 
 
5
  from torchvision import transforms
6
  import json
7
  import cv2
8
+ import requests
9
+ from io import BytesIO
10
  from src.model import SwinTransformerMultiLabel # Import from src folder
11
 
12
  # Title and description
13
  st.title("STAR Multi-Label Classifier")
14
+ st.write("Upload an image or provide a URL to classify and blur sensitive areas.")
15
 
16
  # Load class labels from JSON
17
  label_file = "data/labels.json"
 
38
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
39
  ])
40
 
41
+ # Function to detect sensitive areas using OpenCV
42
  def detect_sensitive_areas(image):
43
  """Detect sensitive areas (breasts, vagina, penis) using OpenCV."""
44
  image_cv = np.array(image) # Convert PIL to NumPy
 
70
 
71
  return Image.fromarray(result) # Convert back to PIL Image
72
 
73
+ # UI for image input: Upload or URL
74
+ option = st.radio("Choose an input method:", ("Upload Image", "Enter Image URL"))
 
 
75
 
76
+ image = None # Placeholder for the image
77
+
78
+ if option == "Upload Image":
79
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
80
+ if uploaded_file is not None:
81
+ image = Image.open(uploaded_file).convert("RGB")
82
+
83
+ elif option == "Enter Image URL":
84
+ image_url = st.text_input("Enter image URL:")
85
+ if image_url:
86
+ try:
87
+ response = requests.get(image_url)
88
+ image = Image.open(BytesIO(response.content)).convert("RGB")
89
+ except Exception as e:
90
+ st.error(f"Error fetching image: {e}")
91
+
92
+ # Process the image if provided
93
+ if image:
94
  # Preprocess image for model
95
  img_tensor = transform(image).unsqueeze(0)
96