#!/usr/bin/env python3 """ Dump a cached .pt sample to JSON for manual debugging. Usage: python scripts/dump_cache_sample.py # Dump first sample python scripts/dump_cache_sample.py --index 5 # Dump sample at index 5 python scripts/dump_cache_sample.py --file data/cache/sample_ABC123.pt # Dump specific file python scripts/dump_cache_sample.py --output debug.json # Custom output path """ import argparse import json import sys import os # Add project root to path so torch.load can find project modules when unpickling sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import numpy as np from pathlib import Path from datetime import datetime def convert_to_serializable(obj): """Recursively convert non-JSON-serializable objects.""" if obj is None: return None if isinstance(obj, (str, int, float, bool)): return obj if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): return float(obj) if isinstance(obj, np.ndarray): return {"__type__": "ndarray", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()} if isinstance(obj, torch.Tensor): data = obj.tolist() # Truncate large tensors for readability if obj.numel() > 50: flat = obj.flatten().tolist() data = flat[:20] + [f"... ({obj.numel()} elements total)"] return {"__type__": "tensor", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": data} # Handle EmbeddingPooler specifically if type(obj).__name__ == 'EmbeddingPooler': try: items = obj.get_all_items() return { "__type__": "EmbeddingPooler", "count": len(items), "items": [convert_to_serializable(item) for item in items] } except: return {"__type__": "EmbeddingPooler", "repr": str(obj)} if isinstance(obj, datetime): return {"__type__": "datetime", "value": obj.isoformat()} if isinstance(obj, bytes): return {"__type__": "bytes", "length": len(obj), "preview": obj[:100].hex() if len(obj) > 0 else ""} if isinstance(obj, dict): return {str(k): convert_to_serializable(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [convert_to_serializable(item) for item in obj] if isinstance(obj, set): return {"__type__": "set", "data": list(obj)} # Fallback: try str representation try: return {"__type__": type(obj).__name__, "repr": str(obj)[:500]} except: return {"__type__": "unknown", "repr": ""} def main(): parser = argparse.ArgumentParser(description="Dump cached .pt sample to JSON") parser.add_argument("--index", "-i", type=int, default=0, help="Index of sample to dump (default: 0)") parser.add_argument("--file", "-f", type=str, default=None, help="Direct path to .pt file (overrides --index)") parser.add_argument("--cache_dir", "-c", type=str, default="data/cache", help="Cache directory (default: data/cache)") parser.add_argument("--output", "-o", type=str, default=None, help="Output JSON path (default: auto-generated)") parser.add_argument("--compact", action="store_true", help="Compact JSON output (no indentation)") args = parser.parse_args() # Determine which file to load if args.file: filepath = Path(args.file) if not filepath.exists(): print(f"ERROR: File not found: {filepath}") return 1 else: cache_dir = Path(args.cache_dir) if not cache_dir.is_dir(): print(f"ERROR: Cache directory not found: {cache_dir}") return 1 cached_files = sorted(cache_dir.glob("sample_*.pt")) if not cached_files: print(f"ERROR: No sample_*.pt files found in {cache_dir}") return 1 if args.index >= len(cached_files): print(f"ERROR: Index {args.index} out of range. Found {len(cached_files)} files.") return 1 filepath = cached_files[args.index] print(f"Loading: {filepath}") # Load the .pt file try: data = torch.load(filepath, map_location="cpu", weights_only=False) except Exception as e: print(f"ERROR: Failed to load file: {e}") return 1 # Convert to JSON-serializable format print("Converting to JSON-serializable format...") serializable_data = convert_to_serializable(data) # Add metadata output_data = { "__metadata__": { "source_file": str(filepath.absolute()), "dumped_at": datetime.now().isoformat(), "cache_format": "context" if isinstance(data, dict) and "event_sequence" in data else "legacy" }, "data": serializable_data } # Determine output path if args.output: output_path = Path(args.output) else: # Default: Save to current directory (root) instead of inside cache dir output_path = Path.cwd() / filepath.with_suffix(".json").name # Write JSON print(f"Writing to: {output_path}") indent = None if args.compact else 2 with open(output_path, "w") as f: json.dump(output_data, f, indent=indent, ensure_ascii=False) # Print summary if isinstance(data, dict): print(f"\n=== Summary ===") print(f"Top-level keys: {list(data.keys())}") print(f"Cache format: {'context' if 'event_sequence' in data else 'legacy'}") if 'event_sequence' in data: print(f"Event count: {len(data['event_sequence'])}") if 'trades' in data: print(f"Trade count: {len(data['trades'])}") if 'source_token' in data: print(f"Source token: {data['source_token']}") if 'class_id' in data: print(f"Class ID: {data['class_id']}") if 'context_bucket' in data: print(f"Context bucket: {data['context_bucket']}") if 'context_score' in data: print(f"Context score: {data['context_score']}") if 'quality_score' in data: print(f"Quality score: {data['quality_score']}") print(f"\nDone! JSON saved to: {output_path}") return 0 if __name__ == "__main__": exit(main())