sahadev10 commited on
Commit
8cf9a64
·
verified ·
1 Parent(s): 4e3fb94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -55
app.py CHANGED
@@ -64,7 +64,6 @@
64
  # iface.launch()
65
 
66
 
67
-
68
  import gradio as gr
69
  import torch
70
  import numpy as np
@@ -74,91 +73,71 @@ import legacy
74
  import torch_utils
75
  import jwt
76
  import requests
77
- import tempfile
78
- from fastapi import Request
79
- from gradio.routes import app as fastapi_app
80
- from starlette.requests import Request as StarletteRequest
81
- from starlette.middleware.cors import CORSMiddleware
82
-
83
- # Allow frontend access
84
- fastapi_app.add_middleware(
85
- CORSMiddleware,
86
- allow_origins=["*"], # You can restrict to your frontend domain
87
- allow_credentials=True,
88
- allow_methods=["*"],
89
- allow_headers=["*"],
90
- )
91
 
92
- # Load StyleGAN model
93
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
94
- model_path = 'dress_model.pkl'
95
 
 
96
  with open(model_path, 'rb') as f:
97
  G = legacy.load_network_pkl(f)['G_ema'].to(device)
98
 
99
- # Helper to decode JWT from cookie
100
- def get_user_id_from_cookie(cookie_str):
101
- try:
102
- if 'access_token=' in cookie_str:
103
- token = cookie_str.split('access_token=')[1].split(';')[0]
104
- decoded = jwt.decode(token, 'your_jwt_secret', algorithms=['HS256'])
105
- return decoded.get('user_id')
106
- except Exception as e:
107
- print("JWT decode error:", e)
108
- return None
109
-
110
- # Style mixing
111
  def mix_styles(image1_path, image2_path, styles_to_mix):
 
112
  image1_name = os.path.splitext(os.path.basename(image1_path))[0]
113
  image2_name = os.path.splitext(os.path.basename(image2_path))[0]
114
 
 
115
  latent_vector_1 = np.load(os.path.join("projection_results", image1_name, "projected_w.npz"))['w']
116
  latent_vector_2 = np.load(os.path.join("projection_results", image2_name, "projected_w.npz"))['w']
117
 
 
118
  latent_1_tensor = torch.from_numpy(latent_vector_1).to(device)
119
  latent_2_tensor = torch.from_numpy(latent_vector_2).to(device)
120
 
 
121
  mixed_latent = latent_1_tensor.clone()
122
  mixed_latent[:, styles_to_mix] = latent_2_tensor[:, styles_to_mix]
123
 
 
124
  with torch.no_grad():
125
  image = G.synthesis(mixed_latent, noise_mode='const')
126
 
 
127
  image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
128
  mixed_image = Image.fromarray(image[0], 'RGB')
129
  return mixed_image
130
 
131
- # Save image to backend
132
- def upload_to_backend(img: Image.Image, user_id: str):
133
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
134
- img.save(tmp_file.name)
135
- with open(tmp_file.name, "rb") as f:
136
- response = requests.post(
137
- f"http://localhost:3000/customisation/upload/{user_id}",
138
- files={"file": f}
139
- )
140
- os.remove(tmp_file.name)
141
- return response.json() if response.ok else {"error": "Upload failed", "details": response.text}
142
-
143
- # Main Gradio function
144
- def style_mixing_interface(image1, image2, mix_value, request: StarletteRequest = None):
145
  if image1 is None or image2 is None:
146
  return None
147
 
148
- # Get user_id from JWT cookie
149
- cookie_header = request.headers.get('cookie', '')
150
- user_id = get_user_id_from_cookie(cookie_header)
151
- if not user_id:
152
- return "❌ Invalid or missing JWT. Please log in again."
 
 
 
153
 
154
  selected_layers = list(range(mix_value + 1))
155
- mixed_img = mix_styles(image1, image2, selected_layers)
156
 
157
- # Upload image to backend
158
- upload_response = upload_to_backend(mixed_img, user_id)
 
 
 
159
 
160
- print("Upload response:", upload_response)
161
- return mixed_img
 
 
 
 
162
 
163
  # Gradio UI
164
  iface = gr.Interface(
@@ -166,13 +145,16 @@ iface = gr.Interface(
166
  inputs=[
167
  gr.Image(label="First Clothing Image", type="filepath"),
168
  gr.Image(label="Second Clothing Image", type="filepath"),
169
- gr.Slider(label="Style Mixing Strength (Layers 0 to N)", minimum=0, maximum=9, step=1, value=5)
 
170
  ],
171
  outputs=gr.Image(label="Mixed Clothing Design"),
172
  live=True,
173
  title="Style Mixing for Clothing Design",
174
- description="Upload two projected images and choose how many early layers to mix."
175
  )
176
 
 
177
  iface.launch()
178
 
 
 
64
  # iface.launch()
65
 
66
 
 
67
  import gradio as gr
68
  import torch
69
  import numpy as np
 
73
  import torch_utils
74
  import jwt
75
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Load the pre-trained StyleGAN model
78
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
+ model_path = 'dress_model.pkl' # Place your .pkl in the same directory or update path
80
 
81
+ # Load StyleGAN Generator
82
  with open(model_path, 'rb') as f:
83
  G = legacy.load_network_pkl(f)['G_ema'].to(device)
84
 
85
+ # Function to mix styles of two clothing images
 
 
 
 
 
 
 
 
 
 
 
86
  def mix_styles(image1_path, image2_path, styles_to_mix):
87
+ # Extract image names (without extensions)
88
  image1_name = os.path.splitext(os.path.basename(image1_path))[0]
89
  image2_name = os.path.splitext(os.path.basename(image2_path))[0]
90
 
91
+ # Load latent vectors from .npz
92
  latent_vector_1 = np.load(os.path.join("projection_results", image1_name, "projected_w.npz"))['w']
93
  latent_vector_2 = np.load(os.path.join("projection_results", image2_name, "projected_w.npz"))['w']
94
 
95
+ # Convert to torch tensors
96
  latent_1_tensor = torch.from_numpy(latent_vector_1).to(device)
97
  latent_2_tensor = torch.from_numpy(latent_vector_2).to(device)
98
 
99
+ # Mix layers
100
  mixed_latent = latent_1_tensor.clone()
101
  mixed_latent[:, styles_to_mix] = latent_2_tensor[:, styles_to_mix]
102
 
103
+ # Generate image
104
  with torch.no_grad():
105
  image = G.synthesis(mixed_latent, noise_mode='const')
106
 
107
+ # Convert to image
108
  image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
109
  mixed_image = Image.fromarray(image[0], 'RGB')
110
  return mixed_image
111
 
112
+ # Function to handle style mixing via Gradio UI
113
+ def style_mixing_interface(image1, image2, mix_value, cookie):
 
 
 
 
 
 
 
 
 
 
 
 
114
  if image1 is None or image2 is None:
115
  return None
116
 
117
+ # Extract user_id from the JWT token passed via cookies (assuming JWT token is passed as 'cookie' in the request)
118
+ try:
119
+ decoded_token = jwt.decode(cookie, options={"verify_exp": False}) # Decode token without verifying expiration
120
+ user_id = decoded_token.get("user_id", None)
121
+ except jwt.ExpiredSignatureError:
122
+ return "Session expired, please log in again."
123
+ except jwt.InvalidTokenError:
124
+ return "Invalid token, please log in again."
125
 
126
  selected_layers = list(range(mix_value + 1))
127
+ mixed_image = mix_styles(image1, image2, selected_layers)
128
 
129
+ # Call backend API to save the image
130
+ if user_id:
131
+ upload_url = f"http://localhost:3000/customisation/upload/{user_id}"
132
+ files = {'file': ('mixed_image.png', mixed_image.tobytes(), 'image/png')}
133
+ response = requests.post(upload_url, files=files)
134
 
135
+ if response.status_code == 200:
136
+ return "Image uploaded successfully!"
137
+ else:
138
+ return f"Failed to upload image: {response.text}"
139
+ else:
140
+ return "User ID not found in token."
141
 
142
  # Gradio UI
143
  iface = gr.Interface(
 
145
  inputs=[
146
  gr.Image(label="First Clothing Image", type="filepath"),
147
  gr.Image(label="Second Clothing Image", type="filepath"),
148
+ gr.Slider(label="Style Mixing Strength (Layers 0 to N)", minimum=0, maximum=9, step=1, value=5),
149
+ gr.Textbox(label="JWT Token (as cookie)", type="text") # You may pass JWT token here for testing purposes
150
  ],
151
  outputs=gr.Image(label="Mixed Clothing Design"),
152
  live=True,
153
  title="Style Mixing for Clothing Design",
154
+ description="Upload two projected images and choose how many early layers to mix. The resulting image will be saved after mixing."
155
  )
156
 
157
+ # Launch the Gradio interface directly
158
  iface.launch()
159
 
160
+