| import argparse |
| import math |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from data.data_loader import summarize_context_window |
| from data.quant_ohlc_feature_schema import FEATURE_VERSION, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT |
|
|
|
|
| REQUIRED_CONTEXT_FIELDS = [ |
| "event_sequence", |
| "wallets", |
| "tokens", |
| "labels", |
| "labels_mask", |
| "quality_score", |
| "class_id", |
| "source_token", |
| "context_bucket", |
| "context_score", |
| "quant_ohlc_features", |
| "quant_feature_version", |
| ] |
|
|
|
|
| def _to_list(value): |
| if value is None: |
| return [] |
| if isinstance(value, torch.Tensor): |
| return value.tolist() |
| return list(value) |
|
|
|
|
| def _safe_float(value): |
| if isinstance(value, torch.Tensor): |
| if value.numel() != 1: |
| raise ValueError("Expected scalar tensor.") |
| return float(value.item()) |
| return float(value) |
|
|
|
|
| def audit_cache(cache_dir, num_samples=None): |
| cache_path = Path(cache_dir) |
| files = sorted(cache_path.glob("sample_*.pt")) |
| if not files: |
| print(f"No sample_*.pt files found in {cache_path}") |
| return |
|
|
| if num_samples is not None and num_samples > 0: |
| files = files[:num_samples] |
|
|
| issues = Counter() |
| class_counts = Counter() |
| bucket_counts = Counter() |
| class_bucket_counts = defaultdict(Counter) |
| token_counts_by_class = defaultdict(Counter) |
| samples_per_token = Counter() |
| missing_fields = Counter() |
|
|
| stats = { |
| "files_audited": len(files), |
| "empty_event_sequence": 0, |
| "missing_wallets": 0, |
| "missing_tokens": 0, |
| "nan_labels": 0, |
| "nan_masks": 0, |
| "nan_quality_score": 0, |
| "negative_quality_score": 0, |
| "max_label_return": -float("inf"), |
| "min_label_return": float("inf"), |
| "max_events": 0, |
| "min_events": float("inf"), |
| "contexts_with_no_valid_horizons": 0, |
| "context_bucket_mismatch": 0, |
| "context_score_mismatch": 0, |
| "quant_feature_version_mismatch": 0, |
| "chart_events_missing_quant": 0, |
| "quant_segments_total": 0, |
| } |
|
|
| for filepath in tqdm(files, desc="Auditing cache", unit="file"): |
| try: |
| data = torch.load(filepath, map_location="cpu", weights_only=False) |
| except Exception: |
| issues["load_error"] += 1 |
| continue |
|
|
| if not isinstance(data, dict): |
| issues["not_dict"] += 1 |
| continue |
|
|
| missing_for_file = [] |
| for field in REQUIRED_CONTEXT_FIELDS: |
| if field not in data: |
| missing_for_file.append(field) |
| missing_fields[field] += 1 |
|
|
| if missing_for_file: |
| issues["missing_required_fields"] += 1 |
| continue |
|
|
| class_id = int(data["class_id"]) |
| source_token = str(data["source_token"]) |
| context_bucket = str(data["context_bucket"]) |
|
|
| class_counts[class_id] += 1 |
| bucket_counts[context_bucket] += 1 |
| class_bucket_counts[class_id][context_bucket] += 1 |
| token_counts_by_class[class_id][source_token] += 1 |
| samples_per_token[source_token] += 1 |
|
|
| events = data.get("event_sequence") or [] |
| wallets = data.get("wallets") or {} |
| tokens = data.get("tokens") or {} |
| labels = _to_list(data.get("labels")) |
| masks = _to_list(data.get("labels_mask")) |
|
|
| if not events: |
| stats["empty_event_sequence"] += 1 |
| stats["max_events"] = max(stats["max_events"], len(events)) |
| stats["min_events"] = min(stats["min_events"], len(events)) |
|
|
| if not wallets: |
| stats["missing_wallets"] += 1 |
| if not tokens: |
| stats["missing_tokens"] += 1 |
|
|
| has_nan_label = False |
| for value in labels: |
| if math.isnan(float(value)): |
| has_nan_label = True |
| break |
| stats["max_label_return"] = max(stats["max_label_return"], float(value)) |
| stats["min_label_return"] = min(stats["min_label_return"], float(value)) |
| if has_nan_label: |
| stats["nan_labels"] += 1 |
|
|
| has_nan_mask = False |
| for value in masks: |
| if math.isnan(float(value)): |
| has_nan_mask = True |
| break |
| if has_nan_mask: |
| stats["nan_masks"] += 1 |
|
|
| try: |
| quality_score = _safe_float(data.get("quality_score")) |
| if math.isnan(quality_score): |
| stats["nan_quality_score"] += 1 |
| elif quality_score < 0: |
| stats["negative_quality_score"] += 1 |
| except Exception: |
| issues["invalid_quality_score"] += 1 |
|
|
| try: |
| summary = summarize_context_window(data.get("labels"), data.get("labels_mask")) |
| if summary["valid_horizons"] == 0: |
| stats["contexts_with_no_valid_horizons"] += 1 |
| if summary["context_bucket"] != context_bucket: |
| stats["context_bucket_mismatch"] += 1 |
| stored_score = _safe_float(data.get("context_score")) |
| if not math.isclose(summary["context_score"], stored_score, rel_tol=1e-6, abs_tol=1e-6): |
| stats["context_score_mismatch"] += 1 |
| except Exception: |
| issues["context_summary_error"] += 1 |
|
|
| if data.get("quant_feature_version") != FEATURE_VERSION: |
| stats["quant_feature_version_mismatch"] += 1 |
|
|
| chart_events = [event for event in events if event.get("event_type") == "Chart_Segment"] |
| stats["quant_segments_total"] += len(chart_events) |
| for event in chart_events: |
| quant_payload = event.get("quant_ohlc_features") |
| if not isinstance(quant_payload, list): |
| stats["chart_events_missing_quant"] += 1 |
| continue |
| if len(quant_payload) > TOKENS_PER_SEGMENT: |
| issues["quant_too_many_tokens"] += 1 |
| for token_payload in quant_payload: |
| vec = token_payload.get("feature_vector") |
| if not isinstance(vec, list) or len(vec) != NUM_QUANT_OHLC_FEATURES: |
| issues["quant_bad_vector_shape"] += 1 |
| break |
|
|
| if stats["min_events"] == float("inf"): |
| stats["min_events"] = 0 |
| if stats["min_label_return"] == float("inf"): |
| stats["min_label_return"] = 0.0 |
| if stats["max_label_return"] == -float("inf"): |
| stats["max_label_return"] = 0.0 |
|
|
| unique_tokens_total = len(samples_per_token) |
| duplicate_tokens_total = sum(1 for count in samples_per_token.values() if count > 1) |
|
|
| print("\n=== Cache Audit ===") |
| print(f"Cache dir: {cache_path}") |
| print(f"Files audited: {stats['files_audited']}") |
| print(f"Unique source tokens: {unique_tokens_total}") |
| print(f"Tokens with >1 cached context: {duplicate_tokens_total}") |
| print(f"Samples per token max: {max(samples_per_token.values()) if samples_per_token else 0}") |
|
|
| print("\n--- Class Counts ---") |
| for class_id in sorted(class_counts): |
| unique_tokens = len(token_counts_by_class[class_id]) |
| print(f"Class {class_id}: samples={class_counts[class_id]} unique_tokens={unique_tokens}") |
|
|
| print("\n--- Context Buckets ---") |
| for bucket, count in sorted(bucket_counts.items()): |
| print(f"{bucket}: {count}") |
|
|
| print("\n--- Class x Context Bucket ---") |
| for class_id in sorted(class_bucket_counts): |
| bucket_summary = dict(sorted(class_bucket_counts[class_id].items())) |
| print(f"Class {class_id}: {bucket_summary}") |
|
|
| print("\n--- General Stats ---") |
| for key, value in stats.items(): |
| print(f"{key}: {value}") |
|
|
| print("\n--- Missing Fields ---") |
| if missing_fields: |
| for field, count in sorted(missing_fields.items()): |
| print(f"{field}: {count}") |
| else: |
| print("none") |
|
|
| print("\n--- Issues ---") |
| if issues: |
| for key, value in sorted(issues.items()): |
| print(f"{key}: {value}") |
| else: |
| print("none") |
|
|
| print("\n--- Duplicate-Heavy Tokens ---") |
| heavy_tokens = sorted(samples_per_token.items(), key=lambda item: (-item[1], item[0]))[:20] |
| for token, count in heavy_tokens: |
| print(f"{token}: {count}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--cache_dir", type=str, default="/workspace/apollo/data/cache") |
| parser.add_argument("--num", type=int, default=None, help="Audit only the first N files.") |
| args = parser.parse_args() |
|
|
| audit_cache(args.cache_dir, args.num) |
|
|