import os import sys import argparse import random import copy import math import torch from pathlib import Path # Add project root to sys.path so we can import data and models sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Provide standard defaults from accelerate import Accelerator from torch.utils.data import DataLoader, Subset from data.data_loader import OracleDataset from data.data_collator import MemecoinCollator from data.context_targets import MOVEMENT_ID_TO_CLASS from models.multi_modal_processor import MultiModalEncoder from models.helper_encoders import ContextualTimeEncoder from models.token_encoder import TokenEncoder from models.wallet_encoder import WalletEncoder from models.graph_updater import GraphUpdater from models.ohlc_embedder import OHLCEmbedder from models.quant_ohlc_embedder import QuantOHLCEmbedder from models.model import Oracle import models.vocabulary as vocab from data.quant_ohlc_feature_schema import FEATURE_GROUPS, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT, group_feature_indices from train import create_balanced_split from dotenv import load_dotenv from clickhouse_driver import Client as ClickHouseClient from neo4j import GraphDatabase from data.data_fetcher import DataFetcher from scripts.analyze_distribution import get_return_class_map ABLATION_SWEEP_MODES = [ "wallet", "graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "wallet_graph", "quant_ohlc", "quant_levels", "quant_trendline", "quant_breaks", "quant_rolling", ] OHLC_PROBE_MODES = [ "ohlc_reverse", "ohlc_shuffle_chunks", "ohlc_mask_recent", "ohlc_trend_only", "ohlc_summary_shuffle", "ohlc_detrend", "ohlc_smooth", ] def unlog_transform(tensor): """Invert the log1p transform applied during training.""" # During training: labels = torch.sign(labels) * torch.log1p(torch.abs(labels)) return torch.sign(tensor) * (torch.exp(torch.abs(tensor)) - 1) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str, default="checkpoints/checkpoint-90000", help="Path to checkpoint dir") parser.add_argument("--sample_idx", type=str, default=None, help="Specific sample index or Mint Address to evaluate") parser.add_argument("--mixed_precision", type=str, default="bf16") parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[300, 900, 1800, 3600, 7200]) parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9]) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--min_class", type=int, default=3, help="Filter out tokens with return class beneath this ID (e.g., 1 for >= 3x returns)") parser.add_argument("--cutoff_trade_idx", type=int, default=200, help="Force the T_cutoff at this exact trade index (e.g., 10 = right after the 10th trade)") parser.add_argument("--num_samples", type=int, default=1, help="Number of valid samples to evaluate and aggregate.") parser.add_argument("--max_retries", type=int, default=100, help="Maximum attempts to find valid contexts across samples.") parser.add_argument("--show_each", action="store_true", help="Print per-sample details for every evaluated sample.") parser.add_argument( "--ablation", type=str, default="none", choices=["none", "wallet", "graph", "wallet_graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "all", "sweep", "ohlc_probe", "quant_ohlc", "quant_levels", "quant_trendline", "quant_breaks", "quant_rolling"], help="Run inference with selected signal families removed, or use 'sweep' to rank multiple families.", ) return parser.parse_args() def clone_batch(batch): cloned = {} for key, value in batch.items(): if isinstance(value, torch.Tensor): cloned[key] = value.clone() else: cloned[key] = copy.deepcopy(value) return cloned def _empty_wallet_encoder_inputs(device): return { 'username_embed_indices': torch.tensor([], device=device, dtype=torch.long), 'profile_rows': [], 'social_rows': [], 'holdings_batch': [], } def _empty_token_encoder_inputs(device): return { 'name_embed_indices': torch.tensor([], device=device, dtype=torch.long), 'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long), 'image_embed_indices': torch.tensor([], device=device, dtype=torch.long), 'protocol_ids': torch.tensor([], device=device, dtype=torch.long), 'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool), '_addresses_for_lookup': [], } def apply_ablation(batch, mode, device): if mode == "none": return batch ablated = clone_batch(batch) if mode in {"wallet", "wallet_graph", "ohlc_wallet", "all"}: for key in ( "wallet_indices", "dest_wallet_indices", "original_author_indices", "holder_snapshot_indices", ): if key in ablated: ablated[key].zero_() ablated["wallet_encoder_inputs"] = _empty_wallet_encoder_inputs(device) ablated["wallet_addr_to_batch_idx"] = {} ablated["holder_snapshot_raw_data"] = [] ablated["graph_updater_links"] = {} if mode in {"graph", "wallet_graph", "all"}: ablated["graph_updater_links"] = {} if mode in {"social", "all"}: if "textual_event_indices" in ablated: ablated["textual_event_indices"].zero_() ablated["textual_event_data"] = [] if mode in {"token", "all"}: for key in ( "token_indices", "quote_token_indices", "trending_token_indices", "boosted_token_indices", ): if key in ablated: ablated[key].zero_() ablated["token_encoder_inputs"] = _empty_token_encoder_inputs(device) if mode in {"holder", "all"}: if "holder_snapshot_indices" in ablated: ablated["holder_snapshot_indices"].zero_() ablated["holder_snapshot_raw_data"] = [] if mode in {"ohlc", "ohlc_wallet", "all"}: if "ohlc_indices" in ablated: ablated["ohlc_indices"].zero_() if "ohlc_price_tensors" in ablated: ablated["ohlc_price_tensors"] = torch.zeros_like(ablated["ohlc_price_tensors"]) if "ohlc_interval_ids" in ablated: ablated["ohlc_interval_ids"] = torch.zeros_like(ablated["ohlc_interval_ids"]) if "quant_ohlc_feature_tensors" in ablated: ablated["quant_ohlc_feature_tensors"] = torch.zeros_like(ablated["quant_ohlc_feature_tensors"]) if "quant_ohlc_feature_mask" in ablated: ablated["quant_ohlc_feature_mask"] = torch.zeros_like(ablated["quant_ohlc_feature_mask"]) quant_group_map = { "quant_ohlc": list(FEATURE_GROUPS.keys()), "quant_levels": ["levels_breaks"], "quant_trendline": ["trendlines"], "quant_breaks": ["relative_structure", "levels_breaks"], "quant_rolling": ["rolling_quant"], } if mode in quant_group_map and "quant_ohlc_feature_tensors" in ablated: idxs = group_feature_indices(quant_group_map[mode]) if idxs: ablated["quant_ohlc_feature_tensors"][:, :, idxs] = 0 if mode in {"trade", "all"}: for key in ( "trade_numerical_features", "deployer_trade_numerical_features", "smart_wallet_trade_numerical_features", "transfer_numerical_features", "pool_created_numerical_features", "liquidity_change_numerical_features", "fee_collected_numerical_features", "token_burn_numerical_features", "supply_lock_numerical_features", "boosted_token_numerical_features", "trending_token_numerical_features", "dexboost_paid_numerical_features", "global_trending_numerical_features", "chainsnapshot_numerical_features", "lighthousesnapshot_numerical_features", "dexprofile_updated_flags", ): if key in ablated: ablated[key] = torch.zeros_like(ablated[key]) for key in ( "trade_dex_ids", "trade_direction_ids", "trade_mev_protection_ids", "trade_is_bundle_ids", "pool_created_protocol_ids", "liquidity_change_type_ids", "trending_token_source_ids", "trending_token_timeframe_ids", "lighthousesnapshot_protocol_ids", "lighthousesnapshot_timeframe_ids", "migrated_protocol_ids", "alpha_group_ids", "channel_ids", "exchange_ids", ): if key in ablated: ablated[key] = torch.zeros_like(ablated[key]) if mode == "onchain": if "onchain_snapshot_numerical_features" in ablated: ablated["onchain_snapshot_numerical_features"] = torch.zeros_like(ablated["onchain_snapshot_numerical_features"]) return ablated def _chunk_permutation_indices(length, chunk_size): if length <= 0: return [] chunks = [list(range(i, min(i + chunk_size, length))) for i in range(0, length, chunk_size)] if len(chunks) <= 1: return list(range(length)) permuted = list(reversed(chunks)) out = [] for chunk in permuted: out.extend(chunk) return out def _moving_average_1d(series, kernel_size): if kernel_size <= 1 or series.numel() == 0: return series pad = kernel_size // 2 kernel = torch.ones(1, 1, kernel_size, device=series.device, dtype=series.dtype) / float(kernel_size) x = series.view(1, 1, -1) x = torch.nn.functional.pad(x, (pad, pad), mode="replicate") smoothed = torch.nn.functional.conv1d(x, kernel) return smoothed.view(-1)[: series.numel()] def _linear_trend(series): if series.numel() <= 1: return series.clone() start = series[0] end = series[-1] steps = torch.linspace(0.0, 1.0, series.numel(), device=series.device, dtype=series.dtype) return start + (end - start) * steps def _summary_preserving_shuffle(series, chunk_size=20): length = series.numel() if length <= 2: return series chunks = [] interior_start = 1 interior_end = length - 1 for i in range(interior_start, interior_end, chunk_size): chunks.append(series[i:min(i + chunk_size, interior_end)].clone()) if len(chunks) <= 1: return series reordered = list(reversed(chunks)) out = series.clone() cursor = 1 for chunk in reordered: out[cursor:cursor + chunk.numel()] = chunk cursor += chunk.numel() out[0] = series[0] out[-1] = series[-1] return out def _apply_per_series(ohlc, transform_fn): out = ohlc.clone() for batch_idx in range(out.shape[0]): for channel_idx in range(out.shape[1]): out[batch_idx, channel_idx] = transform_fn(out[batch_idx, channel_idx]) return out def apply_ohlc_probe(batch, mode): probed = clone_batch(batch) if "ohlc_price_tensors" not in probed or probed["ohlc_price_tensors"].numel() == 0: return probed ohlc = probed["ohlc_price_tensors"].clone() seq_len = ohlc.shape[-1] if mode == "ohlc_reverse": probed["ohlc_price_tensors"] = torch.flip(ohlc, dims=[-1]) elif mode == "ohlc_shuffle_chunks": perm = _chunk_permutation_indices(seq_len, chunk_size=30) idx = torch.tensor(perm, device=ohlc.device, dtype=torch.long) probed["ohlc_price_tensors"] = ohlc.index_select(-1, idx) elif mode == "ohlc_mask_recent": keep = max(seq_len - 60, 0) if keep < seq_len and keep > 0: fill = ohlc[..., keep - 1:keep].expand_as(ohlc[..., keep:]) ohlc[..., keep:] = fill elif keep == 0: ohlc.zero_() probed["ohlc_price_tensors"] = ohlc elif mode == "ohlc_trend_only": probed["ohlc_price_tensors"] = _apply_per_series(ohlc, _linear_trend) elif mode == "ohlc_summary_shuffle": probed["ohlc_price_tensors"] = _apply_per_series( ohlc, lambda series: _summary_preserving_shuffle(series, chunk_size=20), ) elif mode == "ohlc_detrend": def detrend(series): trend = _linear_trend(series) detrended = series - trend + series[0] detrended[0] = series[0] detrended[-1] = series[0] return detrended probed["ohlc_price_tensors"] = _apply_per_series(ohlc, detrend) elif mode == "ohlc_smooth": probed["ohlc_price_tensors"] = _apply_per_series( ohlc, lambda series: _moving_average_1d(series, kernel_size=11), ) return probed def run_inference(model, batch): with torch.no_grad(): outputs = model(batch) preds = outputs["quantile_logits"][0].detach().cpu() quality_pred = outputs["quality_logits"][0].detach().cpu() if "quality_logits" in outputs else None movement_pred = outputs["movement_logits"][0].detach().cpu() if "movement_logits" in outputs else None return preds, quality_pred, movement_pred def print_results(title, batch, preds, quality_pred, movement_pred, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, reference_preds=None, reference_quality=None): real_preds = unlog_transform(preds) num_quantiles = len(quantiles) num_gt_horizons = len(gt_mask) print(f"\n================== {title} ==================") print(f"Token Address: {batch.get('token_addresses', ['Unknown'])[0]}") if gt_quality is not None: quality_line = f"Quality Score: GT = {gt_quality:.4f} | Pred = {quality_pred.item() if quality_pred is not None else 'N/A'}" if reference_quality is not None and quality_pred is not None: quality_delta = quality_pred.item() - reference_quality.item() quality_line += f" | Delta vs Full = {quality_delta:+.6f}" print(quality_line) if movement_pred is not None: movement_targets = batch.get("movement_class_targets") movement_mask = batch.get("movement_class_mask") print("Movement Classes:") for h_idx, horizon in enumerate(horizons_seconds): if h_idx >= movement_pred.shape[0]: break target_txt = "N/A" if movement_targets is not None and movement_mask is not None and bool(movement_mask[0, h_idx].item()): target_txt = MOVEMENT_ID_TO_CLASS.get(int(movement_targets[0, h_idx].item()), "unknown") pred_class = int(movement_pred[h_idx].argmax().item()) pred_name = MOVEMENT_ID_TO_CLASS.get(pred_class, "unknown") pred_prob = float(torch.softmax(movement_pred[h_idx], dim=-1)[pred_class].item()) print( f" {horizon:>4}s GT = {target_txt:<12} | " f"Pred = {pred_name:<12} | " f"Conf = {pred_prob:.4f}" ) if "context_class_name" in batch: print(f"Context Class: {batch['context_class_name'][0]}") print("\nReturns per Horizon:") for h_idx, horizon in enumerate(horizons_seconds): horizon_min = horizon // 60 print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---") if h_idx >= num_gt_horizons: print(" [No Ground Truth Available for this Horizon - Not in Dataset]") valid = False else: valid = gt_mask[h_idx].item() if not valid: print(" [No Ground Truth Available for this Horizon - Masked]") else: gt_ret = gt_labels[h_idx].item() print(f" Ground Truth: {gt_ret * 100:.2f}%") print(" Predictions:") for q_idx, q in enumerate(quantiles): flat_idx = h_idx * num_quantiles + q_idx pred_ret = real_preds[flat_idx].item() log_pred = preds[flat_idx].item() line = f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})" if reference_preds is not None: ref_ret = unlog_transform(reference_preds)[flat_idx].item() line += f" | Delta vs Full: {(pred_ret - ref_ret) * 100:+7.2f}%" print(line) print("=============================================\n") def resolve_sample_index(dataset, sample_idx_arg, rng): if sample_idx_arg is not None: if isinstance(sample_idx_arg, str) and not sample_idx_arg.isdigit(): found_idx = next((i for i, m in enumerate(dataset.sampled_mints) if m['mint_address'] == sample_idx_arg), None) if found_idx is None: raise ValueError(f"Mint address {sample_idx_arg} not found in filtered dataset") return found_idx resolved = int(sample_idx_arg) if resolved >= len(dataset): raise ValueError(f"Sample index {resolved} out of range") return resolved return rng.randint(0, len(dataset.sampled_mints) - 1) def move_batch_to_device(batch, device): for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(device) elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor): batch[k] = [t.to(device) for t in v] if 'textual_event_indices' not in batch: B, L = batch['event_type_ids'].shape batch['textual_event_indices'] = torch.zeros((B, L), dtype=torch.long, device=device) if 'textual_event_data' not in batch: batch['textual_event_data'] = [] return batch def init_aggregate(horizons_seconds, quantiles): return { "count": 0, "quality_full_sum": 0.0, "quality_abl_sum": 0.0, "quality_delta_sum": 0.0, "gt_quality_sum": 0.0, "per_hq": { (h, q): { "full_sum": 0.0, "abl_sum": 0.0, "delta_sum": 0.0, "abs_delta_sum": 0.0, "gt_sum": 0.0, "valid_count": 0, } for h in horizons_seconds for q in quantiles }, } def update_aggregate(stats, full_preds, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, ablated_preds=None, full_quality=None, ablated_quality=None): stats["count"] += 1 if gt_quality is not None: stats["gt_quality_sum"] += float(gt_quality) if full_quality is not None: stats["quality_full_sum"] += float(full_quality.item()) if ablated_quality is not None: stats["quality_abl_sum"] += float(ablated_quality.item()) if full_quality is not None and ablated_quality is not None: stats["quality_delta_sum"] += float(ablated_quality.item() - full_quality.item()) full_real = unlog_transform(full_preds) ablated_real = unlog_transform(ablated_preds) if ablated_preds is not None else None num_quantiles = len(quantiles) for h_idx, horizon in enumerate(horizons_seconds): valid = h_idx < len(gt_mask) and bool(gt_mask[h_idx].item()) gt_ret = float(gt_labels[h_idx].item()) if valid else math.nan for q_idx, q in enumerate(quantiles): flat_idx = h_idx * num_quantiles + q_idx bucket = stats["per_hq"][(horizon, q)] full_val = float(full_real[flat_idx].item()) bucket["full_sum"] += full_val if ablated_real is not None: abl_val = float(ablated_real[flat_idx].item()) delta = abl_val - full_val bucket["abl_sum"] += abl_val bucket["delta_sum"] += delta bucket["abs_delta_sum"] += abs(delta) if valid: bucket["gt_sum"] += gt_ret bucket["valid_count"] += 1 def print_aggregate_summary(stats, horizons_seconds, quantiles, ablation_mode): n = stats["count"] print("\n================== Aggregate Summary ==================") print(f"Evaluated Samples: {n}") if n == 0: print("No valid samples collected.") print("=======================================================\n") return if ablation_mode != "none": print( f"Quality Mean: full={stats['quality_full_sum'] / n:.6f} | " f"ablated={stats['quality_abl_sum'] / n:.6f} | " f"delta={stats['quality_delta_sum'] / n:+.6f}" ) for horizon in horizons_seconds: horizon_min = horizon // 60 print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---") valid_counts = [stats["per_hq"][(horizon, q)]["valid_count"] for q in quantiles] valid_count = max(valid_counts) if valid_counts else 0 if valid_count > 0: gt_mean = stats["per_hq"][(horizon, quantiles[0])]["gt_sum"] / valid_count print(f" Mean Ground Truth over valid labels: {gt_mean * 100:.2f}% (n={valid_count})") else: print(" Mean Ground Truth over valid labels: N/A") for q in quantiles: bucket = stats["per_hq"][(horizon, q)] full_mean = bucket["full_sum"] / n line = f" p{int(q*100):02d} mean full: {full_mean * 100:>8.2f}%" if ablation_mode != "none": abl_mean = bucket["abl_sum"] / n delta_mean = bucket["delta_sum"] / n abs_delta_mean = bucket["abs_delta_sum"] / n line += ( f" | ablated: {abl_mean * 100:>8.2f}%" f" | delta: {delta_mean * 100:+8.2f}%" f" | mean|delta|: {abs_delta_mean * 100:>8.2f}%" ) print(line) print("=======================================================\n") def summarize_influence_score(stats, horizons_seconds, quantiles): n = stats["count"] if n == 0: return 0.0 total = 0.0 denom = 0 for horizon in horizons_seconds: for q in quantiles: total += stats["per_hq"][(horizon, q)]["abs_delta_sum"] / n denom += 1 return total / max(denom, 1) def print_probe_summary(mode_to_stats, horizons_seconds, quantiles): rankings = [] for mode in OHLC_PROBE_MODES: score = summarize_influence_score(mode_to_stats[mode], horizons_seconds, quantiles) rankings.append((mode, score)) rankings.sort(key=lambda x: x[1], reverse=True) print("\n================== OHLC Probe Ranking ==================") for rank, (mode, score) in enumerate(rankings, start=1): print(f"{rank:>2}. {mode:<20} mean|delta| = {score * 100:8.2f}%") print("========================================================\n") for mode, _ in rankings: print_aggregate_summary(mode_to_stats[mode], horizons_seconds, quantiles, mode) def get_latest_checkpoint(checkpoint_dir): ckpt_dir = Path(checkpoint_dir) if ckpt_dir.exists(): dirs = [d for d in ckpt_dir.iterdir() if d.is_dir()] if dirs: dirs.sort(key=lambda x: x.stat().st_mtime) latest_checkpoint = dirs[-1] return str(latest_checkpoint) return None def main(): load_dotenv() args = parse_args() rng = random.Random(args.seed) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) accelerator = Accelerator(mixed_precision=args.mixed_precision) device = accelerator.device init_dtype = torch.float32 if accelerator.mixed_precision == 'bf16': init_dtype = torch.bfloat16 elif accelerator.mixed_precision == 'fp16': init_dtype = torch.float16 print("INFO: Initializing DB Connections for LIVE evaluation...") clickhouse_host = os.getenv("CLICKHOUSE_HOST", "localhost") clickhouse_port = int(os.getenv("CLICKHOUSE_PORT", 9000)) neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") neo4j_user = os.getenv("NEO4J_USER", "neo4j") neo4j_password = os.getenv("NEO4J_PASSWORD", "password") clickhouse_client = ClickHouseClient(host=clickhouse_host, port=clickhouse_port) neo4j_driver = GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password)) data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) print(f"Loading live dataset generator...") # We inject the data fetcher directly. No cache directories are used. dataset = OracleDataset( data_fetcher=data_fetcher, fetcher_config=None, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, cache_dir=None ) # Filter out manipulated/broken tokens and optionally enforce min_class from models.vocabulary import MANIPULATED_CLASS_ID print("INFO: Fetching Return Classification Map...") return_class_map, _ = get_return_class_map(clickhouse_client) min_class_thresh = args.min_class if args.min_class is not None else 0 original_len = len(dataset.sampled_mints) dataset.sampled_mints = [ m for m in dataset.sampled_mints if return_class_map.get(m['mint_address']) is not None and return_class_map.get(m['mint_address']) != MANIPULATED_CLASS_ID and return_class_map.get(m['mint_address']) >= min_class_thresh ] dataset.num_samples = len(dataset.sampled_mints) print(f"INFO: Filtered tokens. {original_len} -> {len(dataset.sampled_mints)} valid tokens (class >= {min_class_thresh}).") if len(dataset) == 0: raise ValueError("Dataset is empty. Are ClickHouse data and trade pipelines populated? (Check if min_return filtered everything out)") # Initialize encoders and model FIRST because we need multi_modal_encoder to compile context print("Initializing encoders...") multi_modal_encoder = MultiModalEncoder(dtype=init_dtype, device=device) time_encoder = ContextualTimeEncoder(dtype=init_dtype) token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype) wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype) graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype) ohlc_embedder = OHLCEmbedder(num_intervals=vocab.NUM_OHLC_INTERVALS, dtype=init_dtype) quant_ohlc_embedder = QuantOHLCEmbedder( num_features=NUM_QUANT_OHLC_FEATURES, sequence_length=TOKENS_PER_SEGMENT, dtype=init_dtype, ) collator = MemecoinCollator( event_type_to_id=vocab.EVENT_TO_ID, device=device, dtype=init_dtype, max_seq_len=4096 ) print("Initializing model...") model = Oracle( token_encoder=token_encoder, wallet_encoder=wallet_encoder, graph_updater=graph_updater, ohlc_embedder=ohlc_embedder, quant_ohlc_embedder=quant_ohlc_embedder, time_encoder=time_encoder, num_event_types=vocab.NUM_EVENT_TYPES, multi_modal_dim=multi_modal_encoder.embedding_dim, event_pad_id=vocab.EVENT_TO_ID["__PAD__"], event_type_to_id=vocab.EVENT_TO_ID, model_config_name="llama3-12l-768d-gqa4-8k-random", quantiles=args.quantiles, horizons_seconds=args.horizons_seconds, dtype=init_dtype ) if hasattr(model.model, 'embed_tokens'): del model.model.embed_tokens # Load checkpoint ckpt_path = args.checkpoint if ckpt_path.endswith("latest"): base_dir = Path(ckpt_path).parent found = get_latest_checkpoint(base_dir) if found: ckpt_path = found if not os.path.exists(ckpt_path): print(f"Warning: Checkpoint {ckpt_path} not found. Running with random weights!") model = accelerator.prepare(model) else: print(f"Loading checkpoint from {ckpt_path}...") model = accelerator.prepare(model) try: accelerator.load_state(ckpt_path) print("Successfully loaded accelerator state.") except Exception as e: print(f"Could not load using accelerate.load_state: {e}") print("Trying to load model weights directly...") model_file = os.path.join(ckpt_path, "pytorch_model.bin") if not os.path.exists(model_file): model_file = os.path.join(ckpt_path, "model.safetensors") if os.path.exists(model_file): if model_file.endswith(".safetensors"): from safetensors.torch import load_file state_dict = load_file(model_file) else: state_dict = torch.load(model_file, map_location="cpu") uw_model = accelerator.unwrap_model(model) uw_model.load_state_dict(state_dict, strict=False) print("Successfully loaded weights directly.") else: print(f"Error: model weights not found in {ckpt_path}") model.eval() stats = init_aggregate(args.horizons_seconds, args.quantiles) selected_modes = [] if args.ablation == "none" else (ABLATION_SWEEP_MODES if args.ablation == "sweep" else ([] if args.ablation == "ohlc_probe" else [args.ablation])) mode_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in selected_modes} probe_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in OHLC_PROBE_MODES} if args.ablation == "ohlc_probe" else {} max_target_samples = max(1, args.num_samples) retries = 0 collected = 0 seen_indices = set() while collected < max_target_samples and retries < args.max_retries: sample_idx = resolve_sample_index(dataset, args.sample_idx, rng) if args.sample_idx is None and sample_idx in seen_indices and len(seen_indices) < len(dataset.sampled_mints): retries += 1 continue seen_indices.add(sample_idx) sample_mint_addr = dataset.sampled_mints[sample_idx]['mint_address'] print(f"Trying Token Address: {sample_mint_addr}") contexts = dataset.__cacheitem_context__( sample_idx, num_samples_per_token=1, encoder=multi_modal_encoder, forced_cutoff_trade_idx=args.cutoff_trade_idx, ) if not contexts or contexts[0] is None: print(" [Failed to generate valid context pattern, skipping...]") retries += 1 if args.sample_idx is not None: print("Specific sample requested but failed to generate context. Exiting.") return continue raw_sample = contexts[0] batch = move_batch_to_device(collator([raw_sample]), device) gt_labels = batch["labels"][0].cpu() gt_mask = batch["labels_mask"][0].cpu().bool() gt_quality = batch["quality_score"][0].item() if "quality_score" in batch else None if collected == 0 or args.show_each: print(f"\nEvaluating sample {collected + 1}/{max_target_samples} on Token Address: {sample_mint_addr}") print("\n--- Running Inference ---") full_preds, full_quality, full_direction = run_inference(model, batch) ablation_outputs = {} for mode in selected_modes: ablated_batch = apply_ablation(batch, mode, device) ablated_preds, ablated_quality, ablated_direction = run_inference(model, ablated_batch) ablation_outputs[mode] = (ablated_batch, ablated_preds, ablated_quality, ablated_direction) probe_outputs = {} if args.ablation == "ohlc_probe": for mode in OHLC_PROBE_MODES: probe_batch = apply_ohlc_probe(batch, mode) probe_preds, probe_quality, probe_direction = run_inference(model, probe_batch) probe_outputs[mode] = (probe_batch, probe_preds, probe_quality, probe_direction) if collected == 0 or args.show_each: print_results( title="Full Results", batch=batch, preds=full_preds, quality_pred=full_quality, movement_pred=full_direction, gt_labels=gt_labels, gt_mask=gt_mask, gt_quality=gt_quality, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, ) if args.ablation != "none": if args.ablation == "sweep": print(f"Collected full predictions for {len(selected_modes)} ablation families on this sample. Aggregate ranking will be printed at the end.") elif args.ablation == "ohlc_probe": for mode in OHLC_PROBE_MODES: probe_batch, probe_preds, probe_quality, probe_direction = probe_outputs[mode] print_results( title=f"OHLC Probe ({mode})", batch=probe_batch, preds=probe_preds, quality_pred=probe_quality, movement_pred=probe_direction, gt_labels=gt_labels, gt_mask=gt_mask, gt_quality=gt_quality, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, reference_preds=full_preds, reference_quality=full_quality, ) else: ablated_batch, ablated_preds, ablated_quality, ablated_direction = ablation_outputs[args.ablation] print_results( title=f"Ablation Results ({args.ablation})", batch=ablated_batch, preds=ablated_preds, quality_pred=ablated_quality, movement_pred=ablated_direction, gt_labels=gt_labels, gt_mask=gt_mask, gt_quality=gt_quality, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, reference_preds=full_preds, reference_quality=full_quality, ) update_aggregate( stats=stats, full_preds=full_preds, gt_labels=gt_labels, gt_mask=gt_mask, gt_quality=gt_quality, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, full_quality=full_quality, ) for mode, (_, ablated_preds, ablated_quality, _) in ablation_outputs.items(): update_aggregate( stats=mode_to_stats[mode], full_preds=full_preds, gt_labels=gt_labels, gt_mask=gt_mask, gt_quality=gt_quality, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, ablated_preds=ablated_preds, full_quality=full_quality, ablated_quality=ablated_quality, ) for mode, (_, probe_preds, probe_quality, _) in probe_outputs.items(): update_aggregate( stats=probe_to_stats[mode], full_preds=full_preds, gt_labels=gt_labels, gt_mask=gt_mask, gt_quality=gt_quality, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, ablated_preds=probe_preds, full_quality=full_quality, ablated_quality=probe_quality, ) collected += 1 retries += 1 if args.sample_idx is not None: break if collected == 0: print(f"Could not find a valid context after {args.max_retries} attempts.") return if collected < max_target_samples: print(f"WARNING: Requested {max_target_samples} samples but only evaluated {collected}.") if args.ablation == "none": print_aggregate_summary(stats, args.horizons_seconds, args.quantiles, args.ablation) return if args.ablation == "ohlc_probe": print_probe_summary(probe_to_stats, args.horizons_seconds, args.quantiles) return if args.ablation == "sweep": rankings = [] for mode in selected_modes: score = summarize_influence_score(mode_to_stats[mode], args.horizons_seconds, args.quantiles) rankings.append((mode, score)) rankings.sort(key=lambda x: x[1], reverse=True) print("\n================== Influence Ranking ==================") for rank, (mode, score) in enumerate(rankings, start=1): print(f"{rank:>2}. {mode:<12} mean|delta| = {score * 100:8.2f}%") print("=======================================================\n") for mode, _ in rankings: print_aggregate_summary(mode_to_stats[mode], args.horizons_seconds, args.quantiles, mode) else: print_aggregate_summary(mode_to_stats[args.ablation], args.horizons_seconds, args.quantiles, args.ablation) if __name__ == "__main__": main()