wracell commited on
Commit
c93638f
·
1 Parent(s): d26f75e
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -2,7 +2,6 @@ import streamlit as st
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
5
- from torchvision import transforms, models
6
  from segment_anything import sam_model_registry, SamPredictor
7
  from io import BytesIO
8
  import base64
@@ -27,21 +26,34 @@ def preprocess_image(image):
27
 
28
  # Segment garment using SAM
29
  def segment_garment(image, predictor):
 
30
  image_np = np.array(image.convert("RGB"))
31
  predictor.set_image(image_np)
32
 
 
33
  height, width, _ = image_np.shape
34
- input_point = np.array([[width // 2, height // 2]])
35
- input_label = np.array([1])
36
-
37
- masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label)
 
 
 
 
 
 
 
 
 
 
 
38
  mask = masks[0]
39
  mask_resized = cv2.resize(mask.astype(np.uint8) * 255, (width, height), interpolation=cv2.INTER_NEAREST)
40
- mask_resized = np.stack([mask_resized] * 3, axis=-1)
41
 
42
- segmented = np.where(mask_resized > 0, image_np, 0)
 
43
  return Image.fromarray(segmented)
44
-
45
  # AI garment analysis
46
  def analyze_garment(image, style_pref=None, feedback=None, generate_variations=False):
47
  image_bytes = BytesIO()
 
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
 
5
  from segment_anything import sam_model_registry, SamPredictor
6
  from io import BytesIO
7
  import base64
 
26
 
27
  # Segment garment using SAM
28
  def segment_garment(image, predictor):
29
+ # Convert image to NumPy and set it in SAM
30
  image_np = np.array(image.convert("RGB"))
31
  predictor.set_image(image_np)
32
 
33
+ # Define multiple input points (center, upper body, lower body)
34
  height, width, _ = image_np.shape
35
+ input_points = np.array([
36
+ [width // 2, height // 3], # upper torso
37
+ [width // 2, height // 2], # center
38
+ [width // 2, 2 * height // 3] # lower torso
39
+ ])
40
+ input_labels = np.array([1, 1, 1]) # all positive prompts
41
+
42
+ # Generate mask using multiple points
43
+ masks, scores, logits = predictor.predict(
44
+ point_coords=input_points,
45
+ point_labels=input_labels,
46
+ multimask_output=False
47
+ )
48
+
49
+ # Post-process mask
50
  mask = masks[0]
51
  mask_resized = cv2.resize(mask.astype(np.uint8) * 255, (width, height), interpolation=cv2.INTER_NEAREST)
52
+ mask_rgb = np.stack([mask_resized] * 3, axis=-1)
53
 
54
+ # Apply mask to original image
55
+ segmented = np.where(mask_rgb > 0, image_np, 0)
56
  return Image.fromarray(segmented)
 
57
  # AI garment analysis
58
  def analyze_garment(image, style_pref=None, feedback=None, generate_variations=False):
59
  image_bytes = BytesIO()