| |
| """ |
| Dump a cached .pt sample to JSON for manual debugging. |
| |
| Usage: |
| python scripts/dump_cache_sample.py # Dump first sample |
| python scripts/dump_cache_sample.py --index 5 # Dump sample at index 5 |
| python scripts/dump_cache_sample.py --file data/cache/sample_ABC123.pt # Dump specific file |
| python scripts/dump_cache_sample.py --output debug.json # Custom output path |
| """ |
|
|
| import argparse |
| import json |
| import sys |
| import os |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| import torch |
| import numpy as np |
| from pathlib import Path |
| from datetime import datetime |
|
|
|
|
| def convert_to_serializable(obj): |
| """Recursively convert non-JSON-serializable objects.""" |
| if obj is None: |
| return None |
| if isinstance(obj, (str, int, float, bool)): |
| return obj |
| if isinstance(obj, (np.integer,)): |
| return int(obj) |
| if isinstance(obj, (np.floating,)): |
| return float(obj) |
| if isinstance(obj, np.ndarray): |
| return {"__type__": "ndarray", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()} |
| if isinstance(obj, torch.Tensor): |
| data = obj.tolist() |
| |
| if obj.numel() > 50: |
| flat = obj.flatten().tolist() |
| data = flat[:20] + [f"... ({obj.numel()} elements total)"] |
| return {"__type__": "tensor", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": data} |
| |
| |
| if type(obj).__name__ == 'EmbeddingPooler': |
| try: |
| items = obj.get_all_items() |
| return { |
| "__type__": "EmbeddingPooler", |
| "count": len(items), |
| "items": [convert_to_serializable(item) for item in items] |
| } |
| except: |
| return {"__type__": "EmbeddingPooler", "repr": str(obj)} |
| if isinstance(obj, datetime): |
| return {"__type__": "datetime", "value": obj.isoformat()} |
| if isinstance(obj, bytes): |
| return {"__type__": "bytes", "length": len(obj), "preview": obj[:100].hex() if len(obj) > 0 else ""} |
| if isinstance(obj, dict): |
| return {str(k): convert_to_serializable(v) for k, v in obj.items()} |
| if isinstance(obj, (list, tuple)): |
| return [convert_to_serializable(item) for item in obj] |
| if isinstance(obj, set): |
| return {"__type__": "set", "data": list(obj)} |
| |
| try: |
| return {"__type__": type(obj).__name__, "repr": str(obj)[:500]} |
| except: |
| return {"__type__": "unknown", "repr": "<not serializable>"} |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Dump cached .pt sample to JSON") |
| parser.add_argument("--index", "-i", type=int, default=0, help="Index of sample to dump (default: 0)") |
| parser.add_argument("--file", "-f", type=str, default=None, help="Direct path to .pt file (overrides --index)") |
| parser.add_argument("--cache_dir", "-c", type=str, default="data/cache", help="Cache directory (default: data/cache)") |
| parser.add_argument("--output", "-o", type=str, default=None, help="Output JSON path (default: auto-generated)") |
| parser.add_argument("--compact", action="store_true", help="Compact JSON output (no indentation)") |
| args = parser.parse_args() |
|
|
| |
| if args.file: |
| filepath = Path(args.file) |
| if not filepath.exists(): |
| print(f"ERROR: File not found: {filepath}") |
| return 1 |
| else: |
| cache_dir = Path(args.cache_dir) |
| if not cache_dir.is_dir(): |
| print(f"ERROR: Cache directory not found: {cache_dir}") |
| return 1 |
| |
| cached_files = sorted(cache_dir.glob("sample_*.pt")) |
| if not cached_files: |
| print(f"ERROR: No sample_*.pt files found in {cache_dir}") |
| return 1 |
| |
| if args.index >= len(cached_files): |
| print(f"ERROR: Index {args.index} out of range. Found {len(cached_files)} files.") |
| return 1 |
| |
| filepath = cached_files[args.index] |
| |
| print(f"Loading: {filepath}") |
| |
| |
| try: |
| data = torch.load(filepath, map_location="cpu", weights_only=False) |
| except Exception as e: |
| print(f"ERROR: Failed to load file: {e}") |
| return 1 |
| |
| |
| print("Converting to JSON-serializable format...") |
| serializable_data = convert_to_serializable(data) |
| |
| |
| output_data = { |
| "__metadata__": { |
| "source_file": str(filepath.absolute()), |
| "dumped_at": datetime.now().isoformat(), |
| "cache_format": "context" if isinstance(data, dict) and "event_sequence" in data else "legacy" |
| }, |
| "data": serializable_data |
| } |
| |
| |
| if args.output: |
| output_path = Path(args.output) |
| else: |
| |
| output_path = Path.cwd() / filepath.with_suffix(".json").name |
| |
| |
| print(f"Writing to: {output_path}") |
| indent = None if args.compact else 2 |
| with open(output_path, "w") as f: |
| json.dump(output_data, f, indent=indent, ensure_ascii=False) |
| |
| |
| if isinstance(data, dict): |
| print(f"\n=== Summary ===") |
| print(f"Top-level keys: {list(data.keys())}") |
| print(f"Cache format: {'context' if 'event_sequence' in data else 'legacy'}") |
| if 'event_sequence' in data: |
| print(f"Event count: {len(data['event_sequence'])}") |
| if 'trades' in data: |
| print(f"Trade count: {len(data['trades'])}") |
| if 'source_token' in data: |
| print(f"Source token: {data['source_token']}") |
| if 'class_id' in data: |
| print(f"Class ID: {data['class_id']}") |
| if 'context_bucket' in data: |
| print(f"Context bucket: {data['context_bucket']}") |
| if 'context_score' in data: |
| print(f"Context score: {data['context_score']}") |
| if 'quality_score' in data: |
| print(f"Quality score: {data['quality_score']}") |
| |
| print(f"\nDone! JSON saved to: {output_path}") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|