gionuibk commited on
Commit
794839a
·
verified ·
1 Parent(s): c825a55

Upload models/pytorch_wrapper.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/pytorch_wrapper.py +8 -0
models/pytorch_wrapper.py CHANGED
@@ -66,6 +66,14 @@ class PyTorchModel:
66
  raise RuntimeError("Model not initialized.")
67
 
68
  try:
 
 
 
 
 
 
 
 
69
  # Convert to Tensor
70
  tensor_in = torch.from_numpy(input_data).float().to(self.device)
71
 
 
66
  raise RuntimeError("Model not initialized.")
67
 
68
  try:
69
+ # Handle dictionary input (ONNX style)
70
+ if isinstance(input_data, dict):
71
+ # Take the first value (assuming single input) or look for "input"
72
+ if "input" in input_data:
73
+ input_data = input_data["input"]
74
+ else:
75
+ input_data = list(input_data.values())[0]
76
+
77
  # Convert to Tensor
78
  tensor_in = torch.from_numpy(input_data).float().to(self.device)
79