| import sys |
| sys.path.append(".") |
| import torch |
| import json |
| from pathlib import Path |
| from tqdm import tqdm |
| from collections import defaultdict |
| from data.data_loader import summarize_context_window |
| from data.quant_ohlc_feature_schema import FEATURE_VERSION |
|
|
| def rebuild_metadata(cache_dir="data/cache"): |
| cache_path = Path(cache_dir) |
| print(f"Scanning {cache_path} for .pt files...") |
| |
| files = sorted(list(cache_path.glob("sample_*.pt"))) |
| if not files: |
| print("No .pt files found!") |
| return |
|
|
| print(f"Found {len(files)} files. Reading class IDs and context summaries...") |
| |
| file_class_map = {} |
| file_context_bucket_map = {} |
| file_context_summary_map = {} |
| class_distribution = defaultdict(int) |
| context_distribution = defaultdict(lambda: defaultdict(int)) |
| |
| for f in tqdm(files): |
| try: |
| |
| |
| data = torch.load(f, map_location="cpu", weights_only=False) |
| cid = data.get("class_id", 0) |
| context_summary = summarize_context_window(data.get("labels"), data.get("labels_mask")) |
| file_class_map[f.name] = cid |
| file_context_bucket_map[f.name] = context_summary["context_bucket"] |
| file_context_summary_map[f.name] = context_summary |
| class_distribution[cid] += 1 |
| context_distribution[cid][context_summary["context_bucket"]] += 1 |
| except Exception as e: |
| print(f"Error reading {f.name}: {e}") |
|
|
| output_data = { |
| 'file_class_map': file_class_map, |
| 'file_context_bucket_map': file_context_bucket_map, |
| 'file_context_summary_map': file_context_summary_map, |
| 'class_distribution': {str(k): v for k, v in class_distribution.items()}, |
| 'context_distribution': { |
| str(k): {bucket: count for bucket, count in bucket_counts.items()} |
| for k, bucket_counts in context_distribution.items() |
| }, |
| |
| 'num_workers': 1, |
| 'horizons_seconds': [300, 900, 1800, 3600, 7200], |
| 'quantiles': [0.1, 0.5, 0.9], |
| 'quant_feature_version': FEATURE_VERSION, |
| } |
| |
| out_file = cache_path / "class_metadata.json" |
| with open(out_file, 'w') as f: |
| json.dump(output_data, f, indent=2) |
| |
| print(f"Successfully rebuilt metadata for {len(file_class_map)} files.") |
| print(f"Saved to: {out_file}") |
|
|
| if __name__ == "__main__": |
| rebuild_metadata() |
|
|