#!/usr/bin/env python3 """Export AGORA planner to all formats: safetensors, ONNX, TRT FP16, TRT FP32. Usage: CUDA_VISIBLE_DEVICES=3 python scripts/export_all.py """ from __future__ import annotations import gc import os import shutil import time from pathlib import Path import torch PROJECT = "project_agora" ARTIFACTS = "/mnt/artifacts-datai" MODEL_DIR = f"{ARTIFACTS}/models/{PROJECT}/agora-planner-v1" EXPORT_DIR = f"{ARTIFACTS}/exports/{PROJECT}" MERGED_DIR = f"{MODEL_DIR}/merged" os.makedirs(EXPORT_DIR, exist_ok=True) def export_safetensors(): """Export merged model as safetensors (already done by training, verify).""" print("\n[1/5] SAFETENSORS CHECK") st_files = list(Path(MERGED_DIR).glob("*.safetensors")) if st_files: total_size = sum(f.stat().st_size for f in st_files) print(f" Already exists: {len(st_files)} files, {total_size / 1e9:.2f} GB") # Copy to exports dst = Path(EXPORT_DIR) / "safetensors" dst.mkdir(exist_ok=True) for f in st_files: shutil.copy2(f, dst / f.name) # Also copy config + tokenizer for name in ["config.json", "tokenizer.json", "tokenizer_config.json", "generation_config.json", "special_tokens_map.json", "vocab.json", "merges.txt"]: src = Path(MERGED_DIR) / name if src.exists(): shutil.copy2(src, dst / name) print(f" Copied to {dst}") return True else: print(" ERROR: No safetensors found in merged dir") return False def export_pth(): """Export as single .pth file.""" print("\n[2/5] PTH EXPORT") from transformers import AutoModelForCausalLM pth_path = Path(EXPORT_DIR) / "agora_planner_v1.pth" if pth_path.exists(): print(f" Already exists: {pth_path} ({pth_path.stat().st_size / 1e9:.2f} GB)") return True print(" Loading model...") model = AutoModelForCausalLM.from_pretrained( MERGED_DIR, dtype=torch.float16, trust_remote_code=True, ) print(f" Saving to {pth_path}...") torch.save(model.state_dict(), pth_path) size_gb = pth_path.stat().st_size / 1e9 print(f" Saved: {size_gb:.2f} GB") del model gc.collect() torch.cuda.empty_cache() return True def export_onnx(): """Export to ONNX format using optimum for large model support.""" print("\n[3/5] ONNX EXPORT") onnx_dir = Path(EXPORT_DIR) / "onnx" if onnx_dir.exists() and list(onnx_dir.glob("*.onnx")): total = sum(f.stat().st_size for f in onnx_dir.rglob("*") if f.is_file()) print(f" Already exists: {onnx_dir} ({total / 1e9:.2f} GB total)") return True onnx_dir.mkdir(parents=True, exist_ok=True) print(f" Exporting with optimum to {onnx_dir}...") try: from optimum.exporters.onnx import main_export main_export( MERGED_DIR, output=str(onnx_dir), task="text-generation", opset=18, trust_remote_code=True, ) onnx_files = list(onnx_dir.rglob("*.onnx")) total = sum(f.stat().st_size for f in onnx_dir.rglob("*") if f.is_file()) print(f" Exported: {len(onnx_files)} ONNX files, {total / 1e9:.2f} GB total") return True except Exception as e: print(f" ERROR: {e}") return False def export_trt(precision: str = "fp16"): """Export ONNX model to TensorRT engine.""" step = "4" if precision == "fp16" else "5" print(f"\n[{step}/5] TENSORRT {precision.upper()} EXPORT") import tensorrt as trt # Find ONNX model (optimum exports to onnx/ directory) onnx_dir = Path(EXPORT_DIR) / "onnx" onnx_candidates = list(onnx_dir.glob("*.onnx")) if onnx_dir.exists() else [] # Also check flat file flat_onnx = Path(EXPORT_DIR) / "agora_planner_v1.onnx" if flat_onnx.exists(): onnx_candidates.append(flat_onnx) if not onnx_candidates: print(f" ERROR: No ONNX model found in {onnx_dir} or {EXPORT_DIR}") return False # Use the largest ONNX file (the main model, not decoder subgraph) onnx_path = max(onnx_candidates, key=lambda p: p.stat().st_size) trt_path = Path(EXPORT_DIR) / f"agora_planner_v1_trt_{precision}.engine" if trt_path.exists(): print(f" Already exists: {trt_path} ({trt_path.stat().st_size / 1e9:.2f} GB)") return True print(f" ONNX source: {onnx_path} ({onnx_path.stat().st_size / 1e9:.2f} GB)") print(f" Building TRT {precision.upper()} engine...") logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, logger) # For external data models, set the model path for parser parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM) print(" Parsing ONNX model...") success = parser.parse_from_file(str(onnx_path)) if not success: for i in range(parser.num_errors): print(f" PARSE ERROR: {parser.get_error(i)}") return False config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4 GB if precision == "fp16": if builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) print(" FP16 enabled") else: print(" WARNING: FP16 not supported, falling back to FP32") # Set optimization profiles for dynamic shapes profile = builder.create_optimization_profile() for i in range(network.num_inputs): inp = network.get_input(i) name = inp.name shape = inp.shape # Build min/opt/max from shape, handling dynamic dims (-1) min_shape = tuple(1 if d == -1 else d for d in shape) opt_shape = tuple(1 if d == -1 else d for d in shape) opt_shape = tuple(512 if i == len(shape) - 1 and d == -1 else (1 if d == -1 else d) for i, d in enumerate(shape)) max_shape = tuple(1024 if d == -1 else d for d in shape) # Override batch dim if len(shape) >= 2: min_shape = (1,) + min_shape[1:] opt_shape = (1,) + opt_shape[1:] max_shape = (4,) + max_shape[1:] profile.set_shape(name, min_shape, opt_shape, max_shape) print(f" Input '{name}': min={min_shape} opt={opt_shape} max={max_shape}") config.add_optimization_profile(profile) print(f" Building engine (this may take 10-30 minutes)...") t0 = time.time() engine_bytes = builder.build_serialized_network(network, config) elapsed = time.time() - t0 if engine_bytes is None: print(" ERROR: TRT engine build failed") return False with open(trt_path, "wb") as f: f.write(engine_bytes) size_gb = trt_path.stat().st_size / 1e9 print(f" Saved: {trt_path} ({size_gb:.2f} GB) in {elapsed:.0f}s") return True def main(): print("=" * 60) print("AGORA PLANNER — FULL EXPORT PIPELINE") print("=" * 60) print(f"Source: {MERGED_DIR}") print(f"Output: {EXPORT_DIR}") print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") results = {} results["safetensors"] = export_safetensors() results["pth"] = export_pth() results["onnx"] = export_onnx() results["trt_fp16"] = export_trt("fp16") results["trt_fp32"] = export_trt("fp32") print("\n" + "=" * 60) print("EXPORT RESULTS") print("=" * 60) for fmt, ok in results.items(): status = "PASS" if ok else "FAIL" print(f" [{status}] {fmt}") # List all exports print(f"\nFiles in {EXPORT_DIR}:") for f in sorted(Path(EXPORT_DIR).rglob("*")): if f.is_file(): size = f.stat().st_size if size > 1e9: print(f" {f.relative_to(EXPORT_DIR)}: {size / 1e9:.2f} GB") elif size > 1e6: print(f" {f.relative_to(EXPORT_DIR)}: {size / 1e6:.0f} MB") else: print(f" {f.relative_to(EXPORT_DIR)}: {size / 1e3:.0f} KB") all_pass = all(results.values()) print(f"\nOVERALL: {'ALL PASS' if all_pass else 'SOME FAILED'}") return 0 if all_pass else 1 if __name__ == "__main__": import sys sys.exit(main())