manbeast3b commited on
Commit
ecd4ec6
·
verified ·
1 Parent(s): 5a1987b

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +4 -2
src/pipeline.py CHANGED
@@ -879,8 +879,10 @@ def load_pipeline() -> Pipeline:
879
  inputs1 = get_example_inputs()
880
  print(f"AoT pre compiled path is {path}")
881
 
882
- transformer = torch._inductor.aoti_load_package(path)
883
-
 
 
884
  for _ in range(2):
885
  _ = transformer(**inputs1)[0]
886
 
 
879
  inputs1 = get_example_inputs()
880
  print(f"AoT pre compiled path is {path}")
881
 
882
+ # transformer = torch._inductor.aoti_load_package(path)
883
+ transformer = torch._export.aot_load(path, "cuda")
884
+ print(f"{transformer(**inputs1)[0].shape=}")
885
+
886
  for _ in range(2):
887
  _ = transformer(**inputs1)[0]
888