File size: 13,056 Bytes
858826c d195287 858826c d195287 858826c bb2313b 858826c d195287 858826c d195287 858826c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | # inference.py
import torch
import traceback
import time
# Import all the necessary components from our project
from models.model import Oracle
from data.data_collator import MemecoinCollator
from models.multi_modal_processor import MultiModalEncoder
from data.data_loader import OracleDataset
from data.data_fetcher import DataFetcher
from models.helper_encoders import ContextualTimeEncoder
from models.token_encoder import TokenEncoder
from models.wallet_encoder import WalletEncoder
from models.graph_updater import GraphUpdater
from models.ohlc_embedder import OHLCEmbedder
from models.quant_ohlc_embedder import QuantOHLCEmbedder
import models.vocabulary as vocab
from data.quant_ohlc_feature_schema import NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
# --- NEW: Import database clients ---
from clickhouse_driver import Client as ClickHouseClient
from neo4j import GraphDatabase
# =============================================================================
# Inference/Test Script for the Oracle Model
# This script replicates the test logic previously in model.py
# =============================================================================
if __name__ == "__main__":
print("--- Oracle Inference Script (Full Pipeline Test) ---")
# --- 1. Define Configs ---
OHLC_SEQ_LEN = 300
print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
if device.type == 'cpu': dtype = torch.float32
print(f"Using device: {device}, dtype: {dtype}")
_test_quantiles = [0.1, 0.5, 0.9]
_test_horizons = [30, 60, 120, 240, 420]
_test_num_outputs = len(_test_quantiles) * len(_test_horizons)
# --- 2. Instantiate ALL Encoders ---
print("Instantiating encoders (using defaults)...")
try:
multi_modal_encoder = MultiModalEncoder(dtype=dtype)
real_time_enc = ContextualTimeEncoder(dtype=dtype)
real_token_enc = TokenEncoder(
multi_dim=multi_modal_encoder.embedding_dim,
dtype=dtype
)
real_wallet_enc = WalletEncoder(encoder=multi_modal_encoder, dtype=dtype)
real_graph_upd = GraphUpdater(time_encoder=real_time_enc, dtype=dtype)
real_ohlc_emb = OHLCEmbedder(
num_intervals=vocab.NUM_OHLC_INTERVALS,
dtype=dtype
)
real_quant_ohlc_emb = QuantOHLCEmbedder(
num_features=NUM_QUANT_OHLC_FEATURES,
sequence_length=TOKENS_PER_SEGMENT,
dtype=dtype
)
print(f"TokenEncoder default output_dim: {real_token_enc.output_dim}")
print(f"WalletEncoder default d_model: {real_wallet_enc.d_model}")
print(f"OHLCEmbedder default output_dim: {real_ohlc_emb.output_dim}")
print("Encoders instantiated.")
except Exception as e:
print(f"Failed to instantiate encoders: {e}")
traceback.print_exc()
exit()
# --- 3. Instantiate the Collator ---
collator = MemecoinCollator(
event_type_to_id=vocab.EVENT_TO_ID,
device=device,
multi_modal_encoder=multi_modal_encoder,
dtype=dtype,
ohlc_seq_len=OHLC_SEQ_LEN,
max_seq_len=50
)
print("MemecoinCollator (fast batcher) instantiated.")
# --- 4. Instantiate the Oracle Model ---
print("Instantiating Oracle (full pipeline)...")
model = Oracle(
token_encoder=real_token_enc,
wallet_encoder=real_wallet_enc,
graph_updater=real_graph_upd,
time_encoder=real_time_enc,
multi_modal_dim=multi_modal_encoder.embedding_dim,
num_event_types=vocab.NUM_EVENT_TYPES,
event_pad_id=vocab.EVENT_TO_ID['__PAD__'],
event_type_to_id=vocab.EVENT_TO_ID,
model_config_name="Qwen/Qwen3-0.6B",
quantiles=_test_quantiles,
horizons_seconds=_test_horizons,
dtype=dtype,
ohlc_embedder=real_ohlc_emb,
quant_ohlc_embedder=real_quant_ohlc_emb
).to(device)
model.eval()
print(f"Oracle d_model: {model.d_model}")
# --- 5. Create Dataset and run pre-collation step ---
print("Creating Dataset...")
# --- NEW: Initialize real database clients and DataFetcher ---
try:
print("Connecting to databases...")
# ClickHouse running locally on port 8123 with no auth
clickhouse_client = ClickHouseClient(host='localhost', port=9000)
# Neo4j running locally on port 7687 with no auth
neo4j_driver = GraphDatabase.driver("bolt://localhost:7687", auth=None)
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
print("Database clients and DataFetcher initialized.")
# --- Fetch mints to get the first token for processing ---
all_mints = data_fetcher.get_all_mints()
if not all_mints:
print("\n❌ No mints found in the database. Exiting test.")
exit()
# --- FIXED: Instantiate the dataset in REAL mode, removing is_test flag ---
dataset = OracleDataset(
data_fetcher=data_fetcher,
horizons_seconds=_test_horizons,
quantiles=_test_quantiles,
max_samples=57)
except Exception as e:
print(f"FATAL: Could not initialize database connections or dataset: {e}")
traceback.print_exc()
exit()
# --- PRODUCTION-READY: Process a full batch of items from the dataset ---
print(f"\n--- Processing a batch of up to {len(dataset)} items from the dataset ---")
batch_items = []
for i in range(len(dataset)):
token_addr = dataset.sampled_mints[i].get('mint_address', 'unknown')
print(f" - Attempting to process sample {i+1}/{len(dataset)} ({token_addr})...")
fetch_start = time.time()
sample = dataset[i]
fetch_elapsed = time.time() - fetch_start
print(f" ... fetch completed in {fetch_elapsed:.2f}s")
if sample is not None:
batch_items.append(sample)
print(f" ... Success! Sample added to batch.")
if not batch_items:
print("\n❌ No valid samples could be generated from the dataset. Exiting.")
exit()
# --- 6. Run Collator AND Model ---
print("\n--- Testing Pipeline (Collator + Model.forward) ---")
try:
# 1. Collator
collate_start = time.time()
collated_batch = collator(batch_items)
collate_elapsed = time.time() - collate_start
print("Collation successful!")
print(f"Collation time for batch of {len(batch_items)} tokens: {collate_elapsed:.2f}s")
# --- Check collator output ---
B = len(batch_items)
L = collated_batch['attention_mask'].shape[1]
assert 'ohlc_price_tensors' in collated_batch
ohlc_price_tensors = collated_batch['ohlc_price_tensors']
assert ohlc_price_tensors.dim() == 3, f"Expected 3D OHLC tensor, got shape {tuple(ohlc_price_tensors.shape)}"
assert ohlc_price_tensors.shape[1] == 2, f"Expected OHLC tensor with 2 rows (open/close), got {ohlc_price_tensors.shape[1]}"
assert ohlc_price_tensors.shape[2] == OHLC_SEQ_LEN, f"Expected OHLC seq len {OHLC_SEQ_LEN}, got {ohlc_price_tensors.shape[2]}"
assert collated_batch['ohlc_interval_ids'].shape[0] == ohlc_price_tensors.shape[0], "Interval ids must align with OHLC segments"
assert ohlc_price_tensors.dtype == dtype, f"OHLC tensor dtype {ohlc_price_tensors.dtype} != expected {dtype}"
print(f"Collator produced {ohlc_price_tensors.shape[0]} OHLC segment(s).")
# --- FIXED: Update assertions for event-specific data which is mostly empty for now ---
assert collated_batch['dest_wallet_indices'].shape == (B, L)
assert collated_batch['transfer_numerical_features'].shape == (B, L, 4)
assert collated_batch['trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
assert collated_batch['deployer_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
assert collated_batch['smart_wallet_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
assert collated_batch['pool_created_numerical_features'].shape == (B, L, 2)
assert collated_batch['liquidity_change_numerical_features'].shape == (B, L, 1)
assert collated_batch['fee_collected_numerical_features'].shape == (B, L, 1)
assert collated_batch['token_burn_numerical_features'].shape == (B, L, 2)
assert collated_batch['supply_lock_numerical_features'].shape == (B, L, 2)
assert collated_batch['onchain_snapshot_numerical_features'].shape == (B, L, 14)
assert collated_batch['trending_token_numerical_features'].shape == (B, L, 1)
assert collated_batch['boosted_token_numerical_features'].shape == (B, L, 2)
# assert len(collated_batch['holder_snapshot_raw_data']) == 1 # No holder snapshots yet
# assert len(collated_batch['textual_event_data']) == 8 # No textual events yet
assert collated_batch['dexboost_paid_numerical_features'].shape == (B, L, 2)
print("Collator correctly processed all event-specific numerical data into their respective tensors.")
# --- NEW: Comprehensive Debugging Output ---
print("\n--- Collated Batch Debug Output ---")
print(f"Batch Size: {B}, Max Sequence Length: {L}")
# Print shapes of key tensors
print("\n[Core Tensors]")
print(f" event_type_ids: {collated_batch['event_type_ids'].shape}")
print(f" attention_mask: {collated_batch['attention_mask'].shape}")
print(f" timestamps_float: {collated_batch['timestamps_float'].shape}")
print("\n[Pointer Tensors]")
print(f" wallet_indices: {collated_batch['wallet_indices'].shape}")
print(f" token_indices: {collated_batch['token_indices'].shape}")
print("\n[Encoder Inputs]")
print(f" embedding_pool: {collated_batch['embedding_pool'].shape}")
# --- FIXED: Check for a key that still exists after removing address embeddings ---
if collated_batch['token_encoder_inputs']['name_embed_indices'].numel() > 0:
print(f" token_encoder_inputs contains {collated_batch['token_encoder_inputs']['name_embed_indices'].shape[0]} tokens.")
else:
print(" token_encoder_inputs is empty.")
if collated_batch['wallet_encoder_inputs']['profile_rows']:
print(f" wallet_encoder_inputs contains {len(collated_batch['wallet_encoder_inputs']['profile_rows'])} wallets.")
else:
print(" wallet_encoder_inputs is empty.")
print("\n[Graph Links]")
if collated_batch['graph_updater_links']:
for link_name, data in collated_batch['graph_updater_links'].items():
print(f" - {link_name}: {data['edge_index'].shape[1]} edges")
else:
print(" No graph links in this batch.")
print("--- End Debug Output ---\n")
print("Embedding pool size:", collated_batch["embedding_pool"].shape[0])
print("Max name_emb_idx:", collated_batch["token_encoder_inputs"]["name_embed_indices"].max().item())
# 2. Model Forward Pass
with torch.no_grad():
model_outputs = model(collated_batch)
quantile_logits = model_outputs["quantile_logits"]
hidden_states = model_outputs["hidden_states"]
attention_mask = model_outputs["attention_mask"]
pooled_states = model_outputs["pooled_states"]
print("Model forward pass successful!")
# --- 7. Verify Output ---
print("\n--- Test Results ---")
D_MODEL = model.d_model
print(f"Final hidden_states shape: {hidden_states.shape}")
print(f"Final attention_mask shape: {attention_mask.shape}")
assert hidden_states.shape == (B, L, D_MODEL)
assert attention_mask.shape == (B, L)
assert hidden_states.dtype == dtype
print(f"Output mean (sanity check): {hidden_states.mean().item()}")
print(f"Pooled state shape: {pooled_states.shape}")
print(f"Quantile logits shape: {quantile_logits.shape}")
quantile_grid = quantile_logits.view(B, len(_test_horizons), len(_test_quantiles))
print("\n[Quantile Predictions]")
for b_idx in range(B):
print(f" Sample {b_idx}:")
for h_idx, horizon in enumerate(_test_horizons):
row = quantile_grid[b_idx, h_idx]
print(f" Horizon {horizon}s -> " + ", ".join(
f"q={q:.2f}: {row[q_idx].item():.6f}"
for q_idx, q in enumerate(_test_quantiles)
))
print("\n✅ **Test Passed!** Full ENCODING pipeline is working.")
except Exception as e:
print(f"\n❌ Error during pipeline test: {e}")
traceback.print_exc()
|