# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.
import sys
import torch
from pathlib import Path
def main():
if len(sys.argv) < 2:
print("Usage: python diagnostics_print_jit_constants.py ")
return
path = Path(sys.argv[1])
if not path.exists():
print("File not found:", path)
return
print("Loading JIT model (map_location='cpu') for safe inspection...")
m = torch.jit.load(str(path), map_location='cpu')
print("Loaded. Collecting info...\n")
print("Named parameters (name, device, shape):")
try:
for n, p in m.named_parameters():
print(" PARAM:", n, p.device, tuple(p.shape))
except Exception as e:
print(" (named_parameters() not available / raised):", e)
print("\nNamed buffers (name, device, shape):")
try:
for n, b in m.named_buffers():
print(" BUFFER:", n, b.device, tuple(b.shape))
except Exception as e:
print(" (named_buffers() not available / raised):", e)
print("\nstate_dict keys and devices:")
try:
sd = m.state_dict()
devices = set()
for k, v in sd.items():
try:
devices.add(v.device)
print(" ", k, v.device, tuple(v.shape))
except Exception:
print(" ", k, " - (non-tensor?)")
print("Devices in state_dict():", devices)
except Exception as e:
print(" state_dict() failed:", e)
print("\nAttempt to show TorchScript graph (short version). Look for prim::Constant Tensor entries:")
try:
g = m.graph
print(g)
except Exception as e:
print(" Could not print graph directly:", e)
try:
print("m.code():")
print(m.code)
except Exception as e2:
print(" Also could not print m.code():", e2)
print("\nIf you find prim::Constant values with Tensor on CPU, those likely cause device mismatch.")
print("Recommendation: re-create JIT on target device (see retrace_to_cuda.py).")
if __name__ == "__main__":
main()