sahadev10 commited on
Commit
1f03fe0
·
verified ·
1 Parent(s): 5d4f45d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -56
app.py CHANGED
@@ -65,7 +65,6 @@
65
 
66
 
67
 
68
-
69
  import gradio as gr
70
  import torch
71
  import numpy as np
@@ -73,94 +72,107 @@ from PIL import Image
73
  import os
74
  import legacy
75
  import torch_utils
 
76
  import requests
77
- import io
78
- import base64
79
-
80
- # Load the pre-trained StyleGAN model
 
 
 
 
 
 
 
 
 
 
 
 
81
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
82
- model_path = 'dress_model.pkl' # Place your .pkl in the same directory or update path
83
 
84
- # Load StyleGAN Generator
85
  with open(model_path, 'rb') as f:
86
  G = legacy.load_network_pkl(f)['G_ema'].to(device)
87
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def mix_styles(image1_path, image2_path, styles_to_mix):
89
- # Extract image names (without extensions)
90
  image1_name = os.path.splitext(os.path.basename(image1_path))[0]
91
  image2_name = os.path.splitext(os.path.basename(image2_path))[0]
92
 
93
- # Load latent vectors from .npz
94
  latent_vector_1 = np.load(os.path.join("projection_results", image1_name, "projected_w.npz"))['w']
95
  latent_vector_2 = np.load(os.path.join("projection_results", image2_name, "projected_w.npz"))['w']
96
 
97
- # Convert to torch tensors
98
  latent_1_tensor = torch.from_numpy(latent_vector_1).to(device)
99
  latent_2_tensor = torch.from_numpy(latent_vector_2).to(device)
100
 
101
- # Mix layers
102
  mixed_latent = latent_1_tensor.clone()
103
  mixed_latent[:, styles_to_mix] = latent_2_tensor[:, styles_to_mix]
104
 
105
- # Generate image
106
  with torch.no_grad():
107
  image = G.synthesis(mixed_latent, noise_mode='const')
108
 
109
- # Convert to image
110
  image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
111
  mixed_image = Image.fromarray(image[0], 'RGB')
112
  return mixed_image
113
 
114
- def style_mixing_interface(image1, image2, mix_value):
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  if image1 is None or image2 is None:
116
- return None, None
 
 
 
 
 
 
117
 
118
  selected_layers = list(range(mix_value + 1))
119
  mixed_img = mix_styles(image1, image2, selected_layers)
120
 
121
- # Convert to base64
122
- buffer = io.BytesIO()
123
- mixed_img.save(buffer, format="PNG")
124
- img_bytes = buffer.getvalue()
125
- img_base64 = base64.b64encode(img_bytes).decode("utf-8")
126
 
127
- return mixed_img, img_base64
128
-
129
- def send_to_backend(base64_img):
130
- try:
131
- response = requests.post(
132
- "http://localhost:3000/customisation/save", # Change if using different port/route
133
- json={"image": base64_img},
134
- timeout=10
135
- )
136
- if response.status_code == 200:
137
- return "✅ Saved to database!"
138
- else:
139
- return f"❌ Failed to save: {response.status_code} - {response.text}"
140
- except Exception as e:
141
- return f"⚠️ Error: {str(e)}"
142
 
143
  # Gradio UI
144
- with gr.Blocks(title="Style Mixing for Clothing Design") as iface:
145
- gr.Markdown("## Style Mixing for Clothing Design\nUpload two projected clothing images and select how many early layers to mix.")
146
-
147
- with gr.Row():
148
- image1_input = gr.Image(label="First Clothing Image", type="filepath")
149
- image2_input = gr.Image(label="Second Clothing Image", type="filepath")
150
-
151
- mix_slider = gr.Slider(label="Style Mixing Strength (Layers 0 to N)", minimum=0, maximum=9, step=1, value=5)
152
-
153
- output_image = gr.Image(label="Mixed Clothing Design")
154
- base64_output = gr.Textbox(visible=False)
155
-
156
- download_button = gr.Button("Download & Save to Database")
157
- save_status = gr.Textbox(label="Save Status", interactive=False)
158
-
159
- def mix_and_return(image1, image2, mix_value):
160
- return style_mixing_interface(image1, image2, mix_value)
161
-
162
- mix_slider.change(mix_and_return, inputs=[image1_input, image2_input, mix_slider], outputs=[output_image, base64_output])
163
-
164
- download_button.click(fn=send_to_backend, inputs=[base64_output], outputs=[save_status])
165
 
166
  iface.launch()
 
 
65
 
66
 
67
 
 
68
  import gradio as gr
69
  import torch
70
  import numpy as np
 
72
  import os
73
  import legacy
74
  import torch_utils
75
+ import jwt
76
  import requests
77
+ import tempfile
78
+ from fastapi import Request
79
+ from gradio.routes import app as fastapi_app
80
+ from starlette.requests import Request as StarletteRequest
81
+ from starlette.middleware.cors import CORSMiddleware
82
+
83
+ # Allow frontend access
84
+ fastapi_app.add_middleware(
85
+ CORSMiddleware,
86
+ allow_origins=["*"], # You can restrict to your frontend domain
87
+ allow_credentials=True,
88
+ allow_methods=["*"],
89
+ allow_headers=["*"],
90
+ )
91
+
92
+ # Load StyleGAN model
93
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
94
+ model_path = 'dress_model.pkl'
95
 
 
96
  with open(model_path, 'rb') as f:
97
  G = legacy.load_network_pkl(f)['G_ema'].to(device)
98
 
99
+ # Helper to decode JWT from cookie
100
+ def get_user_id_from_cookie(cookie_str):
101
+ try:
102
+ if 'access_token=' in cookie_str:
103
+ token = cookie_str.split('access_token=')[1].split(';')[0]
104
+ decoded = jwt.decode(token, 'your_jwt_secret', algorithms=['HS256'])
105
+ return decoded.get('user_id')
106
+ except Exception as e:
107
+ print("JWT decode error:", e)
108
+ return None
109
+
110
+ # Style mixing
111
  def mix_styles(image1_path, image2_path, styles_to_mix):
 
112
  image1_name = os.path.splitext(os.path.basename(image1_path))[0]
113
  image2_name = os.path.splitext(os.path.basename(image2_path))[0]
114
 
 
115
  latent_vector_1 = np.load(os.path.join("projection_results", image1_name, "projected_w.npz"))['w']
116
  latent_vector_2 = np.load(os.path.join("projection_results", image2_name, "projected_w.npz"))['w']
117
 
 
118
  latent_1_tensor = torch.from_numpy(latent_vector_1).to(device)
119
  latent_2_tensor = torch.from_numpy(latent_vector_2).to(device)
120
 
 
121
  mixed_latent = latent_1_tensor.clone()
122
  mixed_latent[:, styles_to_mix] = latent_2_tensor[:, styles_to_mix]
123
 
 
124
  with torch.no_grad():
125
  image = G.synthesis(mixed_latent, noise_mode='const')
126
 
 
127
  image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
128
  mixed_image = Image.fromarray(image[0], 'RGB')
129
  return mixed_image
130
 
131
+ # Save image to backend
132
+ def upload_to_backend(img: Image.Image, user_id: str):
133
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
134
+ img.save(tmp_file.name)
135
+ with open(tmp_file.name, "rb") as f:
136
+ response = requests.post(
137
+ f"http://localhost:3000/customisation/upload/{user_id}",
138
+ files={"file": f}
139
+ )
140
+ os.remove(tmp_file.name)
141
+ return response.json() if response.ok else {"error": "Upload failed", "details": response.text}
142
+
143
+ # Main Gradio function
144
+ def style_mixing_interface(image1, image2, mix_value, request: StarletteRequest = None):
145
  if image1 is None or image2 is None:
146
+ return None
147
+
148
+ # Get user_id from JWT cookie
149
+ cookie_header = request.headers.get('cookie', '')
150
+ user_id = get_user_id_from_cookie(cookie_header)
151
+ if not user_id:
152
+ return "❌ Invalid or missing JWT. Please log in again."
153
 
154
  selected_layers = list(range(mix_value + 1))
155
  mixed_img = mix_styles(image1, image2, selected_layers)
156
 
157
+ # Upload image to backend
158
+ upload_response = upload_to_backend(mixed_img, user_id)
 
 
 
159
 
160
+ print("Upload response:", upload_response)
161
+ return mixed_img
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  # Gradio UI
164
+ iface = gr.Interface(
165
+ fn=style_mixing_interface,
166
+ inputs=[
167
+ gr.Image(label="First Clothing Image", type="filepath"),
168
+ gr.Image(label="Second Clothing Image", type="filepath"),
169
+ gr.Slider(label="Style Mixing Strength (Layers 0 to N)", minimum=0, maximum=9, step=1, value=5)
170
+ ],
171
+ outputs=gr.Image(label="Mixed Clothing Design"),
172
+ live=True,
173
+ title="Style Mixing for Clothing Design",
174
+ description="Upload two projected images and choose how many early layers to mix."
175
+ )
 
 
 
 
 
 
 
 
 
176
 
177
  iface.launch()
178
+