Update src/pipeline.py
Browse files- src/pipeline.py +4 -2
src/pipeline.py
CHANGED
|
@@ -879,8 +879,10 @@ def load_pipeline() -> Pipeline:
|
|
| 879 |
inputs1 = get_example_inputs()
|
| 880 |
print(f"AoT pre compiled path is {path}")
|
| 881 |
|
| 882 |
-
transformer = torch._inductor.aoti_load_package(path)
|
| 883 |
-
|
|
|
|
|
|
|
| 884 |
for _ in range(2):
|
| 885 |
_ = transformer(**inputs1)[0]
|
| 886 |
|
|
|
|
| 879 |
inputs1 = get_example_inputs()
|
| 880 |
print(f"AoT pre compiled path is {path}")
|
| 881 |
|
| 882 |
+
# transformer = torch._inductor.aoti_load_package(path)
|
| 883 |
+
transformer = torch._export.aot_load(path, "cuda")
|
| 884 |
+
print(f"{transformer(**inputs1)[0].shape=}")
|
| 885 |
+
|
| 886 |
for _ in range(2):
|
| 887 |
_ = transformer(**inputs1)[0]
|
| 888 |
|