File size: 2,693 Bytes
f53b3ee
 
 
 
 
 
 
c471f42
d195287
f53b3ee
 
 
 
 
 
 
 
 
 
c471f42
f53b3ee
 
c471f42
 
f53b3ee
c471f42
f53b3ee
 
 
 
 
 
 
c471f42
f53b3ee
c471f42
 
f53b3ee
c471f42
f53b3ee
 
 
 
 
c471f42
 
f53b3ee
c471f42
 
 
 
f53b3ee
 
 
 
d195287
f53b3ee
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import sys
sys.path.append(".")
import torch
import json
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from data.data_loader import summarize_context_window
from data.quant_ohlc_feature_schema import FEATURE_VERSION

def rebuild_metadata(cache_dir="data/cache"):
    cache_path = Path(cache_dir)
    print(f"Scanning {cache_path} for .pt files...")
    
    files = sorted(list(cache_path.glob("sample_*.pt")))
    if not files:
        print("No .pt files found!")
        return

    print(f"Found {len(files)} files. Reading class IDs and context summaries...")
    
    file_class_map = {}
    file_context_bucket_map = {}
    file_context_summary_map = {}
    class_distribution = defaultdict(int)
    context_distribution = defaultdict(lambda: defaultdict(int))
    
    for f in tqdm(files):
        try:
            # We only need the class_id, no need to load the whole extensive tensor data if possible
            # But torch.load loads everything. To be safe/fast, we just load on CPU.
            data = torch.load(f, map_location="cpu", weights_only=False)
            cid = data.get("class_id", 0)
            context_summary = summarize_context_window(data.get("labels"), data.get("labels_mask"))
            file_class_map[f.name] = cid
            file_context_bucket_map[f.name] = context_summary["context_bucket"]
            file_context_summary_map[f.name] = context_summary
            class_distribution[cid] += 1
            context_distribution[cid][context_summary["context_bucket"]] += 1
        except Exception as e:
            print(f"Error reading {f.name}: {e}")

    output_data = {
        'file_class_map': file_class_map,
        'file_context_bucket_map': file_context_bucket_map,
        'file_context_summary_map': file_context_summary_map,
        'class_distribution': {str(k): v for k, v in class_distribution.items()},
        'context_distribution': {
            str(k): {bucket: count for bucket, count in bucket_counts.items()}
            for k, bucket_counts in context_distribution.items()
        },
        # These are informational, setting defaults to avoid breaking if loader checks them
        'num_workers': 1,
        'horizons_seconds': [300, 900, 1800, 3600, 7200], # From user's pre_cache.sh
        'quantiles': [0.1, 0.5, 0.9],
        'quant_feature_version': FEATURE_VERSION,
    }
    
    out_file = cache_path / "class_metadata.json"
    with open(out_file, 'w') as f:
        json.dump(output_data, f, indent=2)
        
    print(f"Successfully rebuilt metadata for {len(file_class_map)} files.")
    print(f"Saved to: {out_file}")

if __name__ == "__main__":
    rebuild_metadata()