ABLingss commited on
Commit
c6078db
·
1 Parent(s): 0d2b17a

fix-aoti-bug

Browse files
__pycache__/app.cpython-312.pyc DELETED
Binary file (14.1 kB)
 
heartlib/src/heartlib/heartmula/acceleration.py CHANGED
@@ -62,15 +62,13 @@ def compile_module_with_aoti(
62
  "kwargs": call.kwargs,
63
  }
64
  if dynamic_shapes:
65
- if call.kwargs:
66
- dynamic_shape_args = tree_map(lambda _: None, call.args)
67
- dynamic_shape_tree = tree_map(lambda _: None, call.kwargs)
68
- for name, spec in dynamic_shapes.items():
69
- if name in dynamic_shape_tree:
70
- dynamic_shape_tree[name] = spec
71
- export_kwargs["dynamic_shapes"] = (dynamic_shape_args, dynamic_shape_tree)
72
- else:
73
- export_kwargs["dynamic_shapes"] = dynamic_shapes
74
 
75
  exported = torch.export.export(module, **export_kwargs)
76
  compiled_module = spaces.aoti_compile(exported)
 
62
  "kwargs": call.kwargs,
63
  }
64
  if dynamic_shapes:
65
+ # Follow the HF ZeroGPU-AoTI blog pattern: build a flat dict
66
+ # that mirrors call.kwargs with None for non-dynamic dims, then
67
+ # overlay the caller-supplied dynamic dim specs.
68
+ dynamic_shape_tree = tree_map(lambda _: None, call.kwargs)
69
+ for name, spec in dynamic_shapes.items():
70
+ dynamic_shape_tree[name] = spec
71
+ export_kwargs["dynamic_shapes"] = dynamic_shape_tree
 
 
72
 
73
  exported = torch.export.export(module, **export_kwargs)
74
  compiled_module = spaces.aoti_compile(exported)