Fabrice-TIERCELIN commited on
Commit
d8ae88b
·
verified ·
1 Parent(s): 627fa06

def export_compiled_transformers_to_zip() -> str:

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py CHANGED
@@ -424,6 +424,44 @@ def generate_video_on_gpu(
424
 
425
  return output_frames_list
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
  # --- 3. Gradio User Interface ---
429
 
@@ -497,6 +535,19 @@ with gr.Blocks(js=js) as app:
497
  download_button = gr.DownloadButton(elem_id="download_btn", interactive = True)
498
  video_information = gr.HTML(value = "")
499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  # Main video generation button
501
  ui_inputs = [
502
  start_image,
 
424
 
425
  return output_frames_list
426
 
427
+ def export_compiled_transformers_to_zip() -> str:
428
+ """
429
+ Bundle compiled_transformer_1 and compiled_transformer_2 into a zip file and return the file path.
430
+ """
431
+ ct1 = getattr(optimization, "COMPILED_TRANSFORMER_1", None)
432
+ ct2 = getattr(optimization, "COMPILED_TRANSFORMER_2", None)
433
+
434
+ if ct1 is None or ct2 is None:
435
+ raise gr.Error("Compiled transformers are not available yet (compilation may have failed).")
436
+
437
+ payload_1 = ct1.to_serializable_dict()
438
+ payload_2 = ct2.to_serializable_dict()
439
+
440
+ tmp_zip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
441
+ tmp_zip.close()
442
+
443
+ with zipfile.ZipFile(tmp_zip.name, "w", compression=zipfile.ZIP_DEFLATED) as zf:
444
+ # store with torch.save so users can load easily with torch.load()
445
+ buf1 = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
446
+ buf1.close()
447
+ torch.save(payload_1, buf1.name)
448
+
449
+ buf2 = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)
450
+ buf2.close()
451
+ torch.save(payload_2, buf2.name)
452
+
453
+ zf.write(buf1.name, arcname="compiled_transformer_1.pt")
454
+ zf.write(buf2.name, arcname="compiled_transformer_2.pt")
455
+
456
+ # cleanup intermediate .pt
457
+ try:
458
+ os.remove(buf1.name)
459
+ os.remove(buf2.name)
460
+ except:
461
+ pass
462
+
463
+ return tmp_zip.name
464
+
465
 
466
  # --- 3. Gradio User Interface ---
467
 
 
535
  download_button = gr.DownloadButton(elem_id="download_btn", interactive = True)
536
  video_information = gr.HTML(value = "")
537
 
538
+ with gr.Accordion("🔧 Compilation artifacts (advanced)", open=False):
539
+ gr.Markdown(
540
+ "Télécharge les artefacts compilés AOTInductor générés au démarrage (transformer + transformer_2)."
541
+ )
542
+ export_btn = gr.Button("📦 Préparer l'archive des transformers compilés")
543
+ compiled_download = gr.DownloadButton(label="⬇️ Télécharger compiled_transformers.zip", interactive=False)
544
+
545
+ def _build_and_enable_download():
546
+ path = export_compiled_transformers_to_zip()
547
+ return gr.update(value=path, interactive=True)
548
+
549
+ export_btn.click(fn=_build_and_enable_download, inputs=None, outputs=compiled_download)
550
+
551
  # Main video generation button
552
  ui_inputs = [
553
  start_image,