sahadev10 commited on
Commit
0331cd5
·
verified ·
1 Parent(s): 10a7eb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -6
app.py CHANGED
@@ -64,7 +64,6 @@
64
  # iface.launch()
65
 
66
 
67
-
68
  import gradio as gr
69
  import torch
70
  import numpy as np
@@ -76,13 +75,10 @@ import requests
76
  import io
77
  import warnings
78
 
79
- # Suppress deprecated torch warnings
80
  warnings.filterwarnings("ignore")
81
-
82
- # --- Load the pre-trained StyleGAN model ---
83
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
- model_path = 'dress_model.pkl'
85
 
 
86
  with open(model_path, 'rb') as f:
87
  G = legacy.load_network_pkl(f)['G_ema'].to(device)
88
 
@@ -117,7 +113,100 @@ def style_mixing_interface(image1, image2, mix_value):
117
  buffer.seek(0)
118
  return mixed_img, buffer
119
 
120
- def send_to_backend(image_buffer, "74"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if not user_id:
122
  return "❌ user_id not found."
123
 
 
64
  # iface.launch()
65
 
66
 
 
67
  import gradio as gr
68
  import torch
69
  import numpy as np
 
75
  import io
76
  import warnings
77
 
 
78
  warnings.filterwarnings("ignore")
 
 
79
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
80
 
81
+ model_path = 'dress_model.pkl'
82
  with open(model_path, 'rb') as f:
83
  G = legacy.load_network_pkl(f)['G_ema'].to(device)
84
 
 
113
  buffer.seek(0)
114
  return mixed_img, buffer
115
 
116
+ def send_to_backend(image_buffer, user_id):
117
+ if not user_id:
118
+ return "❌ user_id not found in URL."
119
+
120
+ try:
121
+ files = {'file': ('generated_image.png', image_buffer, 'image/png')}
122
+ url = f"https://361d-103-40-74-78.ngrok-free.app/customisation/upload/{user_id}"
123
+
124
+ response = requests.post(url, files=files)
125
+
126
+ if response.status_code == 201:
127
+ return "✅ Image uploaded and saved to database!"
128
+ else:
129
+ return f"❌ Upload failed: {response.status_code} - {response.text}"
130
+
131
+ except Exception as e:
132
+ return f"⚠️ Error: {str(e)}"
133
+
134
+ with gr.Blocks(title="Style Mixing for Clothing Design") as iface:
135
+ user_id_state = gr.State()
136
+
137
+ gr.Markdown("## Style Mixing for Clothing Design\nUpload two projected clothing images and mix their styles.")
138
+
139
+ with gr.Row():
140
+ image1_input = gr.Image(label="First Clothing Image", type="filepath")
141
+ image2_input = gr.Image(label="Second Clothing Image", type="filepath")
142
+
143
+ mix_slider = gr.Slider(label="Style Mixing Strength (Layers 0 to N)", minimum=0, maximum=9, step=1, value=5)
144
+
145
+ with gr.Row():
146
+ output_image = gr.Image(label="Mixed Clothing Design")
147
+ save_button = gr.Button("Download & Save to Database")
148
+
149
+ image_buffer = gr.State()
150
+ save_status = gr.Textbox(label="Save Status", interactive=False)
151
+
152
+ def mix_and_store(image1, image2, mix_value):
153
+ result_image, buffer = style_mixing_interface(image1, image2, mix_value)
154
+ return result_image, buffer
155
+
156
+ mix_slider.change(
157
+ mix_and_store,
158
+ inputs=[image1_input, image2_input, mix_slider],
159
+ outputs=[output_image, image_buffer]
160
+ )
161
+
162
+ save_button.click(
163
+ send_to_backend,
164
+ inputs=[image_buffer, user_id_state],
165
+ outputs=[save_status]
166
+ )
167
+
168
+ # Initialization function that extracts user_id
169
+ def init_fn(request: gr.Request):
170
+ user_id = request.query_params.get("user_id", "")
171
+ return {user_id_state: user_id}
172
+
173
+ iface.load(fn=None, inputs=None, outputs=None, preprocess=False, queue=False, show_progress=False)
174
+ iface.launch(initialize=init_fn)
175
+ (model_path, 'rb') as f:
176
+ G = legacy.load_network_pkl(f)['G_ema'].to(device)
177
+
178
+ def mix_styles(image1_path, image2_path, styles_to_mix):
179
+ image1_name = os.path.splitext(os.path.basename(image1_path))[0]
180
+ image2_name = os.path.splitext(os.path.basename(image2_path))[0]
181
+
182
+ latent_vector_1 = np.load(os.path.join("projection_results", image1_name, "projected_w.npz"))['w']
183
+ latent_vector_2 = np.load(os.path.join("projection_results", image2_name, "projected_w.npz"))['w']
184
+
185
+ latent_1_tensor = torch.from_numpy(latent_vector_1).to(device)
186
+ latent_2_tensor = torch.from_numpy(latent_vector_2).to(device)
187
+
188
+ mixed_latent = latent_1_tensor.clone()
189
+ mixed_latent[:, styles_to_mix] = latent_2_tensor[:, styles_to_mix]
190
+
191
+ with torch.no_grad():
192
+ image = G.synthesis(mixed_latent, noise_mode='const')
193
+
194
+ image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
195
+ mixed_image = Image.fromarray(image[0], 'RGB')
196
+ return mixed_image
197
+
198
+ def style_mixing_interface(image1, image2, mix_value):
199
+ if image1 is None or image2 is None:
200
+ return None, None
201
+ selected_layers = list(range(mix_value + 1))
202
+ mixed_img = mix_styles(image1, image2, selected_layers)
203
+
204
+ buffer = io.BytesIO()
205
+ mixed_img.save(buffer, format="PNG")
206
+ buffer.seek(0)
207
+ return mixed_img, buffer
208
+
209
+ def send_to_backend(image_buffer, user_id):
210
  if not user_id:
211
  return "❌ user_id not found."
212