uhdessai commited on
Commit
02553a7
·
verified ·
1 Parent(s): d356138

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -111
app.py CHANGED
@@ -3,170 +3,312 @@
3
  # import os
4
  # import random
5
  # from PIL import Image
 
 
6
 
7
- # # Paths
8
- # OUTPUT_DIR = "outputs"
9
- # MODEL_PATH = "top_model.pkl" # Adjusted for local Hugging Face repo structure
 
 
 
10
 
11
- # # Ensure the output directory exists
12
  # os.makedirs(OUTPUT_DIR, exist_ok=True)
 
13
 
14
- # # Function to generate images using StyleGAN3
15
  # def generate_images():
16
- # command = f"python stylegan3/gen_images.py --outdir={OUTPUT_DIR} --trunc=1 --seeds='1-50' --network={MODEL_PATH}"
 
 
 
 
 
 
 
17
  # try:
18
- # subprocess.run(command, shell=True, check=True)
19
  # except subprocess.CalledProcessError as e:
20
- # return f"Error generating images: {e}"
21
 
22
- # # Function to select 5 random images
23
  # def get_random_images():
24
  # image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
25
  # if len(image_files) < 10:
26
  # generate_images()
27
  # image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
28
  # random_images = random.sample(image_files, min(10, len(image_files)))
29
- # return [Image.open(os.path.join(OUTPUT_DIR, img)) for img in random_images]
 
30
 
31
- # # Gradio function
32
- # def generate_and_display():
33
- # generate_images()
34
- # return get_random_images()
35
 
36
- # # UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # with gr.Blocks() as demo:
38
  # gr.Markdown("# 🎨 AI-Generated Clothing Designs - Tops")
 
39
  # generate_button = gr.Button("Generate New Designs")
40
- # output_gallery = gr.Gallery(label="Generated Designs", columns=5, rows=2)
41
- # generate_button.click(fn=generate_and_display, outputs=output_gallery)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # if __name__ == "__main__":
44
  # demo.launch()
45
 
 
 
 
 
 
 
 
46
 
47
 
 
 
48
 
 
 
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
- import gradio as gr
53
- import subprocess
54
  import os
55
- import random
56
- from PIL import Image
57
- import shutil
58
- import requests
59
-
60
- # === Setup Paths ===
61
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
62
- GEN_SCRIPT = os.path.join(BASE_DIR, "stylegan3", "gen_images.py")
63
- OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
64
- MODEL_PATH = os.path.join(BASE_DIR, "top_model.pkl")
65
- SAVE_DIR = os.path.join(BASE_DIR, "saved_images")
66
-
67
- os.makedirs(OUTPUT_DIR, exist_ok=True)
68
- os.makedirs(SAVE_DIR, exist_ok=True)
69
-
70
- # === Image Generation Function ===
71
- def generate_images():
72
- command = [
73
- "python",
74
- GEN_SCRIPT,
75
- f"--outdir={OUTPUT_DIR}",
76
- "--trunc=1",
77
- "--seeds=3-5,7,9,12-14,16-26,29,31,32,34,40,41",
78
- f"--network={MODEL_PATH}"
79
- ]
80
- try:
81
- subprocess.run(command, check=True, capture_output=True, text=True)
82
- except subprocess.CalledProcessError as e:
83
- return f"Error generating images:\n{e.stderr}"
84
-
85
- # === Select Random Images from Output Folder ===
86
- def get_random_images():
87
- image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
88
- if len(image_files) < 10:
89
- generate_images()
90
- image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
91
- random_images = random.sample(image_files, min(10, len(image_files)))
92
- image_paths = [os.path.join(OUTPUT_DIR, img) for img in random_images]
93
- return image_paths
94
-
95
- # === Send Image to Backend ===
96
  def send_to_backend(img_path, user_id):
 
 
97
  if not user_id:
 
98
  return "❌ user_id not found in URL."
99
 
100
  if not img_path or not os.path.exists(img_path):
 
101
  return "⚠️ No image selected or image not found."
102
 
103
  try:
104
  with open(img_path, 'rb') as f:
105
  files = {'file': ('generated_image.png', f, 'image/png')}
106
-
107
- # Your backend endpoint here
108
- url = f" https://7da2-2409-4042-6e81-1806-de6-b8e5-836c-6b95.ngrok-free.app/images/upload/{user_id}"
109
  response = requests.post(url, files=files)
110
-
111
- if response.status_code == 201:
 
112
  return "✅ Image uploaded and saved to database!"
113
  else:
114
  return f"❌ Upload failed: {response.status_code} - {response.text}"
115
 
116
  except Exception as e:
 
117
  return f"⚠️ Error: {str(e)}"
118
 
119
- # === Gradio Interface ===
120
- with gr.Blocks() as demo:
121
- gr.Markdown("# 🎨 AI-Generated Clothing Designs - Tops")
122
 
123
- generate_button = gr.Button("Generate New Designs")
124
- user_id_state = gr.State()
125
 
126
- @demo.load(inputs=None, outputs=[user_id_state])
127
- def get_user_id(request: gr.Request):
128
- return request.query_params.get("user_id", "")
129
 
130
- image_components = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  file_paths = []
132
- save_buttons = []
133
- outputs = []
134
-
135
- # Use 3 columns layout
136
- for row_idx in range(4): # 4 rows (to cover 10 images)
137
- with gr.Row():
138
- for col_idx in range(3): # 3 columns
139
- i = row_idx * 3 + col_idx
140
- if i >= 10:
141
- break
142
- with gr.Column():
143
- img = gr.Image(width=180, height=180, label=f"Design {i+1}")
144
- image_components.append(img)
145
-
146
- file_path = gr.Textbox(visible=False)
147
- file_paths.append(file_path)
148
-
149
- save_btn = gr.Button("💾 Save to DB")
150
- save_buttons.append(save_btn)
151
-
152
- output = gr.Textbox(label="Status", interactive=False)
153
- outputs.append(output)
154
-
155
- save_btn.click(
156
- fn=send_to_backend,
157
- inputs=[file_path, user_id_state],
158
- outputs=output
159
- )
160
-
161
- # Generate button logic
162
- def generate_and_display_images():
163
- image_paths = get_random_images()
164
- return image_paths + image_paths # One for display, one for hidden path tracking
165
-
166
- generate_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  fn=generate_and_display_images,
 
168
  outputs=image_components + file_paths
169
  )
170
 
171
- if __name__ == "__main__":
172
- demo.launch()
 
3
  # import os
4
  # import random
5
  # from PIL import Image
6
+ # import shutil
7
+ # import requests
8
 
9
+ # # === Setup Paths ===
10
+ # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
+ # GEN_SCRIPT = os.path.join(BASE_DIR, "stylegan3", "gen_images.py")
12
+ # OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
13
+ # MODEL_PATH = os.path.join(BASE_DIR, "top_model.pkl")
14
+ # SAVE_DIR = os.path.join(BASE_DIR, "saved_images")
15
 
 
16
  # os.makedirs(OUTPUT_DIR, exist_ok=True)
17
+ # os.makedirs(SAVE_DIR, exist_ok=True)
18
 
19
+ # # === Image Generation Function ===
20
  # def generate_images():
21
+ # command = [
22
+ # "python",
23
+ # GEN_SCRIPT,
24
+ # f"--outdir={OUTPUT_DIR}",
25
+ # "--trunc=1",
26
+ # "--seeds=3-5,7,9,12-14,16-26,29,31,32,34,40,41",
27
+ # f"--network={MODEL_PATH}"
28
+ # ]
29
  # try:
30
+ # subprocess.run(command, check=True, capture_output=True, text=True)
31
  # except subprocess.CalledProcessError as e:
32
+ # return f"Error generating images:\n{e.stderr}"
33
 
34
+ # # === Select Random Images from Output Folder ===
35
  # def get_random_images():
36
  # image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
37
  # if len(image_files) < 10:
38
  # generate_images()
39
  # image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
40
  # random_images = random.sample(image_files, min(10, len(image_files)))
41
+ # image_paths = [os.path.join(OUTPUT_DIR, img) for img in random_images]
42
+ # return image_paths
43
 
44
+ # # === Send Image to Backend ===
45
+ # def send_to_backend(img_path, user_id):
46
+ # if not user_id:
47
+ # return "❌ user_id not found in URL."
48
 
49
+ # if not img_path or not os.path.exists(img_path):
50
+ # return "⚠️ No image selected or image not found."
51
+
52
+ # try:
53
+ # with open(img_path, 'rb') as f:
54
+ # files = {'file': ('generated_image.png', f, 'image/png')}
55
+
56
+ # # Your backend endpoint here
57
+ # url = f" https://7da2-2409-4042-6e81-1806-de6-b8e5-836c-6b95.ngrok-free.app/images/upload/{user_id}"
58
+ # response = requests.post(url, files=files)
59
+
60
+ # if response.status_code == 201:
61
+ # return "✅ Image uploaded and saved to database!"
62
+ # else:
63
+ # return f"❌ Upload failed: {response.status_code} - {response.text}"
64
+
65
+ # except Exception as e:
66
+ # return f"⚠️ Error: {str(e)}"
67
+
68
+ # # === Gradio Interface ===
69
  # with gr.Blocks() as demo:
70
  # gr.Markdown("# 🎨 AI-Generated Clothing Designs - Tops")
71
+
72
  # generate_button = gr.Button("Generate New Designs")
73
+ # user_id_state = gr.State()
74
+
75
+ # @demo.load(inputs=None, outputs=[user_id_state])
76
+ # def get_user_id(request: gr.Request):
77
+ # return request.query_params.get("user_id", "")
78
+
79
+ # image_components = []
80
+ # file_paths = []
81
+ # save_buttons = []
82
+ # outputs = []
83
+
84
+ # # Use 3 columns layout
85
+ # for row_idx in range(4): # 4 rows (to cover 10 images)
86
+ # with gr.Row():
87
+ # for col_idx in range(3): # 3 columns
88
+ # i = row_idx * 3 + col_idx
89
+ # if i >= 10:
90
+ # break
91
+ # with gr.Column():
92
+ # img = gr.Image(width=180, height=180, label=f"Design {i+1}")
93
+ # image_components.append(img)
94
+
95
+ # file_path = gr.Textbox(visible=False)
96
+ # file_paths.append(file_path)
97
+
98
+ # save_btn = gr.Button("💾 Save to DB")
99
+ # save_buttons.append(save_btn)
100
+
101
+ # output = gr.Textbox(label="Status", interactive=False)
102
+ # outputs.append(output)
103
+
104
+ # save_btn.click(
105
+ # fn=send_to_backend,
106
+ # inputs=[file_path, user_id_state],
107
+ # outputs=output
108
+ # )
109
+
110
+ # # Generate button logic
111
+ # def generate_and_display_images():
112
+ # image_paths = get_random_images()
113
+ # return image_paths + image_paths # One for display, one for hidden path tracking
114
+
115
+ # generate_button.click(
116
+ # fn=generate_and_display_images,
117
+ # outputs=image_components + file_paths
118
+ # )
119
 
120
  # if __name__ == "__main__":
121
  # demo.launch()
122
 
123
+ import torch
124
+ from transformers import CLIPModel, CLIPProcessor
125
+ from PIL import Image
126
+ import numpy as np
127
+ import pickle
128
+ import gradio as gr
129
+ import tempfile
130
 
131
 
132
+ # Force CPU usage for optimization
133
+ device = torch.device("cpu")
134
 
135
+ # Load your GAN model
136
+ with open("top_model.pkl", "rb") as f:
137
+ G = pickle.load(f)['G_ema'].eval().cpu() # Ensure model is in eval mode and on CPU
138
 
139
+ # Load CLIP model and processor
140
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval().cpu()
141
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
142
+
143
+ # def send_to_backend(img_path, user_id):
144
+ # if not user_id:
145
+ # return "❌ user_id not found in URL."
146
+
147
+ # if not img_path or not os.path.exists(img_path):
148
+ # return "⚠️ No image selected or image not found."
149
+
150
+ # try:
151
+ # with open(img_path, 'rb') as f:
152
+ # files = {'file': ('generated_image.png', f, 'image/png')}
153
+
154
+ # # Your backend endpoint here
155
+ # url = f"https://e335-103-40-74-83.ngrok-free.app/images/upload/{user_id}"
156
+ # response = requests.post(url, files=files)
157
+
158
+ # if response.status_code == 201:
159
+ # return "✅ Image uploaded and saved to database!"
160
+ # else:
161
+ # console.log({response.text})
162
+ # return f"❌ Upload failed: {response.status_code} - {response.text}"
163
+
164
+ # except Exception as e:
165
+ # return f"⚠️ Error: {str(e)}"
166
 
167
 
 
 
168
  import os
169
+ import requests # Make sure you import this!
170
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def send_to_backend(img_path, user_id):
172
+ print(f"💡 [DEBUG] Sending image to backend | img_path={img_path}, user_id={user_id}")
173
+
174
  if not user_id:
175
+ print("❌ [DEBUG] Missing user_id in URL.")
176
  return "❌ user_id not found in URL."
177
 
178
  if not img_path or not os.path.exists(img_path):
179
+ print("⚠️ [DEBUG] Image path invalid or does not exist.")
180
  return "⚠️ No image selected or image not found."
181
 
182
  try:
183
  with open(img_path, 'rb') as f:
184
  files = {'file': ('generated_image.png', f, 'image/png')}
185
+ url = f"https://e335-103-40-74-83.ngrok-free.app/images/upload/{user_id}"
186
+ print(f"🔁 [DEBUG] Sending POST to {url}")
 
187
  response = requests.post(url, files=files)
188
+
189
+ print(f"📩 [DEBUG] Response: {response.status_code} - {response.text}")
190
+ if response.status_code == 201 or response.status_code == 200:
191
  return "✅ Image uploaded and saved to database!"
192
  else:
193
  return f"❌ Upload failed: {response.status_code} - {response.text}"
194
 
195
  except Exception as e:
196
+ print(f"⚠️ [ERROR] Exception during upload: {str(e)}")
197
  return f"⚠️ Error: {str(e)}"
198
 
 
 
 
199
 
 
 
200
 
 
 
 
201
 
202
+
203
+ # Generate images
204
+ def generate_images(G, num_images=10): # Reduce for CPU performance
205
+ z = torch.randn(num_images, G.z_dim)
206
+ c = None
207
+ with torch.no_grad():
208
+ images = G(z, c)
209
+ images = (images.clamp(-1, 1) + 1) * (255 / 2)
210
+ images = images.permute(0, 2, 3, 1).numpy().astype(np.uint8)
211
+ return z, images
212
+
213
+ # Rank images using CLIP
214
+ def rank_by_clip(images, prompt, top_k=3): # Reduce top_k for speed
215
+ images_pil = [Image.fromarray(img) for img in images]
216
+ inputs = clip_processor(text=[prompt], images=images_pil, return_tensors="pt", padding=True)
217
+
218
+ with torch.no_grad():
219
+ image_features = clip_model.get_image_features(pixel_values=inputs["pixel_values"])
220
+ text_features = clip_model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
221
+
222
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
223
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
224
+
225
+ similarity = (image_features @ text_features.T).squeeze()
226
+
227
+ top_indices = similarity.argsort(descending=True)[:top_k]
228
+ best_images = [images_pil[i] for i in top_indices]
229
+ return best_images
230
+
231
+ # Gradio interface function
232
+ def generate_top_dresses(prompt):
233
+ _, images = generate_images(G, num_images=20)
234
+ top_images = rank_by_clip(images, prompt, top_k=2)
235
+
236
  file_paths = []
237
+ for i, img in enumerate(top_images):
238
+ temp_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
239
+ img.save(temp_path)
240
+ file_paths.append(temp_path)
241
+
242
+ return top_images, file_paths
243
+
244
+ # Launch Gradio
245
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
246
+ gr.Markdown("""
247
+ # 👗 AI Top Generator
248
+ _Type in your dream outfit, and let the AI bring your fashion vision to life!_
249
+ Just describe and see how AI transforms your words into fashion.
250
+ """)
251
+
252
+ with gr.Row():
253
+ input_box = gr.Textbox(
254
+ label="Describe your Design",
255
+ placeholder="e.g., 'Black sleeveless crop top'",
256
+ lines=2
257
+ )
258
+
259
+ with gr.Row():
260
+ submit_button = gr.Button("Generate Designs")
261
+ user_id_state = gr.State()
262
+
263
+ @demo.load(inputs=None, outputs=[user_id_state])
264
+ def get_user_id(request: gr.Request):
265
+ return request.query_params.get("user_id", "")
266
+
267
+ image_components = []
268
+ file_paths = []
269
+ save_buttons = []
270
+ outputs = []
271
+
272
+ with gr.Row():
273
+ for i in range(2): # Only 2 images
274
+ with gr.Column():
275
+ img = gr.Image(width=180, height=180, label=f"Design {i+1}")
276
+ image_components.append(img)
277
+
278
+ file_path = gr.Textbox(visible=False)
279
+ file_paths.append(file_path)
280
+
281
+ save_btn = gr.Button("💾 Save to DB")
282
+ save_buttons.append(save_btn)
283
+
284
+ output = gr.Textbox(label="Status", interactive=False)
285
+ outputs.append(output)
286
+
287
+ save_btn.click(
288
+ fn=send_to_backend,
289
+ inputs=[file_path, user_id_state],
290
+ outputs=output
291
+ )
292
+
293
+
294
+ examples = gr.Examples(
295
+ examples = [
296
+ ["Striped crop top"],
297
+ ["Simple blue round-neck top with short sleeves"]
298
+ ],
299
+ inputs=[input_box]
300
+ )
301
+ # Generate button logicg
302
+ def generate_and_display_images(prompt):
303
+ images, paths = generate_top_dresses(prompt)
304
+ return images + paths
305
+
306
+
307
+ submit_button.click(
308
  fn=generate_and_display_images,
309
+ inputs=[input_box],
310
  outputs=image_components + file_paths
311
  )
312
 
313
+
314
+ demo.launch()