oracle / scripts /migrate_cache_v2.py
zirobtc's picture
Upload folder using huggingface_hub
e605733
#!/usr/bin/env python3
"""
Cache Migration Script v2: Adds wallet/graph/image data to existing cache files.
This script updates existing cache files to include:
1. Wallet profiles and socials
2. Wallet holdings
3. Graph links (from Neo4j)
4. Token images (as bytes)
This enables fully offline training with zero DB calls during __getitem__.
Usage:
python scripts/migrate_cache_v2.py \
--cache_dir /workspace/apollo/data/cache \
--num_workers 8 \
--batch_size 100
Requirements:
- ClickHouse and Neo4j must be accessible
- Set environment variables: CLICKHOUSE_HOST, CLICKHOUSE_PORT, NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD
"""
import os
import sys
import argparse
import torch
import datetime
import requests
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
from collections import defaultdict
from tqdm import tqdm
from io import BytesIO
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from data.data_fetcher import DataFetcher
from clickhouse_driver import Client as ClickHouseClient
from neo4j import GraphDatabase
def create_fetcher():
"""Create a DataFetcher instance with DB connections."""
ch_client = ClickHouseClient(
host=os.getenv("CLICKHOUSE_HOST", "localhost"),
port=int(os.getenv("CLICKHOUSE_PORT", 9000)),
)
neo4j_driver = GraphDatabase.driver(
os.getenv("NEO4J_URI", "bolt://localhost:7687"),
auth=(
os.getenv("NEO4J_USER", "neo4j"),
os.getenv("NEO4J_PASSWORD", "password")
)
)
return DataFetcher(clickhouse_client=ch_client, neo4j_driver=neo4j_driver)
def _timestamp_to_order_value(ts_value) -> float:
"""Convert various timestamp formats to float."""
if isinstance(ts_value, datetime.datetime):
if ts_value.tzinfo is None:
ts_value = ts_value.replace(tzinfo=datetime.timezone.utc)
return ts_value.timestamp()
try:
return float(ts_value)
except (TypeError, ValueError):
return 0.0
def migrate_single_file(args):
"""Migrate a single cache file to v2 format."""
filepath, fetcher_config = args
try:
# Load existing cache file
raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
# Check if already migrated
if 'cached_wallet_data' in raw_data and 'cached_graph_data' in raw_data:
return filepath, "skipped", "Already migrated"
# Create fetcher for this worker
ch_client = ClickHouseClient(
host=fetcher_config["clickhouse_host"],
port=fetcher_config["clickhouse_port"],
)
neo4j_driver = GraphDatabase.driver(
fetcher_config["neo4j_uri"],
auth=(fetcher_config["neo4j_user"], fetcher_config["neo4j_password"])
)
fetcher = DataFetcher(clickhouse_client=ch_client, neo4j_driver=neo4j_driver)
# Extract token info
token_address = raw_data.get('token_address')
creator_address = raw_data.get('creator_address')
if not token_address:
return filepath, "error", "Missing token_address"
# Collect all unique wallets from all events
all_wallets = set()
all_wallets.add(creator_address)
for trade in raw_data.get('trades', []):
if trade.get('maker'):
all_wallets.add(trade['maker'])
for transfer in raw_data.get('transfers', []):
if transfer.get('source'):
all_wallets.add(transfer['source'])
if transfer.get('destination'):
all_wallets.add(transfer['destination'])
for pool in raw_data.get('pool_creations', []):
if pool.get('creator_address'):
all_wallets.add(pool['creator_address'])
for liq in raw_data.get('liquidity_changes', []):
if liq.get('lp_provider'):
all_wallets.add(liq['lp_provider'])
# Add wallets from holder snapshots
for snapshot in raw_data.get('holder_snapshots_list', []):
if isinstance(snapshot, dict):
for holder in snapshot.get('holders', []):
if holder.get('wallet_address'):
all_wallets.add(holder['wallet_address'])
all_wallets.discard(None)
all_wallets.discard('')
wallet_list = list(all_wallets)
# Determine max T_cutoff from trades
trades = raw_data.get('trades', [])
if trades:
trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades if t.get('timestamp')]
if trade_ts_values:
last_trade_ts_val = max(trade_ts_values)
max_T_cutoff = datetime.datetime.fromtimestamp(last_trade_ts_val, tz=datetime.timezone.utc)
else:
max_T_cutoff = datetime.datetime.now(datetime.timezone.utc)
else:
max_T_cutoff = datetime.datetime.now(datetime.timezone.utc)
# Fetch wallet profiles and socials
try:
cached_profiles, cached_socials = fetcher.fetch_wallet_profiles_and_socials(
wallet_list, max_T_cutoff
)
except Exception as e:
cached_profiles, cached_socials = {}, {}
# Fetch wallet holdings
try:
cached_holdings = fetcher.fetch_wallet_holdings(wallet_list, max_T_cutoff)
except Exception as e:
cached_holdings = {}
# Fetch graph links
try:
cached_graph_entities, cached_graph_links = fetcher.fetch_graph_links(
wallet_list, max_T_cutoff, max_degrees=1
)
except Exception as e:
cached_graph_entities, cached_graph_links = {}, {}
# Fetch token image
cached_image_bytes = None
http_session = requests.Session()
try:
# Try Bullx first
bullx_image_url = f"https://image.bullx.io/1399811149/{token_address}?retry=0"
resp = http_session.get(bullx_image_url, timeout=5)
if resp.status_code == 200:
cached_image_bytes = resp.content
else:
# Fallback to IPFS
token_uri = raw_data.get('token_uri')
if token_uri and 'ipfs/' in str(token_uri):
ipfs_gateways = [
"https://pump.mypinata.cloud/ipfs/",
"https://dweb.link/ipfs/",
"https://cloudflare-ipfs.com/ipfs/",
]
metadata_hash = token_uri.split('ipfs/')[-1]
for gateway in ipfs_gateways:
try:
metadata_resp = http_session.get(f"{gateway}{metadata_hash}", timeout=5)
if metadata_resp.status_code == 200:
metadata = metadata_resp.json()
image_url = metadata.get('image', '')
if image_url and 'ipfs/' in image_url:
image_hash = image_url.split('ipfs/')[-1]
for img_gateway in ipfs_gateways:
try:
img_resp = http_session.get(f"{img_gateway}{image_hash}", timeout=5)
if img_resp.status_code == 200:
cached_image_bytes = img_resp.content
break
except:
continue
break
except:
continue
except Exception as e:
pass
# Store all cached data
raw_data['cached_wallet_data'] = {
'profiles': cached_profiles,
'socials': cached_socials,
'holdings': cached_holdings,
}
raw_data['cached_graph_data'] = {
'entities': cached_graph_entities,
'links': cached_graph_links,
}
raw_data['cached_image_bytes'] = cached_image_bytes
raw_data['cached_max_T_cutoff'] = max_T_cutoff.timestamp()
# Save updated cache file
torch.save(raw_data, filepath)
# Close connections
neo4j_driver.close()
return filepath, "success", f"Migrated {len(wallet_list)} wallets, {len(cached_graph_links)} graph links"
except Exception as e:
return filepath, "error", str(e)
def main():
parser = argparse.ArgumentParser(description="Migrate cache files to v2 format with complete offline data")
parser.add_argument("--cache_dir", type=str, required=True, help="Path to cache directory")
parser.add_argument("--num_workers", type=int, default=4, help="Number of parallel workers")
parser.add_argument("--batch_size", type=int, default=50, help="Files per batch for progress reporting")
parser.add_argument("--start_idx", type=int, default=0, help="Start from this file index (for resume)")
parser.add_argument("--max_files", type=int, default=None, help="Max files to process (for testing)")
args = parser.parse_args()
cache_dir = Path(args.cache_dir)
if not cache_dir.exists():
print(f"ERROR: Cache directory not found: {cache_dir}")
sys.exit(1)
# Find all cache files
cache_files = sorted(cache_dir.glob("sample_*.pt"), key=lambda p: int(p.stem.split('_')[1]))
if args.start_idx > 0:
cache_files = cache_files[args.start_idx:]
if args.max_files:
cache_files = cache_files[:args.max_files]
print(f"Found {len(cache_files)} cache files to process")
# Fetcher config for workers
fetcher_config = {
"clickhouse_host": os.getenv("CLICKHOUSE_HOST", "localhost"),
"clickhouse_port": int(os.getenv("CLICKHOUSE_PORT", 9000)),
"neo4j_uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"),
"neo4j_user": os.getenv("NEO4J_USER", "neo4j"),
"neo4j_password": os.getenv("NEO4J_PASSWORD", "password"),
}
# Prepare arguments for workers
work_items = [(f, fetcher_config) for f in cache_files]
# Statistics
stats = {"success": 0, "skipped": 0, "error": 0}
errors = []
# Process files in parallel
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
futures = {executor.submit(migrate_single_file, item): item[0] for item in work_items}
with tqdm(total=len(cache_files), desc="Migrating cache files") as pbar:
for future in as_completed(futures):
filepath = futures[future]
try:
_, status, message = future.result()
stats[status] += 1
if status == "error":
errors.append((filepath, message))
except Exception as e:
stats["error"] += 1
errors.append((filepath, str(e)))
pbar.update(1)
pbar.set_postfix(success=stats["success"], skipped=stats["skipped"], errors=stats["error"])
# Print summary
print("\n" + "="*60)
print("MIGRATION COMPLETE")
print("="*60)
print(f" Success: {stats['success']}")
print(f" Skipped (already migrated): {stats['skipped']}")
print(f" Errors: {stats['error']}")
if errors:
print("\nErrors:")
for filepath, message in errors[:10]:
print(f" {filepath}: {message}")
if len(errors) > 10:
print(f" ... and {len(errors) - 10} more errors")
if __name__ == "__main__":
main()