akhaliq HF Staff commited on
Commit
4aa756a
·
verified ·
1 Parent(s): 235cb1f

Update app.py from anycoder

Browse files
Files changed (1) hide show
  1. app.py +43 -16
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  from huggingface_hub import hf_hub_download
 
3
  os.system("pip -qq install facenet_pytorch")
4
  from facenet_pytorch import MTCNN
5
  from torchvision import transforms
@@ -62,41 +63,53 @@ size = 256
62
  means = [0.485, 0.456, 0.406]
63
  stds = [0.229, 0.224, 0.225]
64
 
65
- t_stds = torch.tensor(stds).cuda().half()[:,None,None]
66
- t_means = torch.tensor(means).cuda().half()[:,None,None]
67
-
68
  img_transforms = transforms.Compose([
69
  transforms.ToTensor(),
70
  transforms.Normalize(means,stds)
71
  ])
72
-
73
- def tensor2im(var):
74
- return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
75
 
 
 
 
 
 
 
76
  def proc_pil_img(input_image, model):
 
 
 
 
 
 
 
 
77
  transformed_image = img_transforms(input_image)[None,...].cuda().half()
 
78
  with torch.no_grad():
79
  result_image = model(transformed_image)[0]
80
- output_image = tensor2im(result_image)
81
  output_image = output_image.detach().cpu().numpy().astype('uint8')
82
  output_image = PIL.Image.fromarray(output_image)
 
83
  return output_image
84
 
85
- # Load models
86
- modelv4 = torch.jit.load(modelarcanev4).eval().cuda().half()
87
- modelv3 = torch.jit.load(modelarcanev3).eval().cuda().half()
88
- modelv2 = torch.jit.load(modelarcanev2).eval().cuda().half()
89
-
90
  def process(im, version):
 
 
 
 
 
 
 
 
91
  if version == 'v0.4 (Recommended)':
92
- im = scale_by_face_size(im, target_face=256, max_res=1_500_000, max_upscale=1)
93
  res = proc_pil_img(im, modelv4)
94
  elif version == 'v0.3':
95
- im = scale_by_face_size(im, target_face=256, max_res=1_500_000, max_upscale=1)
96
  res = proc_pil_img(im, modelv3)
97
  else:
98
- im = scale_by_face_size(im, target_face=256, max_res=1_500_000, max_upscale=1)
99
  res = proc_pil_img(im, modelv2)
 
100
  return res
101
 
102
  # Custom theme
@@ -184,6 +197,16 @@ custom_css = """
184
  .example-container {
185
  margin-top: 1rem;
186
  }
 
 
 
 
 
 
 
 
 
 
187
  """
188
 
189
  # Build the interface
@@ -196,6 +219,8 @@ with gr.Blocks() as demo:
196
  # 🎨 ArcaneGAN
197
  ### Transform Your Photos into Arcane-Style Art
198
  Upload a portrait and watch it transform into the stunning visual style of Netflix's Arcane series.
 
 
199
  """
200
  )
201
  gr.Markdown(
@@ -263,6 +288,8 @@ with gr.Blocks() as demo:
263
  [GitHub Repository](https://github.com/Sxela/ArcaneGAN) |
264
  [Original Space](https://huggingface.co/spaces/akhaliq/ArcaneGAN)
265
 
 
 
266
  <div style='margin-top: 1rem;'>
267
  <img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_arcanegan' alt='visitor badge'>
268
  </div>
@@ -274,7 +301,7 @@ with gr.Blocks() as demo:
274
  fn=process,
275
  inputs=[input_image, version_selector],
276
  outputs=output_image,
277
- api_name="transform"
278
  )
279
 
280
  input_image.upload(
 
1
  import os
2
  from huggingface_hub import hf_hub_download
3
+ import spaces
4
  os.system("pip -qq install facenet_pytorch")
5
  from facenet_pytorch import MTCNN
6
  from torchvision import transforms
 
63
  means = [0.485, 0.456, 0.406]
64
  stds = [0.229, 0.224, 0.225]
65
 
 
 
 
66
  img_transforms = transforms.Compose([
67
  transforms.ToTensor(),
68
  transforms.Normalize(means,stds)
69
  ])
 
 
 
70
 
71
+ # Load models globally (outside GPU-decorated functions)
72
+ modelv4 = torch.jit.load(modelarcanev4).eval()
73
+ modelv3 = torch.jit.load(modelarcanev3).eval()
74
+ modelv2 = torch.jit.load(modelarcanev2).eval()
75
+
76
+ @spaces.GPU
77
  def proc_pil_img(input_image, model):
78
+ """GPU-accelerated image processing"""
79
+ # Move tensors to GPU inside the decorated function
80
+ t_stds = torch.tensor(stds).cuda().half()[:,None,None]
81
+ t_means = torch.tensor(means).cuda().half()[:,None,None]
82
+
83
+ # Move model to GPU
84
+ model = model.cuda().half()
85
+
86
  transformed_image = img_transforms(input_image)[None,...].cuda().half()
87
+
88
  with torch.no_grad():
89
  result_image = model(transformed_image)[0]
90
+ output_image = result_image.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
91
  output_image = output_image.detach().cpu().numpy().astype('uint8')
92
  output_image = PIL.Image.fromarray(output_image)
93
+
94
  return output_image
95
 
96
+ @spaces.GPU
 
 
 
 
97
  def process(im, version):
98
+ """Main processing function with GPU acceleration"""
99
+ if im is None:
100
+ return None
101
+
102
+ # Scale image (CPU operation)
103
+ im = scale_by_face_size(im, target_face=256, max_res=1_500_000, max_upscale=1)
104
+
105
+ # Select model based on version
106
  if version == 'v0.4 (Recommended)':
 
107
  res = proc_pil_img(im, modelv4)
108
  elif version == 'v0.3':
 
109
  res = proc_pil_img(im, modelv3)
110
  else:
 
111
  res = proc_pil_img(im, modelv2)
112
+
113
  return res
114
 
115
  # Custom theme
 
197
  .example-container {
198
  margin-top: 1rem;
199
  }
200
+
201
+ .gpu-badge {
202
+ display: inline-block;
203
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
204
+ color: white;
205
+ padding: 0.5rem 1rem;
206
+ border-radius: 20px;
207
+ font-weight: 600;
208
+ margin-top: 0.5rem;
209
+ }
210
  """
211
 
212
  # Build the interface
 
219
  # 🎨 ArcaneGAN
220
  ### Transform Your Photos into Arcane-Style Art
221
  Upload a portrait and watch it transform into the stunning visual style of Netflix's Arcane series.
222
+
223
+ <span class="gpu-badge">⚡ Powered by Zero-GPU</span>
224
  """
225
  )
226
  gr.Markdown(
 
288
  [GitHub Repository](https://github.com/Sxela/ArcaneGAN) |
289
  [Original Space](https://huggingface.co/spaces/akhaliq/ArcaneGAN)
290
 
291
+ **⚡ Zero-GPU Optimization**: This Space uses Hugging Face's Zero-GPU infrastructure for efficient GPU allocation.
292
+
293
  <div style='margin-top: 1rem;'>
294
  <img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_arcanegan' alt='visitor badge'>
295
  </div>
 
301
  fn=process,
302
  inputs=[input_image, version_selector],
303
  outputs=output_image,
304
+ api_visibility="public"
305
  )
306
 
307
  input_image.upload(