vsamasworm commited on
Commit
ee8949b
·
1 Parent(s): 9746db7

global model

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. inference.py +5 -3
app.py CHANGED
@@ -21,7 +21,7 @@ print(ckpt_path)
21
  # else:
22
  # mark_dtype = torch.float16
23
  # device = torch.device('cpu')
24
- mark_dtype = torch.float16
25
  device = torch.device('cuda')
26
 
27
  model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
@@ -69,7 +69,8 @@ def run_inference(image_ref, image_tgt, do_rm_bkg):
69
  pil_ref = background_preprocess(pil_ref, True)
70
 
71
  try:
72
- ans_dict = inf_single_case(model, pil_ref, pil_tgt)
 
73
  except Exception as e:
74
  print("Inference error:", e)
75
  raise gr.Error(f"Inference failed: {str(e)}")
 
21
  # else:
22
  # mark_dtype = torch.float16
23
  # device = torch.device('cpu')
24
+ mark_dtype = torch.bfloat16
25
  device = torch.device('cuda')
26
 
27
  model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
 
69
  pil_ref = background_preprocess(pil_ref, True)
70
 
71
  try:
72
+ # ans_dict = inf_single_case(model, pil_ref, pil_tgt)
73
+ ans_dict = inf_single_case(pil_ref, pil_tgt)
74
  except Exception as e:
75
  print("Inference error:", e)
76
  raise gr.Error(f"Inference failed: {str(e)}")
inference.py CHANGED
@@ -177,7 +177,8 @@ def preprocess_images(image_list, mode="crop"):
177
  return images
178
 
179
  @torch.no_grad()
180
- def inf_single_batch(model, batch):
 
181
  device = model.get_device()
182
  batch_img_inputs = batch # (B, S, 3, H, W)
183
  # print(batch_img_inputs.shape)
@@ -229,12 +230,13 @@ def inf_single_batch(model, batch):
229
  # input PIL Image
230
  @spaces.GPU
231
  @torch.no_grad()
232
- def inf_single_case(model, image_ref, image_tgt):
 
233
  if image_tgt is None:
234
  image_list = [image_ref]
235
  else:
236
  image_list = [image_ref, image_tgt]
237
  image_tensors = preprocess_images(image_list, mode="pad").to('cuda')
238
- ans_dict = inf_single_batch(model=model, batch=image_tensors.unsqueeze(0))
239
  print(ans_dict)
240
  return ans_dict
 
177
  return images
178
 
179
  @torch.no_grad()
180
+ def inf_single_batch(batch):
181
+ global model
182
  device = model.get_device()
183
  batch_img_inputs = batch # (B, S, 3, H, W)
184
  # print(batch_img_inputs.shape)
 
230
  # input PIL Image
231
  @spaces.GPU
232
  @torch.no_grad()
233
+ def inf_single_case(image_ref, image_tgt):
234
+ global model
235
  if image_tgt is None:
236
  image_list = [image_ref]
237
  else:
238
  image_list = [image_ref, image_tgt]
239
  image_tensors = preprocess_images(image_list, mode="pad").to('cuda')
240
+ ans_dict = inf_single_batch(batch=image_tensors.unsqueeze(0))
241
  print(ans_dict)
242
  return ans_dict