oracle / scripts /rebuild_metadata.py
zirobtc's picture
Upload folder using huggingface_hub
d195287 verified
import sys
sys.path.append(".")
import torch
import json
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from data.data_loader import summarize_context_window
from data.quant_ohlc_feature_schema import FEATURE_VERSION
def rebuild_metadata(cache_dir="data/cache"):
cache_path = Path(cache_dir)
print(f"Scanning {cache_path} for .pt files...")
files = sorted(list(cache_path.glob("sample_*.pt")))
if not files:
print("No .pt files found!")
return
print(f"Found {len(files)} files. Reading class IDs and context summaries...")
file_class_map = {}
file_context_bucket_map = {}
file_context_summary_map = {}
class_distribution = defaultdict(int)
context_distribution = defaultdict(lambda: defaultdict(int))
for f in tqdm(files):
try:
# We only need the class_id, no need to load the whole extensive tensor data if possible
# But torch.load loads everything. To be safe/fast, we just load on CPU.
data = torch.load(f, map_location="cpu", weights_only=False)
cid = data.get("class_id", 0)
context_summary = summarize_context_window(data.get("labels"), data.get("labels_mask"))
file_class_map[f.name] = cid
file_context_bucket_map[f.name] = context_summary["context_bucket"]
file_context_summary_map[f.name] = context_summary
class_distribution[cid] += 1
context_distribution[cid][context_summary["context_bucket"]] += 1
except Exception as e:
print(f"Error reading {f.name}: {e}")
output_data = {
'file_class_map': file_class_map,
'file_context_bucket_map': file_context_bucket_map,
'file_context_summary_map': file_context_summary_map,
'class_distribution': {str(k): v for k, v in class_distribution.items()},
'context_distribution': {
str(k): {bucket: count for bucket, count in bucket_counts.items()}
for k, bucket_counts in context_distribution.items()
},
# These are informational, setting defaults to avoid breaking if loader checks them
'num_workers': 1,
'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
'quantiles': [0.1, 0.5, 0.9],
'quant_feature_version': FEATURE_VERSION,
}
out_file = cache_path / "class_metadata.json"
with open(out_file, 'w') as f:
json.dump(output_data, f, indent=2)
print(f"Successfully rebuilt metadata for {len(file_class_map)} files.")
print(f"Saved to: {out_file}")
if __name__ == "__main__":
rebuild_metadata()