File size: 8,432 Bytes
12d70dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
#!/usr/bin/env python3
"""Export AGORA planner to all formats: safetensors, ONNX, TRT FP16, TRT FP32.

Usage: CUDA_VISIBLE_DEVICES=3 python scripts/export_all.py
"""

from __future__ import annotations

import gc
import os
import shutil
import time
from pathlib import Path

import torch

PROJECT = "project_agora"
ARTIFACTS = "/mnt/artifacts-datai"
MODEL_DIR = f"{ARTIFACTS}/models/{PROJECT}/agora-planner-v1"
EXPORT_DIR = f"{ARTIFACTS}/exports/{PROJECT}"
MERGED_DIR = f"{MODEL_DIR}/merged"

os.makedirs(EXPORT_DIR, exist_ok=True)


def export_safetensors():
    """Export merged model as safetensors (already done by training, verify)."""
    print("\n[1/5] SAFETENSORS CHECK")
    st_files = list(Path(MERGED_DIR).glob("*.safetensors"))
    if st_files:
        total_size = sum(f.stat().st_size for f in st_files)
        print(f"  Already exists: {len(st_files)} files, {total_size / 1e9:.2f} GB")
        # Copy to exports
        dst = Path(EXPORT_DIR) / "safetensors"
        dst.mkdir(exist_ok=True)
        for f in st_files:
            shutil.copy2(f, dst / f.name)
        # Also copy config + tokenizer
        for name in ["config.json", "tokenizer.json", "tokenizer_config.json",
                      "generation_config.json", "special_tokens_map.json",
                      "vocab.json", "merges.txt"]:
            src = Path(MERGED_DIR) / name
            if src.exists():
                shutil.copy2(src, dst / name)
        print(f"  Copied to {dst}")
        return True
    else:
        print("  ERROR: No safetensors found in merged dir")
        return False


def export_pth():
    """Export as single .pth file."""
    print("\n[2/5] PTH EXPORT")
    from transformers import AutoModelForCausalLM

    pth_path = Path(EXPORT_DIR) / "agora_planner_v1.pth"
    if pth_path.exists():
        print(f"  Already exists: {pth_path} ({pth_path.stat().st_size / 1e9:.2f} GB)")
        return True

    print("  Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        MERGED_DIR,
        dtype=torch.float16,
        trust_remote_code=True,
    )
    print(f"  Saving to {pth_path}...")
    torch.save(model.state_dict(), pth_path)
    size_gb = pth_path.stat().st_size / 1e9
    print(f"  Saved: {size_gb:.2f} GB")
    del model
    gc.collect()
    torch.cuda.empty_cache()
    return True


def export_onnx():
    """Export to ONNX format using optimum for large model support."""
    print("\n[3/5] ONNX EXPORT")

    onnx_dir = Path(EXPORT_DIR) / "onnx"
    if onnx_dir.exists() and list(onnx_dir.glob("*.onnx")):
        total = sum(f.stat().st_size for f in onnx_dir.rglob("*") if f.is_file())
        print(f"  Already exists: {onnx_dir} ({total / 1e9:.2f} GB total)")
        return True

    onnx_dir.mkdir(parents=True, exist_ok=True)

    print(f"  Exporting with optimum to {onnx_dir}...")
    try:
        from optimum.exporters.onnx import main_export

        main_export(
            MERGED_DIR,
            output=str(onnx_dir),
            task="text-generation",
            opset=18,
            trust_remote_code=True,
        )

        onnx_files = list(onnx_dir.rglob("*.onnx"))
        total = sum(f.stat().st_size for f in onnx_dir.rglob("*") if f.is_file())
        print(f"  Exported: {len(onnx_files)} ONNX files, {total / 1e9:.2f} GB total")
        return True
    except Exception as e:
        print(f"  ERROR: {e}")
        return False


def export_trt(precision: str = "fp16"):
    """Export ONNX model to TensorRT engine."""
    step = "4" if precision == "fp16" else "5"
    print(f"\n[{step}/5] TENSORRT {precision.upper()} EXPORT")
    import tensorrt as trt

    # Find ONNX model (optimum exports to onnx/ directory)
    onnx_dir = Path(EXPORT_DIR) / "onnx"
    onnx_candidates = list(onnx_dir.glob("*.onnx")) if onnx_dir.exists() else []
    # Also check flat file
    flat_onnx = Path(EXPORT_DIR) / "agora_planner_v1.onnx"
    if flat_onnx.exists():
        onnx_candidates.append(flat_onnx)

    if not onnx_candidates:
        print(f"  ERROR: No ONNX model found in {onnx_dir} or {EXPORT_DIR}")
        return False

    # Use the largest ONNX file (the main model, not decoder subgraph)
    onnx_path = max(onnx_candidates, key=lambda p: p.stat().st_size)
    trt_path = Path(EXPORT_DIR) / f"agora_planner_v1_trt_{precision}.engine"

    if trt_path.exists():
        print(f"  Already exists: {trt_path} ({trt_path.stat().st_size / 1e9:.2f} GB)")
        return True

    print(f"  ONNX source: {onnx_path} ({onnx_path.stat().st_size / 1e9:.2f} GB)")
    print(f"  Building TRT {precision.upper()} engine...")

    logger = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(logger)
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    parser = trt.OnnxParser(network, logger)

    # For external data models, set the model path for parser
    parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM)

    print("  Parsing ONNX model...")
    success = parser.parse_from_file(str(onnx_path))
    if not success:
        for i in range(parser.num_errors):
            print(f"  PARSE ERROR: {parser.get_error(i)}")
        return False

    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30)  # 4 GB

    if precision == "fp16":
        if builder.platform_has_fast_fp16:
            config.set_flag(trt.BuilderFlag.FP16)
            print("  FP16 enabled")
        else:
            print("  WARNING: FP16 not supported, falling back to FP32")

    # Set optimization profiles for dynamic shapes
    profile = builder.create_optimization_profile()
    for i in range(network.num_inputs):
        inp = network.get_input(i)
        name = inp.name
        shape = inp.shape
        # Build min/opt/max from shape, handling dynamic dims (-1)
        min_shape = tuple(1 if d == -1 else d for d in shape)
        opt_shape = tuple(1 if d == -1 else d for d in shape)
        opt_shape = tuple(512 if i == len(shape) - 1 and d == -1 else (1 if d == -1 else d)
                          for i, d in enumerate(shape))
        max_shape = tuple(1024 if d == -1 else d for d in shape)
        # Override batch dim
        if len(shape) >= 2:
            min_shape = (1,) + min_shape[1:]
            opt_shape = (1,) + opt_shape[1:]
            max_shape = (4,) + max_shape[1:]
        profile.set_shape(name, min_shape, opt_shape, max_shape)
        print(f"  Input '{name}': min={min_shape} opt={opt_shape} max={max_shape}")
    config.add_optimization_profile(profile)

    print(f"  Building engine (this may take 10-30 minutes)...")
    t0 = time.time()
    engine_bytes = builder.build_serialized_network(network, config)
    elapsed = time.time() - t0

    if engine_bytes is None:
        print("  ERROR: TRT engine build failed")
        return False

    with open(trt_path, "wb") as f:
        f.write(engine_bytes)

    size_gb = trt_path.stat().st_size / 1e9
    print(f"  Saved: {trt_path} ({size_gb:.2f} GB) in {elapsed:.0f}s")
    return True


def main():
    print("=" * 60)
    print("AGORA PLANNER — FULL EXPORT PIPELINE")
    print("=" * 60)
    print(f"Source:  {MERGED_DIR}")
    print(f"Output:  {EXPORT_DIR}")
    print(f"Device:  {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

    results = {}

    results["safetensors"] = export_safetensors()
    results["pth"] = export_pth()
    results["onnx"] = export_onnx()
    results["trt_fp16"] = export_trt("fp16")
    results["trt_fp32"] = export_trt("fp32")

    print("\n" + "=" * 60)
    print("EXPORT RESULTS")
    print("=" * 60)
    for fmt, ok in results.items():
        status = "PASS" if ok else "FAIL"
        print(f"  [{status}] {fmt}")

    # List all exports
    print(f"\nFiles in {EXPORT_DIR}:")
    for f in sorted(Path(EXPORT_DIR).rglob("*")):
        if f.is_file():
            size = f.stat().st_size
            if size > 1e9:
                print(f"  {f.relative_to(EXPORT_DIR)}: {size / 1e9:.2f} GB")
            elif size > 1e6:
                print(f"  {f.relative_to(EXPORT_DIR)}: {size / 1e6:.0f} MB")
            else:
                print(f"  {f.relative_to(EXPORT_DIR)}: {size / 1e3:.0f} KB")

    all_pass = all(results.values())
    print(f"\nOVERALL: {'ALL PASS' if all_pass else 'SOME FAILED'}")
    return 0 if all_pass else 1


if __name__ == "__main__":
    import sys
    sys.exit(main())