fix output issue
Browse filesSigned-off-by: binliu <binliu@nvidia.com>
- hugging_face_pipeline.py +1 -1
- vista3d_pipeline.py +2 -0
hugging_face_pipeline.py
CHANGED
|
@@ -32,7 +32,7 @@ class HuggingFacePipelineHelper:
|
|
| 32 |
config_dict = kwargs.pop("config_dict", None)
|
| 33 |
self._update_config(config, config_dict)
|
| 34 |
model = VISTA3DModel(config)
|
| 35 |
-
model.from_pretrained(
|
| 36 |
pretrained_model_name_or_path=pretrained_model_name_or_path
|
| 37 |
)
|
| 38 |
return VISTA3DPipeline(model, **kwargs)
|
|
|
|
| 32 |
config_dict = kwargs.pop("config_dict", None)
|
| 33 |
self._update_config(config, config_dict)
|
| 34 |
model = VISTA3DModel(config)
|
| 35 |
+
model = model.from_pretrained(
|
| 36 |
pretrained_model_name_or_path=pretrained_model_name_or_path
|
| 37 |
)
|
| 38 |
return VISTA3DPipeline(model, **kwargs)
|
vista3d_pipeline.py
CHANGED
|
@@ -433,6 +433,8 @@ class VISTA3DPipeline(Pipeline):
|
|
| 433 |
return outputs
|
| 434 |
|
| 435 |
def postprocess(self, outputs, **kwargs):
|
|
|
|
|
|
|
| 436 |
for key, value in kwargs.items():
|
| 437 |
if key not in self.POSTPROCESSING_EXTRA_ARGS:
|
| 438 |
logging.warning(f"Cannot set parameter {key} for postprocessing.")
|
|
|
|
| 433 |
return outputs
|
| 434 |
|
| 435 |
def postprocess(self, outputs, **kwargs):
|
| 436 |
+
outputs[Keys.IMAGE] = outputs[Keys.IMAGE].to(self.device)
|
| 437 |
+
outputs[Keys.PRED] = outputs[Keys.PRED].to(self.device)
|
| 438 |
for key, value in kwargs.items():
|
| 439 |
if key not in self.POSTPROCESSING_EXTRA_ARGS:
|
| 440 |
logging.warning(f"Cannot set parameter {key} for postprocessing.")
|