Spaces:
Sleeping
Sleeping
wracell
commited on
Commit
·
c93638f
1
Parent(s):
d26f75e
changes
Browse files
app.py
CHANGED
|
@@ -2,7 +2,6 @@ import streamlit as st
|
|
| 2 |
import numpy as np
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
| 5 |
-
from torchvision import transforms, models
|
| 6 |
from segment_anything import sam_model_registry, SamPredictor
|
| 7 |
from io import BytesIO
|
| 8 |
import base64
|
|
@@ -27,21 +26,34 @@ def preprocess_image(image):
|
|
| 27 |
|
| 28 |
# Segment garment using SAM
|
| 29 |
def segment_garment(image, predictor):
|
|
|
|
| 30 |
image_np = np.array(image.convert("RGB"))
|
| 31 |
predictor.set_image(image_np)
|
| 32 |
|
|
|
|
| 33 |
height, width, _ = image_np.shape
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
mask = masks[0]
|
| 39 |
mask_resized = cv2.resize(mask.astype(np.uint8) * 255, (width, height), interpolation=cv2.INTER_NEAREST)
|
| 40 |
-
|
| 41 |
|
| 42 |
-
|
|
|
|
| 43 |
return Image.fromarray(segmented)
|
| 44 |
-
|
| 45 |
# AI garment analysis
|
| 46 |
def analyze_garment(image, style_pref=None, feedback=None, generate_variations=False):
|
| 47 |
image_bytes = BytesIO()
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import cv2
|
| 4 |
from PIL import Image
|
|
|
|
| 5 |
from segment_anything import sam_model_registry, SamPredictor
|
| 6 |
from io import BytesIO
|
| 7 |
import base64
|
|
|
|
| 26 |
|
| 27 |
# Segment garment using SAM
|
| 28 |
def segment_garment(image, predictor):
|
| 29 |
+
# Convert image to NumPy and set it in SAM
|
| 30 |
image_np = np.array(image.convert("RGB"))
|
| 31 |
predictor.set_image(image_np)
|
| 32 |
|
| 33 |
+
# Define multiple input points (center, upper body, lower body)
|
| 34 |
height, width, _ = image_np.shape
|
| 35 |
+
input_points = np.array([
|
| 36 |
+
[width // 2, height // 3], # upper torso
|
| 37 |
+
[width // 2, height // 2], # center
|
| 38 |
+
[width // 2, 2 * height // 3] # lower torso
|
| 39 |
+
])
|
| 40 |
+
input_labels = np.array([1, 1, 1]) # all positive prompts
|
| 41 |
+
|
| 42 |
+
# Generate mask using multiple points
|
| 43 |
+
masks, scores, logits = predictor.predict(
|
| 44 |
+
point_coords=input_points,
|
| 45 |
+
point_labels=input_labels,
|
| 46 |
+
multimask_output=False
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Post-process mask
|
| 50 |
mask = masks[0]
|
| 51 |
mask_resized = cv2.resize(mask.astype(np.uint8) * 255, (width, height), interpolation=cv2.INTER_NEAREST)
|
| 52 |
+
mask_rgb = np.stack([mask_resized] * 3, axis=-1)
|
| 53 |
|
| 54 |
+
# Apply mask to original image
|
| 55 |
+
segmented = np.where(mask_rgb > 0, image_np, 0)
|
| 56 |
return Image.fromarray(segmented)
|
|
|
|
| 57 |
# AI garment analysis
|
| 58 |
def analyze_garment(image, style_pref=None, feedback=None, generate_variations=False):
|
| 59 |
image_bytes = BytesIO()
|