File size: 2,481 Bytes
afaf90f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# download the upscale models & place inside models/upscaler_models
# edit model paths accordingly 

import torch
import folder_paths
from spandrel import ModelLoader, ImageModelDescriptor

model_name = "4xNomos2_otf_esrgan.pth"
onnx_save_path = "./4xNomos2_otf_esrgan.onnx"

model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
model = ModelLoader().load_from_file(model_path).model.eval().cuda()

# Check dynamic shapes for esrgan 4x model
def supports_dynamic_shapes_esrgan(model, scale=4):

    input_shapes = [
    (1, 3, 64, 64),
    (1, 3, 128, 128),
    (1, 3, 256, 192),
    (1, 3, 512, 256),
    (1, 3, 512, 512)
    ]

    all_passed = True

    with torch.no_grad():
        for shape in input_shapes:
            try:
                dummy_input = torch.randn(*shape).cuda()
                output = model(dummy_input)

                expected_h = shape[2] * scale
                expected_w = shape[3] * scale

                assert output.shape[0] == shape[0], "Batch size mismatch"
                assert output.shape[1] == shape[1], "Channel mismatch"
                assert output.shape[2] == expected_h, f"Height mismatch: expected {expected_h}, got {output.shape[2]}"
                assert output.shape[3] == expected_w, f"Width mismatch: expected {expected_w}, got {output.shape[3]}"

                print(f"Success: input {shape} → output {output.shape}")
            except Exception as e:
                all_passed = False
                print(f"Failure: input {shape} → error: {e}")
                torch.cuda.empty_cache()

    if all_passed: print(f"Success: Dynamic shapes supported.")
    if not all_passed: print(f"Failure: Dynamic shapes NOT supported.")
    return all_passed

# Use smaller dummy input if model supports
if supports_dynamic_shapes_esrgan(model):
    shape = (1, 3, 64, 64)
    print(f"Using {shape} input (less VRAM usage)")
else:
    shape = (1, 3, 512, 512)
    print(f"Using {shape} input (large VRAM usage)")

x = torch.rand(*shape).cuda()

dynamic_axes = {
    "input": {0: "batch_size", 2: "width", 3: "height"},
    "output": {0: "batch_size", 2: "width", 3: "height"},
}

with torch.no_grad():
    torch.onnx.export(
        model,
        x,
        onnx_save_path,
        verbose=True,
        input_names=['input'],
        output_names=['output'],
        opset_version=17,
        export_params=True,
        dynamic_axes=dynamic_axes,
    )

print("Saved onnx to:", onnx_save_path)