# inference.py import torch import traceback import time # Import all the necessary components from our project from models.model import Oracle from data.data_collator import MemecoinCollator from models.multi_modal_processor import MultiModalEncoder from data.data_loader import OracleDataset from data.data_fetcher import DataFetcher from models.helper_encoders import ContextualTimeEncoder from models.token_encoder import TokenEncoder from models.wallet_encoder import WalletEncoder from models.graph_updater import GraphUpdater from models.ohlc_embedder import OHLCEmbedder from models.quant_ohlc_embedder import QuantOHLCEmbedder import models.vocabulary as vocab from data.quant_ohlc_feature_schema import NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT # --- NEW: Import database clients --- from clickhouse_driver import Client as ClickHouseClient from neo4j import GraphDatabase # ============================================================================= # Inference/Test Script for the Oracle Model # This script replicates the test logic previously in model.py # ============================================================================= if __name__ == "__main__": print("--- Oracle Inference Script (Full Pipeline Test) ---") # --- 1. Define Configs --- OHLC_SEQ_LEN = 300 print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cpu': dtype = torch.float32 print(f"Using device: {device}, dtype: {dtype}") _test_quantiles = [0.1, 0.5, 0.9] _test_horizons = [30, 60, 120, 240, 420] _test_num_outputs = len(_test_quantiles) * len(_test_horizons) # --- 2. Instantiate ALL Encoders --- print("Instantiating encoders (using defaults)...") try: multi_modal_encoder = MultiModalEncoder(dtype=dtype) real_time_enc = ContextualTimeEncoder(dtype=dtype) real_token_enc = TokenEncoder( multi_dim=multi_modal_encoder.embedding_dim, dtype=dtype ) real_wallet_enc = WalletEncoder(encoder=multi_modal_encoder, dtype=dtype) real_graph_upd = GraphUpdater(time_encoder=real_time_enc, dtype=dtype) real_ohlc_emb = OHLCEmbedder( num_intervals=vocab.NUM_OHLC_INTERVALS, dtype=dtype ) real_quant_ohlc_emb = QuantOHLCEmbedder( num_features=NUM_QUANT_OHLC_FEATURES, sequence_length=TOKENS_PER_SEGMENT, dtype=dtype ) print(f"TokenEncoder default output_dim: {real_token_enc.output_dim}") print(f"WalletEncoder default d_model: {real_wallet_enc.d_model}") print(f"OHLCEmbedder default output_dim: {real_ohlc_emb.output_dim}") print("Encoders instantiated.") except Exception as e: print(f"Failed to instantiate encoders: {e}") traceback.print_exc() exit() # --- 3. Instantiate the Collator --- collator = MemecoinCollator( event_type_to_id=vocab.EVENT_TO_ID, device=device, multi_modal_encoder=multi_modal_encoder, dtype=dtype, ohlc_seq_len=OHLC_SEQ_LEN, max_seq_len=50 ) print("MemecoinCollator (fast batcher) instantiated.") # --- 4. Instantiate the Oracle Model --- print("Instantiating Oracle (full pipeline)...") model = Oracle( token_encoder=real_token_enc, wallet_encoder=real_wallet_enc, graph_updater=real_graph_upd, time_encoder=real_time_enc, multi_modal_dim=multi_modal_encoder.embedding_dim, num_event_types=vocab.NUM_EVENT_TYPES, event_pad_id=vocab.EVENT_TO_ID['__PAD__'], event_type_to_id=vocab.EVENT_TO_ID, model_config_name="Qwen/Qwen3-0.6B", quantiles=_test_quantiles, horizons_seconds=_test_horizons, dtype=dtype, ohlc_embedder=real_ohlc_emb, quant_ohlc_embedder=real_quant_ohlc_emb ).to(device) model.eval() print(f"Oracle d_model: {model.d_model}") # --- 5. Create Dataset and run pre-collation step --- print("Creating Dataset...") # --- NEW: Initialize real database clients and DataFetcher --- try: print("Connecting to databases...") # ClickHouse running locally on port 8123 with no auth clickhouse_client = ClickHouseClient(host='localhost', port=9000) # Neo4j running locally on port 7687 with no auth neo4j_driver = GraphDatabase.driver("bolt://localhost:7687", auth=None) data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) print("Database clients and DataFetcher initialized.") # --- Fetch mints to get the first token for processing --- all_mints = data_fetcher.get_all_mints() if not all_mints: print("\n❌ No mints found in the database. Exiting test.") exit() # --- FIXED: Instantiate the dataset in REAL mode, removing is_test flag --- dataset = OracleDataset( data_fetcher=data_fetcher, horizons_seconds=_test_horizons, quantiles=_test_quantiles, max_samples=57) except Exception as e: print(f"FATAL: Could not initialize database connections or dataset: {e}") traceback.print_exc() exit() # --- PRODUCTION-READY: Process a full batch of items from the dataset --- print(f"\n--- Processing a batch of up to {len(dataset)} items from the dataset ---") batch_items = [] for i in range(len(dataset)): token_addr = dataset.sampled_mints[i].get('mint_address', 'unknown') print(f" - Attempting to process sample {i+1}/{len(dataset)} ({token_addr})...") fetch_start = time.time() sample = dataset[i] fetch_elapsed = time.time() - fetch_start print(f" ... fetch completed in {fetch_elapsed:.2f}s") if sample is not None: batch_items.append(sample) print(f" ... Success! Sample added to batch.") if not batch_items: print("\n❌ No valid samples could be generated from the dataset. Exiting.") exit() # --- 6. Run Collator AND Model --- print("\n--- Testing Pipeline (Collator + Model.forward) ---") try: # 1. Collator collate_start = time.time() collated_batch = collator(batch_items) collate_elapsed = time.time() - collate_start print("Collation successful!") print(f"Collation time for batch of {len(batch_items)} tokens: {collate_elapsed:.2f}s") # --- Check collator output --- B = len(batch_items) L = collated_batch['attention_mask'].shape[1] assert 'ohlc_price_tensors' in collated_batch ohlc_price_tensors = collated_batch['ohlc_price_tensors'] assert ohlc_price_tensors.dim() == 3, f"Expected 3D OHLC tensor, got shape {tuple(ohlc_price_tensors.shape)}" assert ohlc_price_tensors.shape[1] == 2, f"Expected OHLC tensor with 2 rows (open/close), got {ohlc_price_tensors.shape[1]}" assert ohlc_price_tensors.shape[2] == OHLC_SEQ_LEN, f"Expected OHLC seq len {OHLC_SEQ_LEN}, got {ohlc_price_tensors.shape[2]}" assert collated_batch['ohlc_interval_ids'].shape[0] == ohlc_price_tensors.shape[0], "Interval ids must align with OHLC segments" assert ohlc_price_tensors.dtype == dtype, f"OHLC tensor dtype {ohlc_price_tensors.dtype} != expected {dtype}" print(f"Collator produced {ohlc_price_tensors.shape[0]} OHLC segment(s).") # --- FIXED: Update assertions for event-specific data which is mostly empty for now --- assert collated_batch['dest_wallet_indices'].shape == (B, L) assert collated_batch['transfer_numerical_features'].shape == (B, L, 4) assert collated_batch['trade_numerical_features'].shape == (B, L, 8) # Corrected from 10 assert collated_batch['deployer_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10 assert collated_batch['smart_wallet_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10 assert collated_batch['pool_created_numerical_features'].shape == (B, L, 2) assert collated_batch['liquidity_change_numerical_features'].shape == (B, L, 1) assert collated_batch['fee_collected_numerical_features'].shape == (B, L, 1) assert collated_batch['token_burn_numerical_features'].shape == (B, L, 2) assert collated_batch['supply_lock_numerical_features'].shape == (B, L, 2) assert collated_batch['onchain_snapshot_numerical_features'].shape == (B, L, 14) assert collated_batch['trending_token_numerical_features'].shape == (B, L, 1) assert collated_batch['boosted_token_numerical_features'].shape == (B, L, 2) # assert len(collated_batch['holder_snapshot_raw_data']) == 1 # No holder snapshots yet # assert len(collated_batch['textual_event_data']) == 8 # No textual events yet assert collated_batch['dexboost_paid_numerical_features'].shape == (B, L, 2) print("Collator correctly processed all event-specific numerical data into their respective tensors.") # --- NEW: Comprehensive Debugging Output --- print("\n--- Collated Batch Debug Output ---") print(f"Batch Size: {B}, Max Sequence Length: {L}") # Print shapes of key tensors print("\n[Core Tensors]") print(f" event_type_ids: {collated_batch['event_type_ids'].shape}") print(f" attention_mask: {collated_batch['attention_mask'].shape}") print(f" timestamps_float: {collated_batch['timestamps_float'].shape}") print("\n[Pointer Tensors]") print(f" wallet_indices: {collated_batch['wallet_indices'].shape}") print(f" token_indices: {collated_batch['token_indices'].shape}") print("\n[Encoder Inputs]") print(f" embedding_pool: {collated_batch['embedding_pool'].shape}") # --- FIXED: Check for a key that still exists after removing address embeddings --- if collated_batch['token_encoder_inputs']['name_embed_indices'].numel() > 0: print(f" token_encoder_inputs contains {collated_batch['token_encoder_inputs']['name_embed_indices'].shape[0]} tokens.") else: print(" token_encoder_inputs is empty.") if collated_batch['wallet_encoder_inputs']['profile_rows']: print(f" wallet_encoder_inputs contains {len(collated_batch['wallet_encoder_inputs']['profile_rows'])} wallets.") else: print(" wallet_encoder_inputs is empty.") print("\n[Graph Links]") if collated_batch['graph_updater_links']: for link_name, data in collated_batch['graph_updater_links'].items(): print(f" - {link_name}: {data['edge_index'].shape[1]} edges") else: print(" No graph links in this batch.") print("--- End Debug Output ---\n") print("Embedding pool size:", collated_batch["embedding_pool"].shape[0]) print("Max name_emb_idx:", collated_batch["token_encoder_inputs"]["name_embed_indices"].max().item()) # 2. Model Forward Pass with torch.no_grad(): model_outputs = model(collated_batch) quantile_logits = model_outputs["quantile_logits"] hidden_states = model_outputs["hidden_states"] attention_mask = model_outputs["attention_mask"] pooled_states = model_outputs["pooled_states"] print("Model forward pass successful!") # --- 7. Verify Output --- print("\n--- Test Results ---") D_MODEL = model.d_model print(f"Final hidden_states shape: {hidden_states.shape}") print(f"Final attention_mask shape: {attention_mask.shape}") assert hidden_states.shape == (B, L, D_MODEL) assert attention_mask.shape == (B, L) assert hidden_states.dtype == dtype print(f"Output mean (sanity check): {hidden_states.mean().item()}") print(f"Pooled state shape: {pooled_states.shape}") print(f"Quantile logits shape: {quantile_logits.shape}") quantile_grid = quantile_logits.view(B, len(_test_horizons), len(_test_quantiles)) print("\n[Quantile Predictions]") for b_idx in range(B): print(f" Sample {b_idx}:") for h_idx, horizon in enumerate(_test_horizons): row = quantile_grid[b_idx, h_idx] print(f" Horizon {horizon}s -> " + ", ".join( f"q={q:.2f}: {row[q_idx].item():.6f}" for q_idx, q in enumerate(_test_quantiles) )) print("\n✅ **Test Passed!** Full ENCODING pipeline is working.") except Exception as e: print(f"\n❌ Error during pipeline test: {e}") traceback.print_exc()