project_agora / scripts /export_all.py
ilessio-aiflowlab's picture
[AGORA] Full export: pth + safetensors + ONNX + TRT fp16 + TRT fp32
12d70dc verified
#!/usr/bin/env python3
"""Export AGORA planner to all formats: safetensors, ONNX, TRT FP16, TRT FP32.
Usage: CUDA_VISIBLE_DEVICES=3 python scripts/export_all.py
"""
from __future__ import annotations
import gc
import os
import shutil
import time
from pathlib import Path
import torch
PROJECT = "project_agora"
ARTIFACTS = "/mnt/artifacts-datai"
MODEL_DIR = f"{ARTIFACTS}/models/{PROJECT}/agora-planner-v1"
EXPORT_DIR = f"{ARTIFACTS}/exports/{PROJECT}"
MERGED_DIR = f"{MODEL_DIR}/merged"
os.makedirs(EXPORT_DIR, exist_ok=True)
def export_safetensors():
"""Export merged model as safetensors (already done by training, verify)."""
print("\n[1/5] SAFETENSORS CHECK")
st_files = list(Path(MERGED_DIR).glob("*.safetensors"))
if st_files:
total_size = sum(f.stat().st_size for f in st_files)
print(f" Already exists: {len(st_files)} files, {total_size / 1e9:.2f} GB")
# Copy to exports
dst = Path(EXPORT_DIR) / "safetensors"
dst.mkdir(exist_ok=True)
for f in st_files:
shutil.copy2(f, dst / f.name)
# Also copy config + tokenizer
for name in ["config.json", "tokenizer.json", "tokenizer_config.json",
"generation_config.json", "special_tokens_map.json",
"vocab.json", "merges.txt"]:
src = Path(MERGED_DIR) / name
if src.exists():
shutil.copy2(src, dst / name)
print(f" Copied to {dst}")
return True
else:
print(" ERROR: No safetensors found in merged dir")
return False
def export_pth():
"""Export as single .pth file."""
print("\n[2/5] PTH EXPORT")
from transformers import AutoModelForCausalLM
pth_path = Path(EXPORT_DIR) / "agora_planner_v1.pth"
if pth_path.exists():
print(f" Already exists: {pth_path} ({pth_path.stat().st_size / 1e9:.2f} GB)")
return True
print(" Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MERGED_DIR,
dtype=torch.float16,
trust_remote_code=True,
)
print(f" Saving to {pth_path}...")
torch.save(model.state_dict(), pth_path)
size_gb = pth_path.stat().st_size / 1e9
print(f" Saved: {size_gb:.2f} GB")
del model
gc.collect()
torch.cuda.empty_cache()
return True
def export_onnx():
"""Export to ONNX format using optimum for large model support."""
print("\n[3/5] ONNX EXPORT")
onnx_dir = Path(EXPORT_DIR) / "onnx"
if onnx_dir.exists() and list(onnx_dir.glob("*.onnx")):
total = sum(f.stat().st_size for f in onnx_dir.rglob("*") if f.is_file())
print(f" Already exists: {onnx_dir} ({total / 1e9:.2f} GB total)")
return True
onnx_dir.mkdir(parents=True, exist_ok=True)
print(f" Exporting with optimum to {onnx_dir}...")
try:
from optimum.exporters.onnx import main_export
main_export(
MERGED_DIR,
output=str(onnx_dir),
task="text-generation",
opset=18,
trust_remote_code=True,
)
onnx_files = list(onnx_dir.rglob("*.onnx"))
total = sum(f.stat().st_size for f in onnx_dir.rglob("*") if f.is_file())
print(f" Exported: {len(onnx_files)} ONNX files, {total / 1e9:.2f} GB total")
return True
except Exception as e:
print(f" ERROR: {e}")
return False
def export_trt(precision: str = "fp16"):
"""Export ONNX model to TensorRT engine."""
step = "4" if precision == "fp16" else "5"
print(f"\n[{step}/5] TENSORRT {precision.upper()} EXPORT")
import tensorrt as trt
# Find ONNX model (optimum exports to onnx/ directory)
onnx_dir = Path(EXPORT_DIR) / "onnx"
onnx_candidates = list(onnx_dir.glob("*.onnx")) if onnx_dir.exists() else []
# Also check flat file
flat_onnx = Path(EXPORT_DIR) / "agora_planner_v1.onnx"
if flat_onnx.exists():
onnx_candidates.append(flat_onnx)
if not onnx_candidates:
print(f" ERROR: No ONNX model found in {onnx_dir} or {EXPORT_DIR}")
return False
# Use the largest ONNX file (the main model, not decoder subgraph)
onnx_path = max(onnx_candidates, key=lambda p: p.stat().st_size)
trt_path = Path(EXPORT_DIR) / f"agora_planner_v1_trt_{precision}.engine"
if trt_path.exists():
print(f" Already exists: {trt_path} ({trt_path.stat().st_size / 1e9:.2f} GB)")
return True
print(f" ONNX source: {onnx_path} ({onnx_path.stat().st_size / 1e9:.2f} GB)")
print(f" Building TRT {precision.upper()} engine...")
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# For external data models, set the model path for parser
parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM)
print(" Parsing ONNX model...")
success = parser.parse_from_file(str(onnx_path))
if not success:
for i in range(parser.num_errors):
print(f" PARSE ERROR: {parser.get_error(i)}")
return False
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4 GB
if precision == "fp16":
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print(" FP16 enabled")
else:
print(" WARNING: FP16 not supported, falling back to FP32")
# Set optimization profiles for dynamic shapes
profile = builder.create_optimization_profile()
for i in range(network.num_inputs):
inp = network.get_input(i)
name = inp.name
shape = inp.shape
# Build min/opt/max from shape, handling dynamic dims (-1)
min_shape = tuple(1 if d == -1 else d for d in shape)
opt_shape = tuple(1 if d == -1 else d for d in shape)
opt_shape = tuple(512 if i == len(shape) - 1 and d == -1 else (1 if d == -1 else d)
for i, d in enumerate(shape))
max_shape = tuple(1024 if d == -1 else d for d in shape)
# Override batch dim
if len(shape) >= 2:
min_shape = (1,) + min_shape[1:]
opt_shape = (1,) + opt_shape[1:]
max_shape = (4,) + max_shape[1:]
profile.set_shape(name, min_shape, opt_shape, max_shape)
print(f" Input '{name}': min={min_shape} opt={opt_shape} max={max_shape}")
config.add_optimization_profile(profile)
print(f" Building engine (this may take 10-30 minutes)...")
t0 = time.time()
engine_bytes = builder.build_serialized_network(network, config)
elapsed = time.time() - t0
if engine_bytes is None:
print(" ERROR: TRT engine build failed")
return False
with open(trt_path, "wb") as f:
f.write(engine_bytes)
size_gb = trt_path.stat().st_size / 1e9
print(f" Saved: {trt_path} ({size_gb:.2f} GB) in {elapsed:.0f}s")
return True
def main():
print("=" * 60)
print("AGORA PLANNER — FULL EXPORT PIPELINE")
print("=" * 60)
print(f"Source: {MERGED_DIR}")
print(f"Output: {EXPORT_DIR}")
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
results = {}
results["safetensors"] = export_safetensors()
results["pth"] = export_pth()
results["onnx"] = export_onnx()
results["trt_fp16"] = export_trt("fp16")
results["trt_fp32"] = export_trt("fp32")
print("\n" + "=" * 60)
print("EXPORT RESULTS")
print("=" * 60)
for fmt, ok in results.items():
status = "PASS" if ok else "FAIL"
print(f" [{status}] {fmt}")
# List all exports
print(f"\nFiles in {EXPORT_DIR}:")
for f in sorted(Path(EXPORT_DIR).rglob("*")):
if f.is_file():
size = f.stat().st_size
if size > 1e9:
print(f" {f.relative_to(EXPORT_DIR)}: {size / 1e9:.2f} GB")
elif size > 1e6:
print(f" {f.relative_to(EXPORT_DIR)}: {size / 1e6:.0f} MB")
else:
print(f" {f.relative_to(EXPORT_DIR)}: {size / 1e3:.0f} KB")
all_pass = all(results.values())
print(f"\nOVERALL: {'ALL PASS' if all_pass else 'SOME FAILED'}")
return 0 if all_pass else 1
if __name__ == "__main__":
import sys
sys.exit(main())