Spaces:
Runtime error
Runtime error
payload["constants_map"] = constants
Browse files- optimization_utils.py +4 -1
optimization_utils.py
CHANGED
|
@@ -86,6 +86,9 @@ class ZeroGPUCompiledModelFromDict:
|
|
| 86 |
if payload.get("format") != "zerogpu_aoti_v1":
|
| 87 |
raise ValueError(f"Unsupported payload format: {payload.get('format')}")
|
| 88 |
self.archive_file = BytesIO(payload["archive_bytes"])
|
|
|
|
|
|
|
|
|
|
| 89 |
self.constants_map_cpu: dict[str, torch.Tensor] = payload["constants_map"]
|
| 90 |
self.device = device
|
| 91 |
self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar("compiled_model", default=None)
|
|
@@ -98,7 +101,7 @@ class ZeroGPUCompiledModelFromDict:
|
|
| 98 |
|
| 99 |
if not self._loaded_constants:
|
| 100 |
# Move constants to target device (cuda) and keep dtype as-is (bf16)
|
| 101 |
-
constants_map = {k: v.to(
|
| 102 |
compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
|
| 103 |
self._loaded_constants = True
|
| 104 |
|
|
|
|
| 86 |
if payload.get("format") != "zerogpu_aoti_v1":
|
| 87 |
raise ValueError(f"Unsupported payload format: {payload.get('format')}")
|
| 88 |
self.archive_file = BytesIO(payload["archive_bytes"])
|
| 89 |
+
constants = payload["constants_map"]
|
| 90 |
+
constants = {k: v.to(device=device, dtype=torch.bfloat16).contiguous() for k, v in constants.items()}
|
| 91 |
+
payload["constants_map"] = constants
|
| 92 |
self.constants_map_cpu: dict[str, torch.Tensor] = payload["constants_map"]
|
| 93 |
self.device = device
|
| 94 |
self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar("compiled_model", default=None)
|
|
|
|
| 101 |
|
| 102 |
if not self._loaded_constants:
|
| 103 |
# Move constants to target device (cuda) and keep dtype as-is (bf16)
|
| 104 |
+
constants_map = {k: v.to(device="cuda", dtype=torch.bfloat16).contiguous() for k, v in self.constants_map_cpu.items()}
|
| 105 |
compiled_model.load_constants(constants_map, check_full_update=True, user_managed=True)
|
| 106 |
self._loaded_constants = True
|
| 107 |
|