File size: 2,693 Bytes
f53b3ee c471f42 d195287 f53b3ee c471f42 f53b3ee c471f42 f53b3ee c471f42 f53b3ee c471f42 f53b3ee c471f42 f53b3ee c471f42 f53b3ee c471f42 f53b3ee c471f42 f53b3ee d195287 f53b3ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | import sys
sys.path.append(".")
import torch
import json
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from data.data_loader import summarize_context_window
from data.quant_ohlc_feature_schema import FEATURE_VERSION
def rebuild_metadata(cache_dir="data/cache"):
cache_path = Path(cache_dir)
print(f"Scanning {cache_path} for .pt files...")
files = sorted(list(cache_path.glob("sample_*.pt")))
if not files:
print("No .pt files found!")
return
print(f"Found {len(files)} files. Reading class IDs and context summaries...")
file_class_map = {}
file_context_bucket_map = {}
file_context_summary_map = {}
class_distribution = defaultdict(int)
context_distribution = defaultdict(lambda: defaultdict(int))
for f in tqdm(files):
try:
# We only need the class_id, no need to load the whole extensive tensor data if possible
# But torch.load loads everything. To be safe/fast, we just load on CPU.
data = torch.load(f, map_location="cpu", weights_only=False)
cid = data.get("class_id", 0)
context_summary = summarize_context_window(data.get("labels"), data.get("labels_mask"))
file_class_map[f.name] = cid
file_context_bucket_map[f.name] = context_summary["context_bucket"]
file_context_summary_map[f.name] = context_summary
class_distribution[cid] += 1
context_distribution[cid][context_summary["context_bucket"]] += 1
except Exception as e:
print(f"Error reading {f.name}: {e}")
output_data = {
'file_class_map': file_class_map,
'file_context_bucket_map': file_context_bucket_map,
'file_context_summary_map': file_context_summary_map,
'class_distribution': {str(k): v for k, v in class_distribution.items()},
'context_distribution': {
str(k): {bucket: count for bucket, count in bucket_counts.items()}
for k, bucket_counts in context_distribution.items()
},
# These are informational, setting defaults to avoid breaking if loader checks them
'num_workers': 1,
'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
'quantiles': [0.1, 0.5, 0.9],
'quant_feature_version': FEATURE_VERSION,
}
out_file = cache_path / "class_metadata.json"
with open(out_file, 'w') as f:
json.dump(output_data, f, indent=2)
print(f"Successfully rebuilt metadata for {len(file_class_map)} files.")
print(f"Saved to: {out_file}")
if __name__ == "__main__":
rebuild_metadata()
|