tedlasai commited on
Commit
3729b71
·
1 Parent(s): 1344025
Files changed (2) hide show
  1. app.py +1 -1
  2. simple_inference.py +22 -26
app.py CHANGED
@@ -44,7 +44,7 @@ def generate_vstack_from_image(image: Image.Image, input_focal_position: int, nu
44
  args.device = "cuda"
45
 
46
  pipe.to(args.device)
47
- batch = convert_to_batch(args.image_path, input_focal_position=input_focal_position)
48
  output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
49
  save_dir = os.path.join(OUTPUT_DIR, batch['name'])
50
 
 
44
  args.device = "cuda"
45
 
46
  pipe.to(args.device)
47
+ batch = convert_to_batch(image, input_focal_position=input_focal_position)
48
  output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
49
  save_dir = os.path.join(OUTPUT_DIR, batch['name'])
50
 
simple_inference.py CHANGED
@@ -124,29 +124,26 @@ def find_scale(height, width):
124
  # Reduce the scale slightly
125
  scale -= 0.01
126
 
127
- def convert_to_batch(image, input_focal_position, sample_frames=9):
128
- scene, focal_stack_num = image, input_focal_position
129
- from PIL import Image
130
- with Image.open(scene) as img:
131
-
132
- icc_profile = img.info.get("icc_profile")
133
- if icc_profile is None:
134
- icc_profile = "none"
135
- original_pixels = torch.from_numpy(np.array(img)).float().permute(2,0,1)
136
- original_pixels = original_pixels / 255
137
- width, height = img.size
138
- scaled_width, scaled_height = find_scale(width, height)
139
-
140
- img_resized = img.resize((scaled_width, scaled_height))
141
- img_tensor = torch.from_numpy(np.array(img_resized)).float()
142
- img_normalized = img_tensor / 127.5 - 1
143
- img_normalized = img_normalized.permute(2, 0, 1)
144
-
145
- pixels = torch.zeros((1, sample_frames, 3, scaled_height, scaled_width))
146
- pixels[0, focal_stack_num] = img_normalized
147
-
148
- name = os.path.splitext(os.path.basename(scene))[0]
149
- return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
150
 
151
 
152
  def inference_on_image(args, batch, pipeline, device):
@@ -240,12 +237,11 @@ def main():
240
  if args.output_dir is not None:
241
  os.makedirs(args.output_dir, exist_ok=True)
242
 
243
-
244
  pipeline, device = load_model(args)
245
 
246
- batch = convert_to_batch(args.image_path, input_focal_position=6)
247
-
248
  with torch.no_grad():
 
 
249
  output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
250
  val_save_dir = os.path.join(args.output_dir, "validation_images", batch['name'])
251
  write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
 
124
  # Reduce the scale slightly
125
  scale -= 0.01
126
 
127
+
128
+ def convert_to_batch(img, input_focal_position, sample_frames=9):
129
+ focal_stack_num = input_focal_position
130
+ icc_profile = img.info.get("icc_profile")
131
+ if icc_profile is None:
132
+ icc_profile = "none"
133
+ original_pixels = torch.from_numpy(np.array(img)).float().permute(2,0,1)
134
+ original_pixels = original_pixels / 255
135
+ width, height = img.size
136
+ scaled_width, scaled_height = find_scale(width, height)
137
+
138
+ img_resized = img.resize((scaled_width, scaled_height))
139
+ img_tensor = torch.from_numpy(np.array(img_resized)).float()
140
+ img_normalized = img_tensor / 127.5 - 1
141
+ img_normalized = img_normalized.permute(2, 0, 1)
142
+
143
+ pixels = torch.zeros((1, sample_frames, 3, scaled_height, scaled_width))
144
+ pixels[0, focal_stack_num] = img_normalized
145
+
146
+ return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile}
 
 
 
147
 
148
 
149
  def inference_on_image(args, batch, pipeline, device):
 
237
  if args.output_dir is not None:
238
  os.makedirs(args.output_dir, exist_ok=True)
239
 
 
240
  pipeline, device = load_model(args)
241
 
 
 
242
  with torch.no_grad():
243
+ img = Image.open(args.image_path)
244
+ batch = convert_to_batch(img, input_focal_position=6)
245
  output_frames, focal_stack_num = inference_on_image(args, batch, pipeline, device)
246
  val_save_dir = os.path.join(args.output_dir, "validation_images", batch['name'])
247
  write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])