oracle / scripts /dump_cache_sample.py
zirobtc's picture
Upload folder using huggingface_hub
d195287 verified
#!/usr/bin/env python3
"""
Dump a cached .pt sample to JSON for manual debugging.
Usage:
python scripts/dump_cache_sample.py # Dump first sample
python scripts/dump_cache_sample.py --index 5 # Dump sample at index 5
python scripts/dump_cache_sample.py --file data/cache/sample_ABC123.pt # Dump specific file
python scripts/dump_cache_sample.py --output debug.json # Custom output path
"""
import argparse
import json
import sys
import os
# Add project root to path so torch.load can find project modules when unpickling
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import numpy as np
from pathlib import Path
from datetime import datetime
def convert_to_serializable(obj):
"""Recursively convert non-JSON-serializable objects."""
if obj is None:
return None
if isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, np.ndarray):
return {"__type__": "ndarray", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
if isinstance(obj, torch.Tensor):
data = obj.tolist()
# Truncate large tensors for readability
if obj.numel() > 50:
flat = obj.flatten().tolist()
data = flat[:20] + [f"... ({obj.numel()} elements total)"]
return {"__type__": "tensor", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": data}
# Handle EmbeddingPooler specifically
if type(obj).__name__ == 'EmbeddingPooler':
try:
items = obj.get_all_items()
return {
"__type__": "EmbeddingPooler",
"count": len(items),
"items": [convert_to_serializable(item) for item in items]
}
except:
return {"__type__": "EmbeddingPooler", "repr": str(obj)}
if isinstance(obj, datetime):
return {"__type__": "datetime", "value": obj.isoformat()}
if isinstance(obj, bytes):
return {"__type__": "bytes", "length": len(obj), "preview": obj[:100].hex() if len(obj) > 0 else ""}
if isinstance(obj, dict):
return {str(k): convert_to_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [convert_to_serializable(item) for item in obj]
if isinstance(obj, set):
return {"__type__": "set", "data": list(obj)}
# Fallback: try str representation
try:
return {"__type__": type(obj).__name__, "repr": str(obj)[:500]}
except:
return {"__type__": "unknown", "repr": "<not serializable>"}
def main():
parser = argparse.ArgumentParser(description="Dump cached .pt sample to JSON")
parser.add_argument("--index", "-i", type=int, default=0, help="Index of sample to dump (default: 0)")
parser.add_argument("--file", "-f", type=str, default=None, help="Direct path to .pt file (overrides --index)")
parser.add_argument("--cache_dir", "-c", type=str, default="data/cache", help="Cache directory (default: data/cache)")
parser.add_argument("--output", "-o", type=str, default=None, help="Output JSON path (default: auto-generated)")
parser.add_argument("--compact", action="store_true", help="Compact JSON output (no indentation)")
args = parser.parse_args()
# Determine which file to load
if args.file:
filepath = Path(args.file)
if not filepath.exists():
print(f"ERROR: File not found: {filepath}")
return 1
else:
cache_dir = Path(args.cache_dir)
if not cache_dir.is_dir():
print(f"ERROR: Cache directory not found: {cache_dir}")
return 1
cached_files = sorted(cache_dir.glob("sample_*.pt"))
if not cached_files:
print(f"ERROR: No sample_*.pt files found in {cache_dir}")
return 1
if args.index >= len(cached_files):
print(f"ERROR: Index {args.index} out of range. Found {len(cached_files)} files.")
return 1
filepath = cached_files[args.index]
print(f"Loading: {filepath}")
# Load the .pt file
try:
data = torch.load(filepath, map_location="cpu", weights_only=False)
except Exception as e:
print(f"ERROR: Failed to load file: {e}")
return 1
# Convert to JSON-serializable format
print("Converting to JSON-serializable format...")
serializable_data = convert_to_serializable(data)
# Add metadata
output_data = {
"__metadata__": {
"source_file": str(filepath.absolute()),
"dumped_at": datetime.now().isoformat(),
"cache_format": "context" if isinstance(data, dict) and "event_sequence" in data else "legacy"
},
"data": serializable_data
}
# Determine output path
if args.output:
output_path = Path(args.output)
else:
# Default: Save to current directory (root) instead of inside cache dir
output_path = Path.cwd() / filepath.with_suffix(".json").name
# Write JSON
print(f"Writing to: {output_path}")
indent = None if args.compact else 2
with open(output_path, "w") as f:
json.dump(output_data, f, indent=indent, ensure_ascii=False)
# Print summary
if isinstance(data, dict):
print(f"\n=== Summary ===")
print(f"Top-level keys: {list(data.keys())}")
print(f"Cache format: {'context' if 'event_sequence' in data else 'legacy'}")
if 'event_sequence' in data:
print(f"Event count: {len(data['event_sequence'])}")
if 'trades' in data:
print(f"Trade count: {len(data['trades'])}")
if 'source_token' in data:
print(f"Source token: {data['source_token']}")
if 'class_id' in data:
print(f"Class ID: {data['class_id']}")
if 'context_bucket' in data:
print(f"Context bucket: {data['context_bucket']}")
if 'context_score' in data:
print(f"Context score: {data['context_score']}")
if 'quality_score' in data:
print(f"Quality score: {data['quality_score']}")
print(f"\nDone! JSON saved to: {output_path}")
return 0
if __name__ == "__main__":
exit(main())