Fabrice-TIERCELIN commited on
Commit
e52b043
·
verified ·
1 Parent(s): 7d6f890

payload["constants_map"] = constants

Browse files
Files changed (1) hide show
  1. optimization_utils.py +4 -1
optimization_utils.py CHANGED
@@ -86,6 +86,9 @@ class ZeroGPUCompiledModelFromDict:
86
  if payload.get("format") != "zerogpu_aoti_v1":
87
  raise ValueError(f"Unsupported payload format: {payload.get('format')}")
88
  self.archive_file = BytesIO(payload["archive_bytes"])
 
 
 
89
  self.constants_map_cpu: dict[str, torch.Tensor] = payload["constants_map"]
90
  self.device = device
91
  self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar("compiled_model", default=None)
@@ -98,7 +101,7 @@ class ZeroGPUCompiledModelFromDict:
98
 
99
  if not self._loaded_constants:
100
  # Move constants to target device (cuda) and keep dtype as-is (bf16)
101
- constants_map = {k: v.to(self.device, non_blocking=True) for k, v in self.constants_map_cpu.items()}
102
  compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
103
  self._loaded_constants = True
104
 
 
86
  if payload.get("format") != "zerogpu_aoti_v1":
87
  raise ValueError(f"Unsupported payload format: {payload.get('format')}")
88
  self.archive_file = BytesIO(payload["archive_bytes"])
89
+ constants = payload["constants_map"]
90
+ constants = {k: v.to(device=device, dtype=torch.bfloat16).contiguous() for k, v in constants.items()}
91
+ payload["constants_map"] = constants
92
  self.constants_map_cpu: dict[str, torch.Tensor] = payload["constants_map"]
93
  self.device = device
94
  self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar("compiled_model", default=None)
 
101
 
102
  if not self._loaded_constants:
103
  # Move constants to target device (cuda) and keep dtype as-is (bf16)
104
+ constants_map = {k: v.to(device="cuda", dtype=torch.bfloat16).contiguous() for k, v in self.constants_map_cpu.items()}
105
  compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
106
  self._loaded_constants = True
107