oracle / inference.py
zirobtc's picture
Upload folder using huggingface_hub
d195287 verified
# inference.py
import torch
import traceback
import time
# Import all the necessary components from our project
from models.model import Oracle
from data.data_collator import MemecoinCollator
from models.multi_modal_processor import MultiModalEncoder
from data.data_loader import OracleDataset
from data.data_fetcher import DataFetcher
from models.helper_encoders import ContextualTimeEncoder
from models.token_encoder import TokenEncoder
from models.wallet_encoder import WalletEncoder
from models.graph_updater import GraphUpdater
from models.ohlc_embedder import OHLCEmbedder
from models.quant_ohlc_embedder import QuantOHLCEmbedder
import models.vocabulary as vocab
from data.quant_ohlc_feature_schema import NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
# --- NEW: Import database clients ---
from clickhouse_driver import Client as ClickHouseClient
from neo4j import GraphDatabase
# =============================================================================
# Inference/Test Script for the Oracle Model
# This script replicates the test logic previously in model.py
# =============================================================================
if __name__ == "__main__":
print("--- Oracle Inference Script (Full Pipeline Test) ---")
# --- 1. Define Configs ---
OHLC_SEQ_LEN = 300
print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
if device.type == 'cpu': dtype = torch.float32
print(f"Using device: {device}, dtype: {dtype}")
_test_quantiles = [0.1, 0.5, 0.9]
_test_horizons = [30, 60, 120, 240, 420]
_test_num_outputs = len(_test_quantiles) * len(_test_horizons)
# --- 2. Instantiate ALL Encoders ---
print("Instantiating encoders (using defaults)...")
try:
multi_modal_encoder = MultiModalEncoder(dtype=dtype)
real_time_enc = ContextualTimeEncoder(dtype=dtype)
real_token_enc = TokenEncoder(
multi_dim=multi_modal_encoder.embedding_dim,
dtype=dtype
)
real_wallet_enc = WalletEncoder(encoder=multi_modal_encoder, dtype=dtype)
real_graph_upd = GraphUpdater(time_encoder=real_time_enc, dtype=dtype)
real_ohlc_emb = OHLCEmbedder(
num_intervals=vocab.NUM_OHLC_INTERVALS,
dtype=dtype
)
real_quant_ohlc_emb = QuantOHLCEmbedder(
num_features=NUM_QUANT_OHLC_FEATURES,
sequence_length=TOKENS_PER_SEGMENT,
dtype=dtype
)
print(f"TokenEncoder default output_dim: {real_token_enc.output_dim}")
print(f"WalletEncoder default d_model: {real_wallet_enc.d_model}")
print(f"OHLCEmbedder default output_dim: {real_ohlc_emb.output_dim}")
print("Encoders instantiated.")
except Exception as e:
print(f"Failed to instantiate encoders: {e}")
traceback.print_exc()
exit()
# --- 3. Instantiate the Collator ---
collator = MemecoinCollator(
event_type_to_id=vocab.EVENT_TO_ID,
device=device,
multi_modal_encoder=multi_modal_encoder,
dtype=dtype,
ohlc_seq_len=OHLC_SEQ_LEN,
max_seq_len=50
)
print("MemecoinCollator (fast batcher) instantiated.")
# --- 4. Instantiate the Oracle Model ---
print("Instantiating Oracle (full pipeline)...")
model = Oracle(
token_encoder=real_token_enc,
wallet_encoder=real_wallet_enc,
graph_updater=real_graph_upd,
time_encoder=real_time_enc,
multi_modal_dim=multi_modal_encoder.embedding_dim,
num_event_types=vocab.NUM_EVENT_TYPES,
event_pad_id=vocab.EVENT_TO_ID['__PAD__'],
event_type_to_id=vocab.EVENT_TO_ID,
model_config_name="Qwen/Qwen3-0.6B",
quantiles=_test_quantiles,
horizons_seconds=_test_horizons,
dtype=dtype,
ohlc_embedder=real_ohlc_emb,
quant_ohlc_embedder=real_quant_ohlc_emb
).to(device)
model.eval()
print(f"Oracle d_model: {model.d_model}")
# --- 5. Create Dataset and run pre-collation step ---
print("Creating Dataset...")
# --- NEW: Initialize real database clients and DataFetcher ---
try:
print("Connecting to databases...")
# ClickHouse running locally on port 8123 with no auth
clickhouse_client = ClickHouseClient(host='localhost', port=9000)
# Neo4j running locally on port 7687 with no auth
neo4j_driver = GraphDatabase.driver("bolt://localhost:7687", auth=None)
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
print("Database clients and DataFetcher initialized.")
# --- Fetch mints to get the first token for processing ---
all_mints = data_fetcher.get_all_mints()
if not all_mints:
print("\n❌ No mints found in the database. Exiting test.")
exit()
# --- FIXED: Instantiate the dataset in REAL mode, removing is_test flag ---
dataset = OracleDataset(
data_fetcher=data_fetcher,
horizons_seconds=_test_horizons,
quantiles=_test_quantiles,
max_samples=57)
except Exception as e:
print(f"FATAL: Could not initialize database connections or dataset: {e}")
traceback.print_exc()
exit()
# --- PRODUCTION-READY: Process a full batch of items from the dataset ---
print(f"\n--- Processing a batch of up to {len(dataset)} items from the dataset ---")
batch_items = []
for i in range(len(dataset)):
token_addr = dataset.sampled_mints[i].get('mint_address', 'unknown')
print(f" - Attempting to process sample {i+1}/{len(dataset)} ({token_addr})...")
fetch_start = time.time()
sample = dataset[i]
fetch_elapsed = time.time() - fetch_start
print(f" ... fetch completed in {fetch_elapsed:.2f}s")
if sample is not None:
batch_items.append(sample)
print(f" ... Success! Sample added to batch.")
if not batch_items:
print("\n❌ No valid samples could be generated from the dataset. Exiting.")
exit()
# --- 6. Run Collator AND Model ---
print("\n--- Testing Pipeline (Collator + Model.forward) ---")
try:
# 1. Collator
collate_start = time.time()
collated_batch = collator(batch_items)
collate_elapsed = time.time() - collate_start
print("Collation successful!")
print(f"Collation time for batch of {len(batch_items)} tokens: {collate_elapsed:.2f}s")
# --- Check collator output ---
B = len(batch_items)
L = collated_batch['attention_mask'].shape[1]
assert 'ohlc_price_tensors' in collated_batch
ohlc_price_tensors = collated_batch['ohlc_price_tensors']
assert ohlc_price_tensors.dim() == 3, f"Expected 3D OHLC tensor, got shape {tuple(ohlc_price_tensors.shape)}"
assert ohlc_price_tensors.shape[1] == 2, f"Expected OHLC tensor with 2 rows (open/close), got {ohlc_price_tensors.shape[1]}"
assert ohlc_price_tensors.shape[2] == OHLC_SEQ_LEN, f"Expected OHLC seq len {OHLC_SEQ_LEN}, got {ohlc_price_tensors.shape[2]}"
assert collated_batch['ohlc_interval_ids'].shape[0] == ohlc_price_tensors.shape[0], "Interval ids must align with OHLC segments"
assert ohlc_price_tensors.dtype == dtype, f"OHLC tensor dtype {ohlc_price_tensors.dtype} != expected {dtype}"
print(f"Collator produced {ohlc_price_tensors.shape[0]} OHLC segment(s).")
# --- FIXED: Update assertions for event-specific data which is mostly empty for now ---
assert collated_batch['dest_wallet_indices'].shape == (B, L)
assert collated_batch['transfer_numerical_features'].shape == (B, L, 4)
assert collated_batch['trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
assert collated_batch['deployer_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
assert collated_batch['smart_wallet_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
assert collated_batch['pool_created_numerical_features'].shape == (B, L, 2)
assert collated_batch['liquidity_change_numerical_features'].shape == (B, L, 1)
assert collated_batch['fee_collected_numerical_features'].shape == (B, L, 1)
assert collated_batch['token_burn_numerical_features'].shape == (B, L, 2)
assert collated_batch['supply_lock_numerical_features'].shape == (B, L, 2)
assert collated_batch['onchain_snapshot_numerical_features'].shape == (B, L, 14)
assert collated_batch['trending_token_numerical_features'].shape == (B, L, 1)
assert collated_batch['boosted_token_numerical_features'].shape == (B, L, 2)
# assert len(collated_batch['holder_snapshot_raw_data']) == 1 # No holder snapshots yet
# assert len(collated_batch['textual_event_data']) == 8 # No textual events yet
assert collated_batch['dexboost_paid_numerical_features'].shape == (B, L, 2)
print("Collator correctly processed all event-specific numerical data into their respective tensors.")
# --- NEW: Comprehensive Debugging Output ---
print("\n--- Collated Batch Debug Output ---")
print(f"Batch Size: {B}, Max Sequence Length: {L}")
# Print shapes of key tensors
print("\n[Core Tensors]")
print(f" event_type_ids: {collated_batch['event_type_ids'].shape}")
print(f" attention_mask: {collated_batch['attention_mask'].shape}")
print(f" timestamps_float: {collated_batch['timestamps_float'].shape}")
print("\n[Pointer Tensors]")
print(f" wallet_indices: {collated_batch['wallet_indices'].shape}")
print(f" token_indices: {collated_batch['token_indices'].shape}")
print("\n[Encoder Inputs]")
print(f" embedding_pool: {collated_batch['embedding_pool'].shape}")
# --- FIXED: Check for a key that still exists after removing address embeddings ---
if collated_batch['token_encoder_inputs']['name_embed_indices'].numel() > 0:
print(f" token_encoder_inputs contains {collated_batch['token_encoder_inputs']['name_embed_indices'].shape[0]} tokens.")
else:
print(" token_encoder_inputs is empty.")
if collated_batch['wallet_encoder_inputs']['profile_rows']:
print(f" wallet_encoder_inputs contains {len(collated_batch['wallet_encoder_inputs']['profile_rows'])} wallets.")
else:
print(" wallet_encoder_inputs is empty.")
print("\n[Graph Links]")
if collated_batch['graph_updater_links']:
for link_name, data in collated_batch['graph_updater_links'].items():
print(f" - {link_name}: {data['edge_index'].shape[1]} edges")
else:
print(" No graph links in this batch.")
print("--- End Debug Output ---\n")
print("Embedding pool size:", collated_batch["embedding_pool"].shape[0])
print("Max name_emb_idx:", collated_batch["token_encoder_inputs"]["name_embed_indices"].max().item())
# 2. Model Forward Pass
with torch.no_grad():
model_outputs = model(collated_batch)
quantile_logits = model_outputs["quantile_logits"]
hidden_states = model_outputs["hidden_states"]
attention_mask = model_outputs["attention_mask"]
pooled_states = model_outputs["pooled_states"]
print("Model forward pass successful!")
# --- 7. Verify Output ---
print("\n--- Test Results ---")
D_MODEL = model.d_model
print(f"Final hidden_states shape: {hidden_states.shape}")
print(f"Final attention_mask shape: {attention_mask.shape}")
assert hidden_states.shape == (B, L, D_MODEL)
assert attention_mask.shape == (B, L)
assert hidden_states.dtype == dtype
print(f"Output mean (sanity check): {hidden_states.mean().item()}")
print(f"Pooled state shape: {pooled_states.shape}")
print(f"Quantile logits shape: {quantile_logits.shape}")
quantile_grid = quantile_logits.view(B, len(_test_horizons), len(_test_quantiles))
print("\n[Quantile Predictions]")
for b_idx in range(B):
print(f" Sample {b_idx}:")
for h_idx, horizon in enumerate(_test_horizons):
row = quantile_grid[b_idx, h_idx]
print(f" Horizon {horizon}s -> " + ", ".join(
f"q={q:.2f}: {row[q_idx].item():.6f}"
for q_idx, q in enumerate(_test_quantiles)
))
print("\n✅ **Test Passed!** Full ENCODING pipeline is working.")
except Exception as e:
print(f"\n❌ Error during pipeline test: {e}")
traceback.print_exc()