File size: 2,962 Bytes
c88fe21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.


import sys
import torch
from pathlib import Path

def main():
    if len(sys.argv) < 2:
        print("Usage: python diagnostics_print_jit_constants.py <jit_model_path>")
        return
    path = Path(sys.argv[1])
    if not path.exists():
        print("File not found:", path)
        return

    print("Loading JIT model (map_location='cpu') for safe inspection...")
    m = torch.jit.load(str(path), map_location='cpu')
    print("Loaded. Collecting info...\n")

    print("Named parameters (name, device, shape):")
    try:
        for n, p in m.named_parameters():
            print("  PARAM:", n, p.device, tuple(p.shape))
    except Exception as e:
        print("  (named_parameters() not available / raised):", e)

    print("\nNamed buffers (name, device, shape):")
    try:
        for n, b in m.named_buffers():
            print("  BUFFER:", n, b.device, tuple(b.shape))
    except Exception as e:
        print("  (named_buffers() not available / raised):", e)

    print("\nstate_dict keys and devices:")
    try:
        sd = m.state_dict()
        devices = set()
        for k, v in sd.items():
            try:
                devices.add(v.device)
                print(" ", k, v.device, tuple(v.shape))
            except Exception:
                print(" ", k, " - (non-tensor?)")
        print("Devices in state_dict():", devices)
    except Exception as e:
        print("  state_dict() failed:", e)

    print("\nAttempt to show TorchScript graph (short version). Look for prim::Constant Tensor entries:")
    try:
        g = m.graph
        print(g)
    except Exception as e:
        print("  Could not print graph directly:", e)
        try:
            print("m.code():")
            print(m.code)
        except Exception as e2:
            print("  Also could not print m.code():", e2)

    print("\nIf you find prim::Constant values with Tensor on CPU, those likely cause device mismatch.")
    print("Recommendation: re-create JIT on target device (see retrace_to_cuda.py).")

if __name__ == "__main__":
    main()