import argparse import json import os import sys from pathlib import Path from typing import Any, Dict, List import torch # Ensure repo root is on sys.path REPO_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(REPO_ROOT)) from data.data_loader import OracleDataset from data.data_fetcher import DataFetcher from data.data_collator import MemecoinCollator import models.vocabulary as vocab def _decode_events(event_type_ids: torch.Tensor) -> List[str]: names = [] for eid in event_type_ids.tolist(): if eid == 0: names.append("__PAD__") else: names.append(vocab.ID_TO_EVENT.get(eid, f"UNK_{eid}")) return names def _tensor_to_list(t: torch.Tensor) -> List: return t.detach().cpu().tolist() def main() -> None: parser = argparse.ArgumentParser(description="Inspect MemecoinCollator outputs on cached samples.") parser.add_argument("--cache_dir", type=str, default="data/cache") parser.add_argument("--idx", type=int, nargs="+", default=[0], help="Sample indices to inspect") parser.add_argument("--max_seq_len", type=int, default=16000) parser.add_argument("--out", type=str, default="collator_dump.json") args = parser.parse_args() cache_dir = Path(args.cache_dir) # Optional: enable time-aware fetches if DB env is set. import os from dotenv import load_dotenv from clickhouse_driver import Client as ClickHouseClient from neo4j import GraphDatabase load_dotenv() clickhouse_host = os.getenv("CLICKHOUSE_HOST", "localhost") clickhouse_port = int(os.getenv("CLICKHOUSE_NATIVE_PORT", os.getenv("CLICKHOUSE_PORT", 9000))) neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") neo4j_user = os.getenv("NEO4J_USER", "neo4j") neo4j_password = os.getenv("NEO4J_PASSWORD", "password") clickhouse_client = ClickHouseClient(host=clickhouse_host, port=clickhouse_port) neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password)) data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) dataset = OracleDataset( data_fetcher=data_fetcher, cache_dir=str(cache_dir), horizons_seconds=[30, 60, 120, 240, 420], quantiles=[0.1, 0.5, 0.9], max_samples=None, max_seq_len=args.max_seq_len, ) if hasattr(dataset, "init_fetcher"): dataset.init_fetcher() collator = MemecoinCollator( event_type_to_id=vocab.EVENT_TO_ID, device=torch.device("cpu"), dtype=torch.float32, max_seq_len=args.max_seq_len, ) batch_items = [dataset[i] for i in args.idx] batch = collator(batch_items) # Build JSON-friendly dump (no truncation of events; embeddings are omitted) dump: Dict[str, Any] = { "batch_size": len(args.idx), "token_addresses": batch.get("token_addresses"), "t_cutoffs": batch.get("t_cutoffs"), "sample_indices": batch.get("sample_indices"), "raw_events": [item.get("event_sequence", []) for item in batch_items], } # Raw event type counts event_counts = [] for item in batch_items: counts: Dict[str, int] = {} for ev in item.get("event_sequence", []): et = ev.get("event_type", "UNKNOWN") counts[et] = counts.get(et, 0) + 1 event_counts.append(counts) dump["raw_event_counts"] = event_counts # Core sequence + features (full length) dump["event_type_ids"] = _tensor_to_list(batch["event_type_ids"]) dump["event_type_names"] = [ _decode_events(batch["event_type_ids"][i].cpu()) for i in range(batch["event_type_ids"].shape[0]) ] dump["timestamps_float"] = _tensor_to_list(batch["timestamps_float"]) dump["relative_ts"] = _tensor_to_list(batch["relative_ts"]) dump["attention_mask"] = _tensor_to_list(batch["attention_mask"]) dump["wallet_addr_to_batch_idx"] = batch.get("wallet_addr_to_batch_idx", {}) # Pointer tensors for key in [ "wallet_indices", "token_indices", "quote_token_indices", "trending_token_indices", "boosted_token_indices", "dest_wallet_indices", "original_author_indices", "ohlc_indices", "holder_snapshot_indices", "textual_event_indices", ]: if key in batch: dump[key] = _tensor_to_list(batch[key]) # Numerical feature tensors nonzero_summary = {} for key in [ "transfer_numerical_features", "trade_numerical_features", "deployer_trade_numerical_features", "smart_wallet_trade_numerical_features", "pool_created_numerical_features", "liquidity_change_numerical_features", "fee_collected_numerical_features", "token_burn_numerical_features", "supply_lock_numerical_features", "onchain_snapshot_numerical_features", "trending_token_numerical_features", "boosted_token_numerical_features", "dexboost_paid_numerical_features", "dexprofile_updated_flags", "global_trending_numerical_features", "chainsnapshot_numerical_features", "lighthousesnapshot_numerical_features", ]: if key in batch: t = batch[key] dump[key] = _tensor_to_list(t) nonzero_summary[key] = int(torch.count_nonzero(t).item()) # Categorical feature tensors for key in [ "trade_dex_ids", "trade_direction_ids", "trade_mev_protection_ids", "trade_is_bundle_ids", "pool_created_protocol_ids", "liquidity_change_type_ids", "trending_token_source_ids", "trending_token_timeframe_ids", "lighthousesnapshot_protocol_ids", "lighthousesnapshot_timeframe_ids", "migrated_protocol_ids", "alpha_group_ids", "channel_ids", "exchange_ids", ]: if key in batch: t = batch[key] dump[key] = _tensor_to_list(t) nonzero_summary[key] = int(torch.count_nonzero(t).item()) # Labels if batch.get("labels") is not None: dump["labels"] = _tensor_to_list(batch["labels"]) if batch.get("labels_mask") is not None: dump["labels_mask"] = _tensor_to_list(batch["labels_mask"]) if batch.get("quality_score") is not None: dump["quality_score"] = _tensor_to_list(batch["quality_score"]) dump["nonzero_summary"] = nonzero_summary # Raw wallet/token feature payloads used by encoders wallet_inputs = batch.get("wallet_encoder_inputs", {}) token_inputs = batch.get("token_encoder_inputs", {}) dump["wallet_encoder_inputs"] = { "profile_rows": wallet_inputs.get("profile_rows", []), "social_rows": wallet_inputs.get("social_rows", []), "holdings_batch": wallet_inputs.get("holdings_batch", []), "username_embed_indices": _tensor_to_list(wallet_inputs.get("username_embed_indices")) if "username_embed_indices" in wallet_inputs else [], } dump["token_encoder_inputs"] = { "addresses_for_lookup": token_inputs.get("_addresses_for_lookup", []), "protocol_ids": _tensor_to_list(token_inputs.get("protocol_ids")) if "protocol_ids" in token_inputs else [], "is_vanity_flags": _tensor_to_list(token_inputs.get("is_vanity_flags")) if "is_vanity_flags" in token_inputs else [], "name_embed_indices": _tensor_to_list(token_inputs.get("name_embed_indices")) if "name_embed_indices" in token_inputs else [], "symbol_embed_indices": _tensor_to_list(token_inputs.get("symbol_embed_indices")) if "symbol_embed_indices" in token_inputs else [], "image_embed_indices": _tensor_to_list(token_inputs.get("image_embed_indices")) if "image_embed_indices" in token_inputs else [], } dump["wallet_set_encoder_inputs"] = { "holdings_batch": wallet_inputs.get("holdings_batch", []), "token_vibe_lookup_keys": token_inputs.get("_addresses_for_lookup", []), } out_path = Path(args.out) def _json_default(o): if isinstance(o, (str, int, float, bool)) or o is None: return o try: import datetime as _dt if isinstance(o, (_dt.datetime, _dt.date)): return o.isoformat() except Exception: pass try: return str(o) except Exception: return "" with out_path.open("w") as f: json.dump(dump, f, indent=2, default=_json_default) print(f"Wrote collator dump to {out_path.resolve()}") if __name__ == "__main__": main()