| |
| """ |
| Cache Migration Script v2: Adds wallet/graph/image data to existing cache files. |
| |
| This script updates existing cache files to include: |
| 1. Wallet profiles and socials |
| 2. Wallet holdings |
| 3. Graph links (from Neo4j) |
| 4. Token images (as bytes) |
| |
| This enables fully offline training with zero DB calls during __getitem__. |
| |
| Usage: |
| python scripts/migrate_cache_v2.py \ |
| --cache_dir /workspace/apollo/data/cache \ |
| --num_workers 8 \ |
| --batch_size 100 |
| |
| Requirements: |
| - ClickHouse and Neo4j must be accessible |
| - Set environment variables: CLICKHOUSE_HOST, CLICKHOUSE_PORT, NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import torch |
| import datetime |
| import requests |
| from pathlib import Path |
| from concurrent.futures import ProcessPoolExecutor, as_completed |
| from collections import defaultdict |
| from tqdm import tqdm |
| from io import BytesIO |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from data.data_fetcher import DataFetcher |
| from clickhouse_driver import Client as ClickHouseClient |
| from neo4j import GraphDatabase |
|
|
|
|
| def create_fetcher(): |
| """Create a DataFetcher instance with DB connections.""" |
| ch_client = ClickHouseClient( |
| host=os.getenv("CLICKHOUSE_HOST", "localhost"), |
| port=int(os.getenv("CLICKHOUSE_PORT", 9000)), |
| ) |
| neo4j_driver = GraphDatabase.driver( |
| os.getenv("NEO4J_URI", "bolt://localhost:7687"), |
| auth=( |
| os.getenv("NEO4J_USER", "neo4j"), |
| os.getenv("NEO4J_PASSWORD", "password") |
| ) |
| ) |
| return DataFetcher(clickhouse_client=ch_client, neo4j_driver=neo4j_driver) |
|
|
|
|
| def _timestamp_to_order_value(ts_value) -> float: |
| """Convert various timestamp formats to float.""" |
| if isinstance(ts_value, datetime.datetime): |
| if ts_value.tzinfo is None: |
| ts_value = ts_value.replace(tzinfo=datetime.timezone.utc) |
| return ts_value.timestamp() |
| try: |
| return float(ts_value) |
| except (TypeError, ValueError): |
| return 0.0 |
|
|
|
|
| def migrate_single_file(args): |
| """Migrate a single cache file to v2 format.""" |
| filepath, fetcher_config = args |
|
|
| try: |
| |
| raw_data = torch.load(filepath, map_location='cpu', weights_only=False) |
|
|
| |
| if 'cached_wallet_data' in raw_data and 'cached_graph_data' in raw_data: |
| return filepath, "skipped", "Already migrated" |
|
|
| |
| ch_client = ClickHouseClient( |
| host=fetcher_config["clickhouse_host"], |
| port=fetcher_config["clickhouse_port"], |
| ) |
| neo4j_driver = GraphDatabase.driver( |
| fetcher_config["neo4j_uri"], |
| auth=(fetcher_config["neo4j_user"], fetcher_config["neo4j_password"]) |
| ) |
| fetcher = DataFetcher(clickhouse_client=ch_client, neo4j_driver=neo4j_driver) |
|
|
| |
| token_address = raw_data.get('token_address') |
| creator_address = raw_data.get('creator_address') |
|
|
| if not token_address: |
| return filepath, "error", "Missing token_address" |
|
|
| |
| all_wallets = set() |
| all_wallets.add(creator_address) |
|
|
| for trade in raw_data.get('trades', []): |
| if trade.get('maker'): |
| all_wallets.add(trade['maker']) |
|
|
| for transfer in raw_data.get('transfers', []): |
| if transfer.get('source'): |
| all_wallets.add(transfer['source']) |
| if transfer.get('destination'): |
| all_wallets.add(transfer['destination']) |
|
|
| for pool in raw_data.get('pool_creations', []): |
| if pool.get('creator_address'): |
| all_wallets.add(pool['creator_address']) |
|
|
| for liq in raw_data.get('liquidity_changes', []): |
| if liq.get('lp_provider'): |
| all_wallets.add(liq['lp_provider']) |
|
|
| |
| for snapshot in raw_data.get('holder_snapshots_list', []): |
| if isinstance(snapshot, dict): |
| for holder in snapshot.get('holders', []): |
| if holder.get('wallet_address'): |
| all_wallets.add(holder['wallet_address']) |
|
|
| all_wallets.discard(None) |
| all_wallets.discard('') |
| wallet_list = list(all_wallets) |
|
|
| |
| trades = raw_data.get('trades', []) |
| if trades: |
| trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades if t.get('timestamp')] |
| if trade_ts_values: |
| last_trade_ts_val = max(trade_ts_values) |
| max_T_cutoff = datetime.datetime.fromtimestamp(last_trade_ts_val, tz=datetime.timezone.utc) |
| else: |
| max_T_cutoff = datetime.datetime.now(datetime.timezone.utc) |
| else: |
| max_T_cutoff = datetime.datetime.now(datetime.timezone.utc) |
|
|
| |
| try: |
| cached_profiles, cached_socials = fetcher.fetch_wallet_profiles_and_socials( |
| wallet_list, max_T_cutoff |
| ) |
| except Exception as e: |
| cached_profiles, cached_socials = {}, {} |
|
|
| |
| try: |
| cached_holdings = fetcher.fetch_wallet_holdings(wallet_list, max_T_cutoff) |
| except Exception as e: |
| cached_holdings = {} |
|
|
| |
| try: |
| cached_graph_entities, cached_graph_links = fetcher.fetch_graph_links( |
| wallet_list, max_T_cutoff, max_degrees=1 |
| ) |
| except Exception as e: |
| cached_graph_entities, cached_graph_links = {}, {} |
|
|
| |
| cached_image_bytes = None |
| http_session = requests.Session() |
| try: |
| |
| bullx_image_url = f"https://image.bullx.io/1399811149/{token_address}?retry=0" |
| resp = http_session.get(bullx_image_url, timeout=5) |
| if resp.status_code == 200: |
| cached_image_bytes = resp.content |
| else: |
| |
| token_uri = raw_data.get('token_uri') |
| if token_uri and 'ipfs/' in str(token_uri): |
| ipfs_gateways = [ |
| "https://pump.mypinata.cloud/ipfs/", |
| "https://dweb.link/ipfs/", |
| "https://cloudflare-ipfs.com/ipfs/", |
| ] |
| metadata_hash = token_uri.split('ipfs/')[-1] |
| for gateway in ipfs_gateways: |
| try: |
| metadata_resp = http_session.get(f"{gateway}{metadata_hash}", timeout=5) |
| if metadata_resp.status_code == 200: |
| metadata = metadata_resp.json() |
| image_url = metadata.get('image', '') |
| if image_url and 'ipfs/' in image_url: |
| image_hash = image_url.split('ipfs/')[-1] |
| for img_gateway in ipfs_gateways: |
| try: |
| img_resp = http_session.get(f"{img_gateway}{image_hash}", timeout=5) |
| if img_resp.status_code == 200: |
| cached_image_bytes = img_resp.content |
| break |
| except: |
| continue |
| break |
| except: |
| continue |
| except Exception as e: |
| pass |
|
|
| |
| raw_data['cached_wallet_data'] = { |
| 'profiles': cached_profiles, |
| 'socials': cached_socials, |
| 'holdings': cached_holdings, |
| } |
| raw_data['cached_graph_data'] = { |
| 'entities': cached_graph_entities, |
| 'links': cached_graph_links, |
| } |
| raw_data['cached_image_bytes'] = cached_image_bytes |
| raw_data['cached_max_T_cutoff'] = max_T_cutoff.timestamp() |
|
|
| |
| torch.save(raw_data, filepath) |
|
|
| |
| neo4j_driver.close() |
|
|
| return filepath, "success", f"Migrated {len(wallet_list)} wallets, {len(cached_graph_links)} graph links" |
|
|
| except Exception as e: |
| return filepath, "error", str(e) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Migrate cache files to v2 format with complete offline data") |
| parser.add_argument("--cache_dir", type=str, required=True, help="Path to cache directory") |
| parser.add_argument("--num_workers", type=int, default=4, help="Number of parallel workers") |
| parser.add_argument("--batch_size", type=int, default=50, help="Files per batch for progress reporting") |
| parser.add_argument("--start_idx", type=int, default=0, help="Start from this file index (for resume)") |
| parser.add_argument("--max_files", type=int, default=None, help="Max files to process (for testing)") |
| args = parser.parse_args() |
|
|
| cache_dir = Path(args.cache_dir) |
| if not cache_dir.exists(): |
| print(f"ERROR: Cache directory not found: {cache_dir}") |
| sys.exit(1) |
|
|
| |
| cache_files = sorted(cache_dir.glob("sample_*.pt"), key=lambda p: int(p.stem.split('_')[1])) |
|
|
| if args.start_idx > 0: |
| cache_files = cache_files[args.start_idx:] |
| if args.max_files: |
| cache_files = cache_files[:args.max_files] |
|
|
| print(f"Found {len(cache_files)} cache files to process") |
|
|
| |
| fetcher_config = { |
| "clickhouse_host": os.getenv("CLICKHOUSE_HOST", "localhost"), |
| "clickhouse_port": int(os.getenv("CLICKHOUSE_PORT", 9000)), |
| "neo4j_uri": os.getenv("NEO4J_URI", "bolt://localhost:7687"), |
| "neo4j_user": os.getenv("NEO4J_USER", "neo4j"), |
| "neo4j_password": os.getenv("NEO4J_PASSWORD", "password"), |
| } |
|
|
| |
| work_items = [(f, fetcher_config) for f in cache_files] |
|
|
| |
| stats = {"success": 0, "skipped": 0, "error": 0} |
| errors = [] |
|
|
| |
| with ProcessPoolExecutor(max_workers=args.num_workers) as executor: |
| futures = {executor.submit(migrate_single_file, item): item[0] for item in work_items} |
|
|
| with tqdm(total=len(cache_files), desc="Migrating cache files") as pbar: |
| for future in as_completed(futures): |
| filepath = futures[future] |
| try: |
| _, status, message = future.result() |
| stats[status] += 1 |
| if status == "error": |
| errors.append((filepath, message)) |
| except Exception as e: |
| stats["error"] += 1 |
| errors.append((filepath, str(e))) |
|
|
| pbar.update(1) |
| pbar.set_postfix(success=stats["success"], skipped=stats["skipped"], errors=stats["error"]) |
|
|
| |
| print("\n" + "="*60) |
| print("MIGRATION COMPLETE") |
| print("="*60) |
| print(f" Success: {stats['success']}") |
| print(f" Skipped (already migrated): {stats['skipped']}") |
| print(f" Errors: {stats['error']}") |
|
|
| if errors: |
| print("\nErrors:") |
| for filepath, message in errors[:10]: |
| print(f" {filepath}: {message}") |
| if len(errors) > 10: |
| print(f" ... and {len(errors) - 10} more errors") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|