import sys sys.path.append(".") import torch import json from pathlib import Path from tqdm import tqdm from collections import defaultdict from data.data_loader import summarize_context_window from data.quant_ohlc_feature_schema import FEATURE_VERSION def rebuild_metadata(cache_dir="data/cache"): cache_path = Path(cache_dir) print(f"Scanning {cache_path} for .pt files...") files = sorted(list(cache_path.glob("sample_*.pt"))) if not files: print("No .pt files found!") return print(f"Found {len(files)} files. Reading class IDs and context summaries...") file_class_map = {} file_context_bucket_map = {} file_context_summary_map = {} class_distribution = defaultdict(int) context_distribution = defaultdict(lambda: defaultdict(int)) for f in tqdm(files): try: # We only need the class_id, no need to load the whole extensive tensor data if possible # But torch.load loads everything. To be safe/fast, we just load on CPU. data = torch.load(f, map_location="cpu", weights_only=False) cid = data.get("class_id", 0) context_summary = summarize_context_window(data.get("labels"), data.get("labels_mask")) file_class_map[f.name] = cid file_context_bucket_map[f.name] = context_summary["context_bucket"] file_context_summary_map[f.name] = context_summary class_distribution[cid] += 1 context_distribution[cid][context_summary["context_bucket"]] += 1 except Exception as e: print(f"Error reading {f.name}: {e}") output_data = { 'file_class_map': file_class_map, 'file_context_bucket_map': file_context_bucket_map, 'file_context_summary_map': file_context_summary_map, 'class_distribution': {str(k): v for k, v in class_distribution.items()}, 'context_distribution': { str(k): {bucket: count for bucket, count in bucket_counts.items()} for k, bucket_counts in context_distribution.items() }, # These are informational, setting defaults to avoid breaking if loader checks them 'num_workers': 1, 'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh 'quantiles': [0.1, 0.5, 0.9], 'quant_feature_version': FEATURE_VERSION, } out_file = cache_path / "class_metadata.json" with open(out_file, 'w') as f: json.dump(output_data, f, indent=2) print(f"Successfully rebuilt metadata for {len(file_class_map)} files.") print(f"Saved to: {out_file}") if __name__ == "__main__": rebuild_metadata()