akhaliq HF Staff commited on
Commit
877acf3
·
verified ·
1 Parent(s): f12153a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -68,30 +68,26 @@ stds = [0.229, 0.224, 0.225]
68
 
69
  img_transforms = transforms.Compose([
70
  transforms.ToTensor(),
71
- transforms.Normalize(means,stds)
72
  ])
73
 
74
- # Load models globally on CPU
75
- modelv4_cpu = torch.jit.load(modelarcanev4, map_location='cpu').eval()
76
- modelv3_cpu = torch.jit.load(modelarcanev3, map_location='cpu').eval()
77
- modelv2_cpu = torch.jit.load(modelarcanev2, map_location='cpu').eval()
78
-
79
  @spaces.GPU
80
  def proc_pil_img(input_image, model_path):
81
- """GPU-accelerated image processing"""
82
- # Load model fresh on GPU to avoid device mismatch
83
  model = torch.jit.load(model_path, map_location='cuda').eval()
84
 
85
- # Create tensors on GPU
86
- t_stds = torch.tensor(stds).cuda().view(3, 1, 1)
87
- t_means = torch.tensor(means).cuda().view(3, 1, 1)
88
 
89
- # Transform and move to GPU
90
- transformed_image = img_transforms(input_image).unsqueeze(0).cuda()
91
 
92
  with torch.no_grad():
93
  result_image = model(transformed_image)[0]
94
- output_image = result_image.mul(t_stds).add(t_means).mul(255.).clamp(0, 255).permute(1, 2, 0)
 
95
  output_image = output_image.cpu().numpy().astype('uint8')
96
  output_image = PIL.Image.fromarray(output_image)
97
 
 
68
 
69
  img_transforms = transforms.Compose([
70
  transforms.ToTensor(),
71
+ transforms.Normalize(means, stds)
72
  ])
73
 
 
 
 
 
 
74
  @spaces.GPU
75
  def proc_pil_img(input_image, model_path):
76
+ """GPU-accelerated image processing with half precision support"""
77
+ # Load model on GPU
78
  model = torch.jit.load(model_path, map_location='cuda').eval()
79
 
80
+ # Create tensors on GPU in half precision to match model
81
+ t_stds = torch.tensor(stds).cuda().half().view(3, 1, 1)
82
+ t_means = torch.tensor(means).cuda().half().view(3, 1, 1)
83
 
84
+ # Transform image and move to GPU with half precision
85
+ transformed_image = img_transforms(input_image).unsqueeze(0).cuda().half()
86
 
87
  with torch.no_grad():
88
  result_image = model(transformed_image)[0]
89
+ # Convert back to float for post-processing
90
+ output_image = result_image.float().mul(t_stds.float()).add(t_means.float()).mul(255.).clamp(0, 255).permute(1, 2, 0)
91
  output_image = output_image.cpu().numpy().astype('uint8')
92
  output_image = PIL.Image.fromarray(output_image)
93