fix-aoti-bug
Browse files
__pycache__/app.cpython-312.pyc
DELETED
|
Binary file (14.1 kB)
|
|
|
heartlib/src/heartlib/heartmula/acceleration.py
CHANGED
|
@@ -62,15 +62,13 @@ def compile_module_with_aoti(
|
|
| 62 |
"kwargs": call.kwargs,
|
| 63 |
}
|
| 64 |
if dynamic_shapes:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
else:
|
| 73 |
-
export_kwargs["dynamic_shapes"] = dynamic_shapes
|
| 74 |
|
| 75 |
exported = torch.export.export(module, **export_kwargs)
|
| 76 |
compiled_module = spaces.aoti_compile(exported)
|
|
|
|
| 62 |
"kwargs": call.kwargs,
|
| 63 |
}
|
| 64 |
if dynamic_shapes:
|
| 65 |
+
# Follow the HF ZeroGPU-AoTI blog pattern: build a flat dict
|
| 66 |
+
# that mirrors call.kwargs with None for non-dynamic dims, then
|
| 67 |
+
# overlay the caller-supplied dynamic dim specs.
|
| 68 |
+
dynamic_shape_tree = tree_map(lambda _: None, call.kwargs)
|
| 69 |
+
for name, spec in dynamic_shapes.items():
|
| 70 |
+
dynamic_shape_tree[name] = spec
|
| 71 |
+
export_kwargs["dynamic_shapes"] = dynamic_shape_tree
|
|
|
|
|
|
|
| 72 |
|
| 73 |
exported = torch.export.export(module, **export_kwargs)
|
| 74 |
compiled_module = spaces.aoti_compile(exported)
|