import argparse import math from collections import Counter, defaultdict from pathlib import Path import torch from tqdm import tqdm from data.data_loader import summarize_context_window from data.quant_ohlc_feature_schema import FEATURE_VERSION, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT REQUIRED_CONTEXT_FIELDS = [ "event_sequence", "wallets", "tokens", "labels", "labels_mask", "quality_score", "class_id", "source_token", "context_bucket", "context_score", "quant_ohlc_features", "quant_feature_version", ] def _to_list(value): if value is None: return [] if isinstance(value, torch.Tensor): return value.tolist() return list(value) def _safe_float(value): if isinstance(value, torch.Tensor): if value.numel() != 1: raise ValueError("Expected scalar tensor.") return float(value.item()) return float(value) def audit_cache(cache_dir, num_samples=None): cache_path = Path(cache_dir) files = sorted(cache_path.glob("sample_*.pt")) if not files: print(f"No sample_*.pt files found in {cache_path}") return if num_samples is not None and num_samples > 0: files = files[:num_samples] issues = Counter() class_counts = Counter() bucket_counts = Counter() class_bucket_counts = defaultdict(Counter) token_counts_by_class = defaultdict(Counter) samples_per_token = Counter() missing_fields = Counter() stats = { "files_audited": len(files), "empty_event_sequence": 0, "missing_wallets": 0, "missing_tokens": 0, "nan_labels": 0, "nan_masks": 0, "nan_quality_score": 0, "negative_quality_score": 0, "max_label_return": -float("inf"), "min_label_return": float("inf"), "max_events": 0, "min_events": float("inf"), "contexts_with_no_valid_horizons": 0, "context_bucket_mismatch": 0, "context_score_mismatch": 0, "quant_feature_version_mismatch": 0, "chart_events_missing_quant": 0, "quant_segments_total": 0, } for filepath in tqdm(files, desc="Auditing cache", unit="file"): try: data = torch.load(filepath, map_location="cpu", weights_only=False) except Exception: issues["load_error"] += 1 continue if not isinstance(data, dict): issues["not_dict"] += 1 continue missing_for_file = [] for field in REQUIRED_CONTEXT_FIELDS: if field not in data: missing_for_file.append(field) missing_fields[field] += 1 if missing_for_file: issues["missing_required_fields"] += 1 continue class_id = int(data["class_id"]) source_token = str(data["source_token"]) context_bucket = str(data["context_bucket"]) class_counts[class_id] += 1 bucket_counts[context_bucket] += 1 class_bucket_counts[class_id][context_bucket] += 1 token_counts_by_class[class_id][source_token] += 1 samples_per_token[source_token] += 1 events = data.get("event_sequence") or [] wallets = data.get("wallets") or {} tokens = data.get("tokens") or {} labels = _to_list(data.get("labels")) masks = _to_list(data.get("labels_mask")) if not events: stats["empty_event_sequence"] += 1 stats["max_events"] = max(stats["max_events"], len(events)) stats["min_events"] = min(stats["min_events"], len(events)) if not wallets: stats["missing_wallets"] += 1 if not tokens: stats["missing_tokens"] += 1 has_nan_label = False for value in labels: if math.isnan(float(value)): has_nan_label = True break stats["max_label_return"] = max(stats["max_label_return"], float(value)) stats["min_label_return"] = min(stats["min_label_return"], float(value)) if has_nan_label: stats["nan_labels"] += 1 has_nan_mask = False for value in masks: if math.isnan(float(value)): has_nan_mask = True break if has_nan_mask: stats["nan_masks"] += 1 try: quality_score = _safe_float(data.get("quality_score")) if math.isnan(quality_score): stats["nan_quality_score"] += 1 elif quality_score < 0: stats["negative_quality_score"] += 1 except Exception: issues["invalid_quality_score"] += 1 try: summary = summarize_context_window(data.get("labels"), data.get("labels_mask")) if summary["valid_horizons"] == 0: stats["contexts_with_no_valid_horizons"] += 1 if summary["context_bucket"] != context_bucket: stats["context_bucket_mismatch"] += 1 stored_score = _safe_float(data.get("context_score")) if not math.isclose(summary["context_score"], stored_score, rel_tol=1e-6, abs_tol=1e-6): stats["context_score_mismatch"] += 1 except Exception: issues["context_summary_error"] += 1 if data.get("quant_feature_version") != FEATURE_VERSION: stats["quant_feature_version_mismatch"] += 1 chart_events = [event for event in events if event.get("event_type") == "Chart_Segment"] stats["quant_segments_total"] += len(chart_events) for event in chart_events: quant_payload = event.get("quant_ohlc_features") if not isinstance(quant_payload, list): stats["chart_events_missing_quant"] += 1 continue if len(quant_payload) > TOKENS_PER_SEGMENT: issues["quant_too_many_tokens"] += 1 for token_payload in quant_payload: vec = token_payload.get("feature_vector") if not isinstance(vec, list) or len(vec) != NUM_QUANT_OHLC_FEATURES: issues["quant_bad_vector_shape"] += 1 break if stats["min_events"] == float("inf"): stats["min_events"] = 0 if stats["min_label_return"] == float("inf"): stats["min_label_return"] = 0.0 if stats["max_label_return"] == -float("inf"): stats["max_label_return"] = 0.0 unique_tokens_total = len(samples_per_token) duplicate_tokens_total = sum(1 for count in samples_per_token.values() if count > 1) print("\n=== Cache Audit ===") print(f"Cache dir: {cache_path}") print(f"Files audited: {stats['files_audited']}") print(f"Unique source tokens: {unique_tokens_total}") print(f"Tokens with >1 cached context: {duplicate_tokens_total}") print(f"Samples per token max: {max(samples_per_token.values()) if samples_per_token else 0}") print("\n--- Class Counts ---") for class_id in sorted(class_counts): unique_tokens = len(token_counts_by_class[class_id]) print(f"Class {class_id}: samples={class_counts[class_id]} unique_tokens={unique_tokens}") print("\n--- Context Buckets ---") for bucket, count in sorted(bucket_counts.items()): print(f"{bucket}: {count}") print("\n--- Class x Context Bucket ---") for class_id in sorted(class_bucket_counts): bucket_summary = dict(sorted(class_bucket_counts[class_id].items())) print(f"Class {class_id}: {bucket_summary}") print("\n--- General Stats ---") for key, value in stats.items(): print(f"{key}: {value}") print("\n--- Missing Fields ---") if missing_fields: for field, count in sorted(missing_fields.items()): print(f"{field}: {count}") else: print("none") print("\n--- Issues ---") if issues: for key, value in sorted(issues.items()): print(f"{key}: {value}") else: print("none") print("\n--- Duplicate-Heavy Tokens ---") heavy_tokens = sorted(samples_per_token.items(), key=lambda item: (-item[1], item[0]))[:20] for token, count in heavy_tokens: print(f"{token}: {count}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--cache_dir", type=str, default="/workspace/apollo/data/cache") parser.add_argument("--num", type=int, default=None, help="Audit only the first N files.") args = parser.parse_args() audit_cache(args.cache_dir, args.num)