#!/usr/bin/env python3 """ Cache Migration Script v2: Adds wallet/graph/image data to existing cache files. This script updates existing cache files to include: 1. Wallet profiles and socials 2. Wallet holdings 3. Graph links (from Neo4j) 4. Token images (as bytes) This enables fully offline training with zero DB calls during __getitem__. Usage: python scripts/migrate_cache_v2.py \ --cache_dir /workspace/apollo/data/cache \ --num_workers 8 \ --batch_size 100 Requirements: - ClickHouse and Neo4j must be accessible - Set environment variables: CLICKHOUSE_HOST, CLICKHOUSE_PORT, NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD """ import os import sys import argparse import torch import datetime import requests from pathlib import Path from concurrent.futures import ProcessPoolExecutor, as_completed from collections import defaultdict from tqdm import tqdm from io import BytesIO # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) from data.data_fetcher import DataFetcher from clickhouse_driver import Client as ClickHouseClient from neo4j import GraphDatabase def create_fetcher(): """Create a DataFetcher instance with DB connections.""" ch_client = ClickHouseClient( host=os.getenv("CLICKHOUSE_HOST", "localhost"), port=int(os.getenv("CLICKHOUSE_PORT", 9000)), ) neo4j_driver = GraphDatabase.driver( os.getenv("NEO4J_URI", "bolt://localhost:7687"), auth=( os.getenv("NEO4J_USER", "neo4j"), os.getenv("NEO4J_PASSWORD", "password") ) ) return DataFetcher(clickhouse_client=ch_client, neo4j_driver=neo4j_driver) def _timestamp_to_order_value(ts_value) -> float: """Convert various timestamp formats to float.""" if isinstance(ts_value, datetime.datetime): if ts_value.tzinfo is None: ts_value = ts_value.replace(tzinfo=datetime.timezone.utc) return ts_value.timestamp() try: return float(ts_value) except (TypeError, ValueError): return 0.0 def migrate_single_file(args): """Migrate a single cache file to v2 format.""" filepath, fetcher_config = args try: # Load existing cache file raw_data = torch.load(filepath, map_location='cpu', weights_only=False) # Check if already migrated if 'cached_wallet_data' in raw_data and 'cached_graph_data' in raw_data: return filepath, "skipped", "Already migrated" # Create fetcher for this worker ch_client = ClickHouseClient( host=fetcher_config["clickhouse_host"], port=fetcher_config["clickhouse_port"], ) neo4j_driver = GraphDatabase.driver( fetcher_config["neo4j_uri"], auth=(fetcher_config["neo4j_user"], fetcher_config["neo4j_password"]) ) fetcher = DataFetcher(clickhouse_client=ch_client, neo4j_driver=neo4j_driver) # Extract token info token_address = raw_data.get('token_address') creator_address = raw_data.get('creator_address') if not token_address: return filepath, "error", "Missing token_address" # Collect all unique wallets from all events all_wallets = set() all_wallets.add(creator_address) for trade in raw_data.get('trades', []): if trade.get('maker'): all_wallets.add(trade['maker']) for transfer in raw_data.get('transfers', []): if transfer.get('source'): all_wallets.add(transfer['source']) if transfer.get('destination'): all_wallets.add(transfer['destination']) for pool in raw_data.get('pool_creations', []): if pool.get('creator_address'): all_wallets.add(pool['creator_address']) for liq in raw_data.get('liquidity_changes', []): if liq.get('lp_provider'): all_wallets.add(liq['lp_provider']) # Add wallets from holder snapshots for snapshot in raw_data.get('holder_snapshots_list', []): if isinstance(snapshot, dict): for holder in snapshot.get('holders', []): if holder.get('wallet_address'): all_wallets.add(holder['wallet_address']) all_wallets.discard(None) all_wallets.discard('') wallet_list = list(all_wallets) # Determine max T_cutoff from trades trades = raw_data.get('trades', []) if trades: trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades if t.get('timestamp')] if trade_ts_values: last_trade_ts_val = max(trade_ts_values) max_T_cutoff = datetime.datetime.fromtimestamp(last_trade_ts_val, tz=datetime.timezone.utc) else: max_T_cutoff = datetime.datetime.now(datetime.timezone.utc) else: max_T_cutoff = datetime.datetime.now(datetime.timezone.utc) # Fetch wallet profiles and socials try: cached_profiles, cached_socials = fetcher.fetch_wallet_profiles_and_socials( wallet_list, max_T_cutoff ) except Exception as e: cached_profiles, cached_socials = {}, {} # Fetch wallet holdings try: cached_holdings = fetcher.fetch_wallet_holdings(wallet_list, max_T_cutoff) except Exception as e: cached_holdings = {} # Fetch graph links try: cached_graph_entities, cached_graph_links = fetcher.fetch_graph_links( wallet_list, max_T_cutoff, max_degrees=1 ) except Exception as e: cached_graph_entities, cached_graph_links = {}, {} # Fetch token image cached_image_bytes = None http_session = requests.Session() try: # Try Bullx first bullx_image_url = f"https://image.bullx.io/1399811149/{token_address}?retry=0" resp = http_session.get(bullx_image_url, timeout=5) if resp.status_code == 200: cached_image_bytes = resp.content else: # Fallback to IPFS token_uri = raw_data.get('token_uri') if token_uri and 'ipfs/' in str(token_uri): ipfs_gateways = [ "https://pump.mypinata.cloud/ipfs/", "https://dweb.link/ipfs/", "https://cloudflare-ipfs.com/ipfs/", ] metadata_hash = token_uri.split('ipfs/')[-1] for gateway in ipfs_gateways: try: metadata_resp = http_session.get(f"{gateway}{metadata_hash}", timeout=5) if metadata_resp.status_code == 200: metadata = metadata_resp.json() image_url = metadata.get('image', '') if image_url and 'ipfs/' in image_url: image_hash = image_url.split('ipfs/')[-1] for img_gateway in ipfs_gateways: try: img_resp = http_session.get(f"{img_gateway}{image_hash}", timeout=5) if img_resp.status_code == 200: cached_image_bytes = img_resp.content break except: continue break except: continue except Exception as e: pass # Store all cached data raw_data['cached_wallet_data'] = { 'profiles': cached_profiles, 'socials': cached_socials, 'holdings': cached_holdings, } raw_data['cached_graph_data'] = { 'entities': cached_graph_entities, 'links': cached_graph_links, } raw_data['cached_image_bytes'] = cached_image_bytes raw_data['cached_max_T_cutoff'] = max_T_cutoff.timestamp() # Save updated cache file torch.save(raw_data, filepath) # Close connections neo4j_driver.close() return filepath, "success", f"Migrated {len(wallet_list)} wallets, {len(cached_graph_links)} graph links" except Exception as e: return filepath, "error", str(e) def main(): parser = argparse.ArgumentParser(description="Migrate cache files to v2 format with complete offline data") parser.add_argument("--cache_dir", type=str, required=True, help="Path to cache directory") parser.add_argument("--num_workers", type=int, default=4, help="Number of parallel workers") parser.add_argument("--batch_size", type=int, default=50, help="Files per batch for progress reporting") parser.add_argument("--start_idx", type=int, default=0, help="Start from this file index (for resume)") parser.add_argument("--max_files", type=int, default=None, help="Max files to process (for testing)") args = parser.parse_args() cache_dir = Path(args.cache_dir) if not cache_dir.exists(): print(f"ERROR: Cache directory not found: {cache_dir}") sys.exit(1) # Find all cache files cache_files = sorted(cache_dir.glob("sample_*.pt"), key=lambda p: int(p.stem.split('_')[1])) if args.start_idx > 0: cache_files = cache_files[args.start_idx:] if args.max_files: cache_files = cache_files[:args.max_files] print(f"Found {len(cache_files)} cache files to process") # Fetcher config for workers fetcher_config = { "clickhouse_host": os.getenv("CLICKHOUSE_HOST", "localhost"), "clickhouse_port": int(os.getenv("CLICKHOUSE_PORT", 9000)), "neo4j_uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"), "neo4j_user": os.getenv("NEO4J_USER", "neo4j"), "neo4j_password": os.getenv("NEO4J_PASSWORD", "password"), } # Prepare arguments for workers work_items = [(f, fetcher_config) for f in cache_files] # Statistics stats = {"success": 0, "skipped": 0, "error": 0} errors = [] # Process files in parallel with ProcessPoolExecutor(max_workers=args.num_workers) as executor: futures = {executor.submit(migrate_single_file, item): item[0] for item in work_items} with tqdm(total=len(cache_files), desc="Migrating cache files") as pbar: for future in as_completed(futures): filepath = futures[future] try: _, status, message = future.result() stats[status] += 1 if status == "error": errors.append((filepath, message)) except Exception as e: stats["error"] += 1 errors.append((filepath, str(e))) pbar.update(1) pbar.set_postfix(success=stats["success"], skipped=stats["skipped"], errors=stats["error"]) # Print summary print("\n" + "="*60) print("MIGRATION COMPLETE") print("="*60) print(f" Success: {stats['success']}") print(f" Skipped (already migrated): {stats['skipped']}") print(f" Errors: {stats['error']}") if errors: print("\nErrors:") for filepath, message in errors[:10]: print(f" {filepath}: {message}") if len(errors) > 10: print(f" ... and {len(errors) - 10} more errors") if __name__ == "__main__": main()