|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
import torch |
|
|
from pathlib import Path |
|
|
|
|
|
def main(): |
|
|
if len(sys.argv) < 2: |
|
|
print("Usage: python diagnostics_print_jit_constants.py <jit_model_path>") |
|
|
return |
|
|
path = Path(sys.argv[1]) |
|
|
if not path.exists(): |
|
|
print("File not found:", path) |
|
|
return |
|
|
|
|
|
print("Loading JIT model (map_location='cpu') for safe inspection...") |
|
|
m = torch.jit.load(str(path), map_location='cpu') |
|
|
print("Loaded. Collecting info...\n") |
|
|
|
|
|
print("Named parameters (name, device, shape):") |
|
|
try: |
|
|
for n, p in m.named_parameters(): |
|
|
print(" PARAM:", n, p.device, tuple(p.shape)) |
|
|
except Exception as e: |
|
|
print(" (named_parameters() not available / raised):", e) |
|
|
|
|
|
print("\nNamed buffers (name, device, shape):") |
|
|
try: |
|
|
for n, b in m.named_buffers(): |
|
|
print(" BUFFER:", n, b.device, tuple(b.shape)) |
|
|
except Exception as e: |
|
|
print(" (named_buffers() not available / raised):", e) |
|
|
|
|
|
print("\nstate_dict keys and devices:") |
|
|
try: |
|
|
sd = m.state_dict() |
|
|
devices = set() |
|
|
for k, v in sd.items(): |
|
|
try: |
|
|
devices.add(v.device) |
|
|
print(" ", k, v.device, tuple(v.shape)) |
|
|
except Exception: |
|
|
print(" ", k, " - (non-tensor?)") |
|
|
print("Devices in state_dict():", devices) |
|
|
except Exception as e: |
|
|
print(" state_dict() failed:", e) |
|
|
|
|
|
print("\nAttempt to show TorchScript graph (short version). Look for prim::Constant Tensor entries:") |
|
|
try: |
|
|
g = m.graph |
|
|
print(g) |
|
|
except Exception as e: |
|
|
print(" Could not print graph directly:", e) |
|
|
try: |
|
|
print("m.code():") |
|
|
print(m.code) |
|
|
except Exception as e2: |
|
|
print(" Also could not print m.code():", e2) |
|
|
|
|
|
print("\nIf you find prim::Constant values with Tensor on CPU, those likely cause device mismatch.") |
|
|
print("Recommendation: re-create JIT on target device (see retrace_to_cuda.py).") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |