| |
| """Export AGORA planner to all formats: safetensors, ONNX, TRT FP16, TRT FP32. |
| |
| Usage: CUDA_VISIBLE_DEVICES=3 python scripts/export_all.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import gc |
| import os |
| import shutil |
| import time |
| from pathlib import Path |
|
|
| import torch |
|
|
| PROJECT = "project_agora" |
| ARTIFACTS = "/mnt/artifacts-datai" |
| MODEL_DIR = f"{ARTIFACTS}/models/{PROJECT}/agora-planner-v1" |
| EXPORT_DIR = f"{ARTIFACTS}/exports/{PROJECT}" |
| MERGED_DIR = f"{MODEL_DIR}/merged" |
|
|
| os.makedirs(EXPORT_DIR, exist_ok=True) |
|
|
|
|
| def export_safetensors(): |
| """Export merged model as safetensors (already done by training, verify).""" |
| print("\n[1/5] SAFETENSORS CHECK") |
| st_files = list(Path(MERGED_DIR).glob("*.safetensors")) |
| if st_files: |
| total_size = sum(f.stat().st_size for f in st_files) |
| print(f" Already exists: {len(st_files)} files, {total_size / 1e9:.2f} GB") |
| |
| dst = Path(EXPORT_DIR) / "safetensors" |
| dst.mkdir(exist_ok=True) |
| for f in st_files: |
| shutil.copy2(f, dst / f.name) |
| |
| for name in ["config.json", "tokenizer.json", "tokenizer_config.json", |
| "generation_config.json", "special_tokens_map.json", |
| "vocab.json", "merges.txt"]: |
| src = Path(MERGED_DIR) / name |
| if src.exists(): |
| shutil.copy2(src, dst / name) |
| print(f" Copied to {dst}") |
| return True |
| else: |
| print(" ERROR: No safetensors found in merged dir") |
| return False |
|
|
|
|
| def export_pth(): |
| """Export as single .pth file.""" |
| print("\n[2/5] PTH EXPORT") |
| from transformers import AutoModelForCausalLM |
|
|
| pth_path = Path(EXPORT_DIR) / "agora_planner_v1.pth" |
| if pth_path.exists(): |
| print(f" Already exists: {pth_path} ({pth_path.stat().st_size / 1e9:.2f} GB)") |
| return True |
|
|
| print(" Loading model...") |
| model = AutoModelForCausalLM.from_pretrained( |
| MERGED_DIR, |
| dtype=torch.float16, |
| trust_remote_code=True, |
| ) |
| print(f" Saving to {pth_path}...") |
| torch.save(model.state_dict(), pth_path) |
| size_gb = pth_path.stat().st_size / 1e9 |
| print(f" Saved: {size_gb:.2f} GB") |
| del model |
| gc.collect() |
| torch.cuda.empty_cache() |
| return True |
|
|
|
|
| def export_onnx(): |
| """Export to ONNX format using optimum for large model support.""" |
| print("\n[3/5] ONNX EXPORT") |
|
|
| onnx_dir = Path(EXPORT_DIR) / "onnx" |
| if onnx_dir.exists() and list(onnx_dir.glob("*.onnx")): |
| total = sum(f.stat().st_size for f in onnx_dir.rglob("*") if f.is_file()) |
| print(f" Already exists: {onnx_dir} ({total / 1e9:.2f} GB total)") |
| return True |
|
|
| onnx_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f" Exporting with optimum to {onnx_dir}...") |
| try: |
| from optimum.exporters.onnx import main_export |
|
|
| main_export( |
| MERGED_DIR, |
| output=str(onnx_dir), |
| task="text-generation", |
| opset=18, |
| trust_remote_code=True, |
| ) |
|
|
| onnx_files = list(onnx_dir.rglob("*.onnx")) |
| total = sum(f.stat().st_size for f in onnx_dir.rglob("*") if f.is_file()) |
| print(f" Exported: {len(onnx_files)} ONNX files, {total / 1e9:.2f} GB total") |
| return True |
| except Exception as e: |
| print(f" ERROR: {e}") |
| return False |
|
|
|
|
| def export_trt(precision: str = "fp16"): |
| """Export ONNX model to TensorRT engine.""" |
| step = "4" if precision == "fp16" else "5" |
| print(f"\n[{step}/5] TENSORRT {precision.upper()} EXPORT") |
| import tensorrt as trt |
|
|
| |
| onnx_dir = Path(EXPORT_DIR) / "onnx" |
| onnx_candidates = list(onnx_dir.glob("*.onnx")) if onnx_dir.exists() else [] |
| |
| flat_onnx = Path(EXPORT_DIR) / "agora_planner_v1.onnx" |
| if flat_onnx.exists(): |
| onnx_candidates.append(flat_onnx) |
|
|
| if not onnx_candidates: |
| print(f" ERROR: No ONNX model found in {onnx_dir} or {EXPORT_DIR}") |
| return False |
|
|
| |
| onnx_path = max(onnx_candidates, key=lambda p: p.stat().st_size) |
| trt_path = Path(EXPORT_DIR) / f"agora_planner_v1_trt_{precision}.engine" |
|
|
| if trt_path.exists(): |
| print(f" Already exists: {trt_path} ({trt_path.stat().st_size / 1e9:.2f} GB)") |
| return True |
|
|
| print(f" ONNX source: {onnx_path} ({onnx_path.stat().st_size / 1e9:.2f} GB)") |
| print(f" Building TRT {precision.upper()} engine...") |
|
|
| logger = trt.Logger(trt.Logger.WARNING) |
| builder = trt.Builder(logger) |
| network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) |
| parser = trt.OnnxParser(network, logger) |
|
|
| |
| parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM) |
|
|
| print(" Parsing ONNX model...") |
| success = parser.parse_from_file(str(onnx_path)) |
| if not success: |
| for i in range(parser.num_errors): |
| print(f" PARSE ERROR: {parser.get_error(i)}") |
| return False |
|
|
| config = builder.create_builder_config() |
| config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) |
|
|
| if precision == "fp16": |
| if builder.platform_has_fast_fp16: |
| config.set_flag(trt.BuilderFlag.FP16) |
| print(" FP16 enabled") |
| else: |
| print(" WARNING: FP16 not supported, falling back to FP32") |
|
|
| |
| profile = builder.create_optimization_profile() |
| for i in range(network.num_inputs): |
| inp = network.get_input(i) |
| name = inp.name |
| shape = inp.shape |
| |
| min_shape = tuple(1 if d == -1 else d for d in shape) |
| opt_shape = tuple(1 if d == -1 else d for d in shape) |
| opt_shape = tuple(512 if i == len(shape) - 1 and d == -1 else (1 if d == -1 else d) |
| for i, d in enumerate(shape)) |
| max_shape = tuple(1024 if d == -1 else d for d in shape) |
| |
| if len(shape) >= 2: |
| min_shape = (1,) + min_shape[1:] |
| opt_shape = (1,) + opt_shape[1:] |
| max_shape = (4,) + max_shape[1:] |
| profile.set_shape(name, min_shape, opt_shape, max_shape) |
| print(f" Input '{name}': min={min_shape} opt={opt_shape} max={max_shape}") |
| config.add_optimization_profile(profile) |
|
|
| print(f" Building engine (this may take 10-30 minutes)...") |
| t0 = time.time() |
| engine_bytes = builder.build_serialized_network(network, config) |
| elapsed = time.time() - t0 |
|
|
| if engine_bytes is None: |
| print(" ERROR: TRT engine build failed") |
| return False |
|
|
| with open(trt_path, "wb") as f: |
| f.write(engine_bytes) |
|
|
| size_gb = trt_path.stat().st_size / 1e9 |
| print(f" Saved: {trt_path} ({size_gb:.2f} GB) in {elapsed:.0f}s") |
| return True |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("AGORA PLANNER — FULL EXPORT PIPELINE") |
| print("=" * 60) |
| print(f"Source: {MERGED_DIR}") |
| print(f"Output: {EXPORT_DIR}") |
| print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}") |
|
|
| results = {} |
|
|
| results["safetensors"] = export_safetensors() |
| results["pth"] = export_pth() |
| results["onnx"] = export_onnx() |
| results["trt_fp16"] = export_trt("fp16") |
| results["trt_fp32"] = export_trt("fp32") |
|
|
| print("\n" + "=" * 60) |
| print("EXPORT RESULTS") |
| print("=" * 60) |
| for fmt, ok in results.items(): |
| status = "PASS" if ok else "FAIL" |
| print(f" [{status}] {fmt}") |
|
|
| |
| print(f"\nFiles in {EXPORT_DIR}:") |
| for f in sorted(Path(EXPORT_DIR).rglob("*")): |
| if f.is_file(): |
| size = f.stat().st_size |
| if size > 1e9: |
| print(f" {f.relative_to(EXPORT_DIR)}: {size / 1e9:.2f} GB") |
| elif size > 1e6: |
| print(f" {f.relative_to(EXPORT_DIR)}: {size / 1e6:.0f} MB") |
| else: |
| print(f" {f.relative_to(EXPORT_DIR)}: {size / 1e3:.0f} KB") |
|
|
| all_pass = all(results.values()) |
| print(f"\nOVERALL: {'ALL PASS' if all_pass else 'SOME FAILED'}") |
| return 0 if all_pass else 1 |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| sys.exit(main()) |
|
|