# Copyright (c) 2025 CMS Manhattan # All rights reserved. # Author: Konstantin Vladimirovich Grabko # Email: grabko@cmsmanhattan.com # Phone: +1(516)777-0945 # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, version 3 of the License. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Additional terms: # Any commercial use or distribution of this software or derivative works # requires explicit written permission from the copyright holder. import sys import torch from pathlib import Path def main(): if len(sys.argv) < 2: print("Usage: python diagnostics_print_jit_constants.py ") return path = Path(sys.argv[1]) if not path.exists(): print("File not found:", path) return print("Loading JIT model (map_location='cpu') for safe inspection...") m = torch.jit.load(str(path), map_location='cpu') print("Loaded. Collecting info...\n") print("Named parameters (name, device, shape):") try: for n, p in m.named_parameters(): print(" PARAM:", n, p.device, tuple(p.shape)) except Exception as e: print(" (named_parameters() not available / raised):", e) print("\nNamed buffers (name, device, shape):") try: for n, b in m.named_buffers(): print(" BUFFER:", n, b.device, tuple(b.shape)) except Exception as e: print(" (named_buffers() not available / raised):", e) print("\nstate_dict keys and devices:") try: sd = m.state_dict() devices = set() for k, v in sd.items(): try: devices.add(v.device) print(" ", k, v.device, tuple(v.shape)) except Exception: print(" ", k, " - (non-tensor?)") print("Devices in state_dict():", devices) except Exception as e: print(" state_dict() failed:", e) print("\nAttempt to show TorchScript graph (short version). Look for prim::Constant Tensor entries:") try: g = m.graph print(g) except Exception as e: print(" Could not print graph directly:", e) try: print("m.code():") print(m.code) except Exception as e2: print(" Also could not print m.code():", e2) print("\nIf you find prim::Constant values with Tensor on CPU, those likely cause device mismatch.") print("Recommendation: re-create JIT on target device (see retrace_to_cuda.py).") if __name__ == "__main__": main()