multimodalart HF Staff commited on
Commit
34757ae
·
verified ·
1 Parent(s): 14237e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -33
app.py CHANGED
@@ -54,42 +54,10 @@ pipe = HeliosPyramidPipeline.from_pretrained(
54
  is_distilled=True
55
  )
56
 
57
- aoti_load_(pipe.transformer, "multimodalart/helios-distilled-transformer", "helios_distilled_transformer.pt2")
58
 
59
  pipe.to("cuda")
60
 
61
- # ---------------------------------------------------------------------------
62
- # 🔥 AOT LOADING LOGIC 🔥
63
- # ---------------------------------------------------------------------------
64
- # AOT_FILENAME = "helios_distilled_transformer.pt2"
65
- # AOT_PATH = os.path.join(_APP_DIR, AOT_FILENAME)
66
-
67
- #def load_aot_model(path, original_module):
68
- # """
69
- # Loads a raw AOTI package (.pt2) and patches the original module.
70
- # """
71
- # print(f"[AOT] Loading AOTI package from {path}...")
72
- #
73
- # compiled_model = torch._inductor.aoti_load_package(path)
74
- #
75
- # original_module.forward = compiled_model
76
- #
77
- # original_module.to("meta")
78
- #
79
- # print("[AOT] Model patched successfully!")
80
-
81
- #if os.path.exists(AOT_PATH):
82
- # try:
83
- # load_aot_model(AOT_PATH, pipe.transformer)
84
- # print(f"[AOT] ✅ Loaded compiled graph: {e}")
85
- # except Exception as e:
86
- # print(f"[AOT] ❌ Failed to load compiled graph: {e}")
87
- # # Restore device if failed
88
- # pipe.to("cuda")
89
- # pipe.transformer.set_attention_backend("_flash_3_hub")
90
- #else:
91
- # print(f"[AOT] ⚠️ No compiled graph found at {AOT_PATH}.")
92
-
93
  pipe.transformer.set_attention_backend("_flash_3_hub")
94
 
95
  # ---------------------------------------------------------------------------
 
54
  is_distilled=True
55
  )
56
 
57
+ # aoti_load_(pipe.transformer, "multimodalart/helios-distilled-transformer", "helios_distilled_transformer.pt2")
58
 
59
  pipe.to("cuda")
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  pipe.transformer.set_attention_backend("_flash_3_hub")
62
 
63
  # ---------------------------------------------------------------------------