nami0342 commited on
Commit
ffb6807
ยท
1 Parent(s): 939f91e

Warm up : model move to GPU when inference at first time

Browse files
Files changed (1) hide show
  1. app.py +32 -29
app.py CHANGED
@@ -232,33 +232,39 @@ else:
232
  print("\nโš  CPU warm-up completed with warnings")
233
  print("=" * 60 + "\n")
234
 
 
 
 
 
 
 
 
 
 
235
  # GPU Warm-up ํ•จ์ˆ˜ (์•ฑ ๋กœ๋“œ ์‹œ ์ž๋™ ์‹คํ–‰)
236
- # torch.compile() ์ฒซ ๋ฒˆ์งธ ์ปดํŒŒ์ผ์„ ๋ฏธ๋ฆฌ ์ˆ˜ํ–‰
237
  @spaces.GPU
238
  def warmup_gpu():
239
- """์•ฑ ๋กœ๋“œ ์‹œ GPU ๋ชจ๋ธ ์ดˆ๊ธฐํ™”๋ฅผ ์œ„ํ•œ Warm-up ํ•จ์ˆ˜ (torch.compile ์ฒซ ํ˜ธ์ถœ)"""
240
  try:
241
  device = "cuda"
242
  print("=" * 60)
243
- print("GPU Warm-up: Triggering torch.compile() first compilation...")
244
  print("=" * 60)
245
 
246
  # ๋ชจ๋ธ์„ GPU๋กœ ์ด๋™
 
247
  pipe.to(device)
248
  pipe.unet_encoder.to(device)
 
249
 
250
  # ๋”๋ฏธ ํ…์„œ ์ƒ์„ฑ
251
  with torch.no_grad():
252
  with torch.cuda.amp.autocast():
253
  # 1. ๋”๋ฏธ ํ”„๋กฌํ”„ํŠธ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ (Text Encoder GPU warm-up)
254
- print("[GPU Warm-up 1/3] Text Encoder GPU warm-up...")
255
  dummy_prompt = "a photo of white t-shirt"
256
- (
257
- prompt_embeds,
258
- negative_prompt_embeds,
259
- pooled_prompt_embeds,
260
- negative_pooled_prompt_embeds,
261
- ) = pipe.encode_prompt(
262
  dummy_prompt,
263
  num_images_per_prompt=1,
264
  do_classifier_free_guidance=True,
@@ -266,35 +272,31 @@ def warmup_gpu():
266
  )
267
  print("โœ“ Text Encoder GPU warmed up")
268
 
269
- # 2. ๋”๋ฏธ ์ด๋ฏธ์ง€๋กœ VAE ์ธ์ฝ”๋”ฉ (VAE GPU warm-up)
270
- print("[GPU Warm-up 2/3] VAE GPU warm-up...")
271
  dummy_img = torch.randn(1, 3, 1024, 768).to(device, torch.float16)
272
- _ = pipe.vae.encode(dummy_img)
273
- print("โœ“ VAE GPU warmed up")
 
274
 
275
- # 3. UNet ๊ฐ„๋‹จํ•œ forward pass (UNet + torch.compile warm-up)
276
- print("[GPU Warm-up 3/3] UNet GPU warm-up (torch.compile trigger)...")
277
- dummy_latent = torch.randn(1, 4, 128, 96).to(device, torch.float16)
278
- dummy_timestep = torch.tensor([999]).to(device)
279
- _ = pipe.unet(
280
- dummy_latent,
281
- dummy_timestep,
282
- encoder_hidden_states=prompt_embeds.to(device, torch.float16),
283
- )
284
- print("โœ“ UNet GPU warmed up (torch.compile triggered)")
285
 
286
  # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
287
  torch.cuda.empty_cache()
288
 
289
  print("\n" + "=" * 60)
290
- print("โœ“ GPU Warm-up completed! torch.compile() compilation done.")
291
- print(" All subsequent requests will be faster.")
 
292
  print("=" * 60 + "\n")
293
 
294
  return "GPU Warm-up completed successfully!"
295
  except Exception as e:
296
  print(f"\nโš  GPU Warm-up failed: {e}")
297
- print(" First user request will trigger compilation instead.")
298
  return f"GPU Warm-up skipped: {e}"
299
 
300
 
@@ -662,8 +664,9 @@ with image_blocks as demo:
662
  print("โœ“ Gradio Blocks created")
663
 
664
  gr.Markdown("## DXCO : GENAI-VTON")
665
- gr.Markdown("์ž„์„ฑ๋‚จ, ์œค์ง€์˜, ์กฐ๋ฏผ์ฃผ based on IDM-VTON")
666
- gr.Markdown("์ด๋ฏธ์ง€๋Š” 3:4๋น„์œจ(384x512 ๋˜๋Š” 768x1024)๋กœ ์˜ฌ๋ ค์ฃผ์„ธ์š”")
 
667
 
668
  with gr.Row():
669
  with gr.Column():
 
232
  print("\nโš  CPU warm-up completed with warnings")
233
  print("=" * 60 + "\n")
234
 
235
+ # torch.compile ์˜ค๋ฅ˜ ์‹œ eager ๋ชจ๋“œ๋กœ ํด๋ฐฑ ์„ค์ •
236
+ # ์ปค์Šคํ…€ UNet forward ๋ฉ”์„œ๋“œ ํ˜ธํ™˜์„ฑ ๋ฌธ์ œ ๋Œ€์‘
237
+ try:
238
+ import torch._dynamo
239
+ torch._dynamo.config.suppress_errors = True
240
+ print("โœ“ torch._dynamo.config.suppress_errors enabled (fallback to eager mode on error)")
241
+ except Exception as e:
242
+ print(f"โš  torch._dynamo config not available: {e}")
243
+
244
  # GPU Warm-up ํ•จ์ˆ˜ (์•ฑ ๋กœ๋“œ ์‹œ ์ž๋™ ์‹คํ–‰)
245
+ # Text Encoder, VAE GPU ๋กœ๋”ฉ ๋ฐ CUDA ์ปค๋„ ์ดˆ๊ธฐํ™”
246
  @spaces.GPU
247
  def warmup_gpu():
248
+ """์•ฑ ๋กœ๋“œ ์‹œ GPU ๋ชจ๋ธ ์ดˆ๊ธฐํ™”๋ฅผ ์œ„ํ•œ Warm-up ํ•จ์ˆ˜"""
249
  try:
250
  device = "cuda"
251
  print("=" * 60)
252
+ print("GPU Warm-up: Loading models to GPU and initializing CUDA kernels...")
253
  print("=" * 60)
254
 
255
  # ๋ชจ๋ธ์„ GPU๋กœ ์ด๋™
256
+ print("[GPU Warm-up 1/4] Moving models to GPU...")
257
  pipe.to(device)
258
  pipe.unet_encoder.to(device)
259
+ print("โœ“ Models moved to GPU")
260
 
261
  # ๋”๋ฏธ ํ…์„œ ์ƒ์„ฑ
262
  with torch.no_grad():
263
  with torch.cuda.amp.autocast():
264
  # 1. ๋”๋ฏธ ํ”„๋กฌํ”„ํŠธ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ (Text Encoder GPU warm-up)
265
+ print("[GPU Warm-up 2/4] Text Encoder GPU warm-up...")
266
  dummy_prompt = "a photo of white t-shirt"
267
+ _ = pipe.encode_prompt(
 
 
 
 
 
268
  dummy_prompt,
269
  num_images_per_prompt=1,
270
  do_classifier_free_guidance=True,
 
272
  )
273
  print("โœ“ Text Encoder GPU warmed up")
274
 
275
+ # 2. ๋”๋ฏธ ์ด๋ฏธ์ง€๋กœ VAE ์ธ์ฝ”๋”ฉ/๋””์ฝ”๋”ฉ (VAE GPU warm-up)
276
+ print("[GPU Warm-up 3/4] VAE GPU warm-up...")
277
  dummy_img = torch.randn(1, 3, 1024, 768).to(device, torch.float16)
278
+ latents = pipe.vae.encode(dummy_img).latent_dist.sample()
279
+ _ = pipe.vae.decode(latents)
280
+ print("โœ“ VAE GPU warmed up (encode + decode)")
281
 
282
+ # 3. CUDA ๋™๊ธฐํ™” (์ปค๋„ ๋กœ๋”ฉ ์™„๋ฃŒ ๋Œ€๊ธฐ)
283
+ print("[GPU Warm-up 4/4] CUDA synchronization...")
284
+ torch.cuda.synchronize()
285
+ print("โœ“ CUDA kernels initialized")
 
 
 
 
 
 
286
 
287
  # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
288
  torch.cuda.empty_cache()
289
 
290
  print("\n" + "=" * 60)
291
+ print("โœ“ GPU Warm-up completed!")
292
+ print(" Text Encoder, VAE ready. UNet will compile on first request.")
293
+ print(" (torch.compile errors will fallback to eager mode)")
294
  print("=" * 60 + "\n")
295
 
296
  return "GPU Warm-up completed successfully!"
297
  except Exception as e:
298
  print(f"\nโš  GPU Warm-up failed: {e}")
299
+ print(" Models will be loaded on first user request.")
300
  return f"GPU Warm-up skipped: {e}"
301
 
302
 
 
664
  print("โœ“ Gradio Blocks created")
665
 
666
  gr.Markdown("## DXCO : GENAI-VTON")
667
+ gr.Markdown("์ž„์„ฑ๋‚จ, ์œค์ง€์˜, ์กฐ๋ฏผ์ฃผ based on IDM-VTON")
668
+ gr.Markdown("* ๋งจ ์ฒ˜์Œ ์ถ”๋ก  ์‹œ [5๋ถ„] ๊ฑธ๋ฆผ - compile๊ณผ GPU warm-up *")
669
+ gr.Markdown("๊ถŒ์žฅ ์ด๋ฏธ์ง€ ์‚ฌ์ด์ฆˆ - 3:4๋น„์œจ(384x512,768x1024)")
670
 
671
  with gr.Row():
672
  with gr.Column():