wracell commited on
Commit
caef22e
Β·
1 Parent(s): 804b0e5

modifications and added sam_vit_b.pth

Browse files
Files changed (3) hide show
  1. app.py +64 -92
  2. requirements.txt +2 -1
  3. sam_vit_b.pth +3 -0
app.py CHANGED
@@ -2,136 +2,108 @@ import streamlit as st
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
5
- import torch
6
  import torchvision.transforms as transforms
7
- import torchvision.models as models
8
  from io import BytesIO
9
- from google.generativeai import configure, GenerativeModel
10
  import base64
 
11
 
12
  # Configure Gemini API
13
  configure(api_key="AIzaSyBawh403z5cyyQzFhQo14y7oUQw6nr8mIg")
14
  model = GenerativeModel("gemini-2.0-flash")
15
 
16
- # Load DeepLabV3 model for garment segmentation
17
- def load_segmentation_model():
18
- model = models.segmentation.deeplabv3_resnet101(pretrained=True)
19
- model.eval()
20
- return model
21
-
22
- def segment_garment(image, model):
23
- # Convert to model-compatible format
24
- transform = transforms.Compose([
25
- transforms.Resize((520, 520)),
26
- transforms.ToTensor(),
27
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28
- ])
29
-
30
- image_tensor = transform(image).unsqueeze(0)
31
-
32
- with torch.no_grad():
33
- output = model(image_tensor)['out'][0]
34
-
35
- # Convert output to segmentation mask
36
- mask = output.argmax(0).byte().cpu().numpy()
37
 
38
- # Debugging: Check if any garment pixels are detected
39
- print("Unique values in mask:", np.unique(mask))
 
 
 
 
40
 
41
- if mask.max() == 0: # No garment detected
42
- print("Warning: No garment detected. Check input image.")
43
- return image # Return original image
 
44
 
45
- # Resize mask to match original image size
46
- mask = Image.fromarray(mask.astype(np.uint8) * 255)
47
- mask = mask.resize(image.size, Image.NEAREST) # Match original size
48
- mask = np.array(mask)
49
 
50
- # Convert grayscale mask to 3-channel
51
- mask = np.stack([mask] * 3, axis=-1)
52
 
53
- # Convert original image to numpy array
54
- image_np = np.array(image)
 
55
 
56
- # Apply the mask (only keep garment pixels)
57
- segmented = np.where(mask > 0, image_np, 0)
58
 
59
  return Image.fromarray(segmented)
60
 
61
- # Preprocess image using OpenCV (Edge Detection & Background Removal)
62
- def preprocess_image(image):
63
- image = np.array(image.convert("RGB"))
64
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
65
- edges = cv2.Canny(gray, 100, 200)
66
- return Image.fromarray(edges)
67
-
68
- # AI garment analysis & fashion recommendations
69
  def analyze_garment(image):
70
- # Convert PIL Image to bytes
71
  image_bytes = BytesIO()
72
  image.save(image_bytes, format="PNG")
73
- image_bytes = image_bytes.getvalue()
74
-
75
- # Encode image in Base64
76
- encoded_image = base64.b64encode(image_bytes).decode("utf-8")
77
 
78
- # Prepare request payload in correct Gemini API format
79
  prompt = {
80
  "parts": [
81
  {"text": "Analyze the garment in this image, describing its style, fabric, and design elements. "
82
- "Based on the garment's features, suggest the best occasions to wear it "
83
- "and recommend complementary fashion pieces (e.g., shoes, accessories, layering options). "
84
- "Also, provide a seasonal suitability rating."},
85
  {"inline_data": {"mime_type": "image/png", "data": encoded_image}}
86
  ]
87
  }
88
-
89
- # Call Gemini API
90
  response = model.generate_content(prompt)
91
-
92
  return response.text if response else "Analysis failed."
93
 
94
- # Load segmentation model
95
- segmentation_model = load_segmentation_model()
96
 
97
  # Streamlit UI
98
- st.title("πŸ‘— AI-Enhanced Fashion Design with Gemini 2.0 Flash")
99
- st.write("""
100
- ### 🎨 **About This App**
101
- This AI-powered fashion tool helps you analyze and enhance garment designs using Google's **Gemini 2.0 Flash** and DeepLabV3 segmentation.
102
- It provides recommendations on **suitable fashion styles** for different occasions based on the uploaded garment.
103
-
104
- ### πŸ›  **How to Use This App**
105
- 1. **Upload an Image** – Select a clothing item or a fashion sketch.
106
- 2. **View Preprocessing & Segmentation** – The app applies edge detection and garment segmentation.
107
- 3. **Analyze Garment** – Click "Analyze Garment" to get AI-powered fashion insights.
108
- 4. **Get Recommendations** – The AI suggests suitable **occasions and styling tips** based on the garment.
109
-
110
- ### πŸ” **How It Works**
111
- - **Edge Detection**: Uses OpenCV to highlight contours and details in the garment.
112
- - **Garment Segmentation**: DeepLabV3 identifies the clothing item and removes the background.
113
- - **AI Fashion Analysis**: Google’s Gemini AI analyzes the **style, fabric, and design** of the garment and provides recommendations.
114
-
115
- ➑️ Try it now by uploading an image of your clothing!
116
  """)
117
 
118
  # File Upload
119
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
120
 
121
  if uploaded_file is not None:
122
- # Display uploaded image
123
  image = Image.open(uploaded_file)
124
  st.image(image, caption="Uploaded Image", use_container_width=True)
125
-
126
- # Preprocess Image with OpenCV
127
- processed_image = preprocess_image(image)
128
- st.image(processed_image, caption="Processed Image (Edge Detection)", use_container_width=True)
129
-
130
- # Segment Garment using DeepLabV3
131
- segmented_image = segment_garment(image, segmentation_model)
132
- st.image(segmented_image.convert("RGB"), caption="Segmented Garment", use_container_width=True)
133
 
134
-
135
- if st.button("Analyze Garment"):
136
- result = analyze_garment(image)
137
- st.success(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
 
5
  import torchvision.transforms as transforms
6
+ from segment_anything import sam_model_registry, SamPredictor
7
  from io import BytesIO
 
8
  import base64
9
+ from google.generativeai import configure, GenerativeModel
10
 
11
  # Configure Gemini API
12
  configure(api_key="AIzaSyBawh403z5cyyQzFhQo14y7oUQw6nr8mIg")
13
  model = GenerativeModel("gemini-2.0-flash")
14
 
15
+ # Load SAM model with ViT-Base
16
+ def load_sam_model():
17
+ sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b.pth") # Use vit_b instead of vit_h
18
+ predictor = SamPredictor(sam)
19
+ return predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Preprocess image using OpenCV (Edge Detection & Background Removal)
22
+ def preprocess_image(image):
23
+ image = np.array(image.convert("RGB")) # Convert to NumPy array
24
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) # Convert to Grayscale
25
+ edges = cv2.Canny(gray, 100, 200) # Apply Canny Edge Detection
26
+ return Image.fromarray(edges) # Convert back to PIL Image
27
 
28
+ # Segment garment using SAM
29
+ def segment_garment(image, predictor):
30
+ image_np = np.array(image.convert("RGB"))
31
+ predictor.set_image(image_np)
32
 
33
+ # Use center point of image as prompt
34
+ height, width, _ = image_np.shape
35
+ input_point = np.array([[width // 2, height // 2]])
36
+ input_label = np.array([1]) # 1 indicates object selection
37
 
38
+ masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label)
39
+ mask = masks[0] # Get first mask
40
 
41
+ # Resize mask to match image
42
+ mask_resized = cv2.resize(mask.astype(np.uint8) * 255, (width, height), interpolation=cv2.INTER_NEAREST)
43
+ mask_resized = np.stack([mask_resized] * 3, axis=-1) # Convert to 3-channel
44
 
45
+ # Apply segmentation mask
46
+ segmented = np.where(mask_resized > 0, image_np, 0)
47
 
48
  return Image.fromarray(segmented)
49
 
50
+ # AI garment analysis
 
 
 
 
 
 
 
51
  def analyze_garment(image):
 
52
  image_bytes = BytesIO()
53
  image.save(image_bytes, format="PNG")
54
+ encoded_image = base64.b64encode(image_bytes.getvalue()).decode("utf-8")
 
 
 
55
 
 
56
  prompt = {
57
  "parts": [
58
  {"text": "Analyze the garment in this image, describing its style, fabric, and design elements. "
59
+ "Suggest the best occasions to wear it and recommend complementary fashion pieces."},
 
 
60
  {"inline_data": {"mime_type": "image/png", "data": encoded_image}}
61
  ]
62
  }
63
+
 
64
  response = model.generate_content(prompt)
 
65
  return response.text if response else "Analysis failed."
66
 
67
+ # Load SAM model
68
+ sam_predictor = load_sam_model()
69
 
70
  # Streamlit UI
71
+ st.title("πŸ‘— AI Fashion Analysis with SAM & Gemini AI")
72
+
73
+ # Description and Instructions
74
+ st.markdown("""
75
+ ### πŸ“Œ How to Use this App:
76
+ 1. *Upload an Image*: Click the upload button and select a fashion image.
77
+ 2. *Preprocess the Image*: Click 'Preprocess' to apply edge detection and garment segmentation.
78
+ 3. *View Results*: Processed and segmented images will be displayed.
79
+ 4. *Analyze the Garment*: Click 'Analyze Garment' to get AI-based fashion insights.
 
 
 
 
 
 
 
 
 
80
  """)
81
 
82
  # File Upload
83
  uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
84
 
85
  if uploaded_file is not None:
 
86
  image = Image.open(uploaded_file)
87
  st.image(image, caption="Uploaded Image", use_container_width=True)
 
 
 
 
 
 
 
 
88
 
89
+ # Initialize session state for persistence
90
+ if "processed_image" not in st.session_state:
91
+ st.session_state.processed_image = None
92
+ if "segmented_image" not in st.session_state:
93
+ st.session_state.segmented_image = None
94
+
95
+ # Preprocess Button (Runs Edge Detection & Segmentation)
96
+ if st.button("Preprocess"):
97
+ st.session_state.processed_image = preprocess_image(image)
98
+ st.session_state.segmented_image = segment_garment(image, sam_predictor)
99
+
100
+ # Display Preprocessed Images if Available
101
+ if st.session_state.processed_image:
102
+ st.image(st.session_state.processed_image, caption="Edge Detection", use_container_width=True)
103
+ if st.session_state.segmented_image:
104
+ st.image(st.session_state.segmented_image, caption="Segmented Garment", use_container_width=True)
105
+
106
+ # Analyze Garment Button (Gemini AI)
107
+ if st.button("Analyze Garment"):
108
+ result = analyze_garment(image)
109
+ st.success(result)
requirements.txt CHANGED
@@ -4,4 +4,5 @@ opencv-python
4
  pillow
5
  torch
6
  torchvision
7
- google-generativeai
 
 
4
  pillow
5
  torch
6
  torchvision
7
+ google-generativeai
8
+ transformers
sam_vit_b.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383