Commit ·
ee8949b
1
Parent(s): 9746db7
global model
Browse files- app.py +3 -2
- inference.py +5 -3
app.py
CHANGED
|
@@ -21,7 +21,7 @@ print(ckpt_path)
|
|
| 21 |
# else:
|
| 22 |
# mark_dtype = torch.float16
|
| 23 |
# device = torch.device('cpu')
|
| 24 |
-
mark_dtype = torch.
|
| 25 |
device = torch.device('cuda')
|
| 26 |
|
| 27 |
model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
|
|
@@ -69,7 +69,8 @@ def run_inference(image_ref, image_tgt, do_rm_bkg):
|
|
| 69 |
pil_ref = background_preprocess(pil_ref, True)
|
| 70 |
|
| 71 |
try:
|
| 72 |
-
ans_dict = inf_single_case(model, pil_ref, pil_tgt)
|
|
|
|
| 73 |
except Exception as e:
|
| 74 |
print("Inference error:", e)
|
| 75 |
raise gr.Error(f"Inference failed: {str(e)}")
|
|
|
|
| 21 |
# else:
|
| 22 |
# mark_dtype = torch.float16
|
| 23 |
# device = torch.device('cpu')
|
| 24 |
+
mark_dtype = torch.bfloat16
|
| 25 |
device = torch.device('cuda')
|
| 26 |
|
| 27 |
model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
|
|
|
|
| 69 |
pil_ref = background_preprocess(pil_ref, True)
|
| 70 |
|
| 71 |
try:
|
| 72 |
+
# ans_dict = inf_single_case(model, pil_ref, pil_tgt)
|
| 73 |
+
ans_dict = inf_single_case(pil_ref, pil_tgt)
|
| 74 |
except Exception as e:
|
| 75 |
print("Inference error:", e)
|
| 76 |
raise gr.Error(f"Inference failed: {str(e)}")
|
inference.py
CHANGED
|
@@ -177,7 +177,8 @@ def preprocess_images(image_list, mode="crop"):
|
|
| 177 |
return images
|
| 178 |
|
| 179 |
@torch.no_grad()
|
| 180 |
-
def inf_single_batch(
|
|
|
|
| 181 |
device = model.get_device()
|
| 182 |
batch_img_inputs = batch # (B, S, 3, H, W)
|
| 183 |
# print(batch_img_inputs.shape)
|
|
@@ -229,12 +230,13 @@ def inf_single_batch(model, batch):
|
|
| 229 |
# input PIL Image
|
| 230 |
@spaces.GPU
|
| 231 |
@torch.no_grad()
|
| 232 |
-
def inf_single_case(
|
|
|
|
| 233 |
if image_tgt is None:
|
| 234 |
image_list = [image_ref]
|
| 235 |
else:
|
| 236 |
image_list = [image_ref, image_tgt]
|
| 237 |
image_tensors = preprocess_images(image_list, mode="pad").to('cuda')
|
| 238 |
-
ans_dict = inf_single_batch(
|
| 239 |
print(ans_dict)
|
| 240 |
return ans_dict
|
|
|
|
| 177 |
return images
|
| 178 |
|
| 179 |
@torch.no_grad()
|
| 180 |
+
def inf_single_batch(batch):
|
| 181 |
+
global model
|
| 182 |
device = model.get_device()
|
| 183 |
batch_img_inputs = batch # (B, S, 3, H, W)
|
| 184 |
# print(batch_img_inputs.shape)
|
|
|
|
| 230 |
# input PIL Image
|
| 231 |
@spaces.GPU
|
| 232 |
@torch.no_grad()
|
| 233 |
+
def inf_single_case(image_ref, image_tgt):
|
| 234 |
+
global model
|
| 235 |
if image_tgt is None:
|
| 236 |
image_list = [image_ref]
|
| 237 |
else:
|
| 238 |
image_list = [image_ref, image_tgt]
|
| 239 |
image_tensors = preprocess_images(image_list, mode="pad").to('cuda')
|
| 240 |
+
ans_dict = inf_single_batch(batch=image_tensors.unsqueeze(0))
|
| 241 |
print(ans_dict)
|
| 242 |
return ans_dict
|