File size: 2,319 Bytes
afaf90f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import torch
import time
from utilities import Engine

def export_trt(trt_path=None, onnx_path=None, use_fp16=True):
    option = input("Choose an option:\n1. Convert a single ONNX file\n2. Convert all ONNX files in a directory\nEnter your choice (1 or 2): ")

    if option == '1':
        onnx_path = input("Enter the path to the ONNX model (e.g ./realesrgan.onnx): ")
        onnx_files = [onnx_path]
        trt_dir = input("Enter the path to save the TensorRT engine (e.g ./trt_engine/): ")
    elif option == '2':
        onnx_dir = input("Enter the directory path containing ONNX models (e.g ./onnx_models/): ")
        onnx_files = [os.path.join(onnx_dir, file) for file in os.listdir(onnx_dir) if file.endswith('.onnx')]
        if not onnx_files:
            raise ValueError(f"No .onnx files found in directory: {onnx_dir}")
        trt_dir = input("Enter the directory path to save the TensorRT engines (e.g ./trt_engine/): ")
    else:
        raise ValueError("Invalid option. Please choose either 1 or 2.")

    # Check if trt_dir already exists as a directory
    if not os.path.exists(trt_dir):
        os.makedirs(trt_dir)
        
    #os.makedirs(trt_dir, exist_ok=True)
    total_files = len(onnx_files)
    for index, onnx_path in enumerate(onnx_files):
        engine = Engine(trt_path)

        torch.cuda.empty_cache()
        base_name = os.path.splitext(os.path.basename(onnx_path))[0]
        trt_path = os.path.join(trt_dir, f"{base_name}.engine")

        print(f"Converting {onnx_path} to {trt_path}")

        s = time.time()

        # Initialize Engine with trt_path and clear CUDA cache
        engine = Engine(trt_path)
        torch.cuda.empty_cache()

        ret = engine.build(
        onnx_path,
        use_fp16,
        enable_preview=True,
        input_profile=[
            {"input": [(1,3,256,256), (1,3,512,512), (1,3,1280,1280)]}, # any sizes from 256x256 to 1280x1280
        ],
        )

        e = time.time()
        print(f"Time taken to build: {(e-s)} seconds")
        if index < total_files - 1:
            # Delay for 10 seconds
            print("Delaying for 10 seconds...")
            time.sleep(10)
            print("Resuming operations after delay...")

    return

export_trt()