from typing import List, Optional, Dict, Any from dataclasses import dataclass import time from queue import Queue from threading import Lock import logging import json import duckdb from huggingface_hub import HfApi, HfFileSystem from config import get_db_url, get_hf_token_cached @dataclass class CrossGPUEvent: """Represents a cross-GPU synchronization event""" event_id: str timestamp: float source_gpu: int target_gpu: Optional[int] completed: bool = False transfer_size: Optional[int] = None nvlink_path: Optional[List[str]] = None class CrossGPUStream: """Represents a stream that can execute operations across multiple GPUs""" DB_URL = "hf://datasets/Fred808/helium/storage.json" def __init__(self, stream_id: int, db_url: Optional[str] = None): self.stream_id = stream_id self.events: List[CrossGPUEvent] = [] self.operation_queue: Queue = Queue() self.lock = Lock() self.is_active = True self.db_url = db_url or self.DB_URL self.max_retries = 3 self._connect_with_retries() self._setup_database() def _connect_with_retries(self): """Establish database connection with retry logic""" for attempt in range(self.max_retries): try: self.conn = self._init_db_connection() return except Exception as e: if attempt == self.max_retries - 1: raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}") time.sleep(1) def _init_db_connection(self) -> duckdb.DuckDBPyConnection: """Initialize database connection with HuggingFace configuration""" # Convert HF URL to S3 path _, _, owner, dataset, db_file = self.db_url.split('/', 4) db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}" # Connect to remote database conn = duckdb.connect(db_path) conn.execute("INSTALL httpfs;") conn.execute("LOAD httpfs;") conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';") conn.execute("SET s3_use_ssl=true;") conn.execute("SET s3_url_style='path';") conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';") conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';") return conn def _setup_database(self): """Initialize database tables""" # Events table self.conn.execute(""" CREATE TABLE IF NOT EXISTS stream_events ( event_id VARCHAR PRIMARY KEY, stream_id BIGINT, timestamp DOUBLE, source_gpu INTEGER, target_gpu INTEGER, completed BOOLEAN, transfer_size BIGINT, nvlink_path JSON, state_json JSON ) """) # Operations table self.conn.execute(""" CREATE TABLE IF NOT EXISTS stream_operations ( operation_id BIGINT PRIMARY KEY, stream_id BIGINT, operation_data JSON, status VARCHAR, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, completed_at TIMESTAMP, error_message VARCHAR ) """) # Sync points table self.conn.execute(""" CREATE TABLE IF NOT EXISTS sync_points ( sync_id VARCHAR PRIMARY KEY, stream_id BIGINT, gpu_ids VARCHAR, status VARCHAR DEFAULT 'pending', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, completed_at TIMESTAMP, error_message VARCHAR ) """) def record_event(self, source_gpu: int, target_gpu: Optional[int] = None, transfer_size: Optional[int] = None) -> CrossGPUEvent: """Record a cross-GPU event in the stream""" with self.lock: event_id = f"event_{self.stream_id}_{len(self.events)}_{time.time_ns()}" event = CrossGPUEvent( event_id=event_id, timestamp=time.time(), source_gpu=source_gpu, target_gpu=target_gpu, transfer_size=transfer_size ) # Store event in database self.conn.execute(""" INSERT INTO stream_events ( event_id, stream_id, timestamp, source_gpu, target_gpu, completed, transfer_size, nvlink_path, state_json ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, [ event_id, self.stream_id, event.timestamp, source_gpu, target_gpu, False, transfer_size, None, {"status": "created"} ]) self.events.append(event) return event def wait_event(self, event: CrossGPUEvent): """Wait for a cross-GPU event to complete""" while True: # Check database for event completion result = self.conn.execute(""" SELECT completed FROM stream_events WHERE event_id = ? """, [event.event_id]).fetchall() if result and result[0][0]: event.completed = True break if event.completed: break time.sleep(0.001) # Small sleep to prevent busy waiting def synchronize(self): """Synchronize all operations in the stream""" with self.lock: for event in self.events: self.wait_event(event) self.events.clear() # Clear completed events from database self.conn.execute(""" DELETE FROM stream_events WHERE stream_id = ? AND completed = true """, [self.stream_id]) def add_operation(self, operation: Dict[str, Any]): """Add a cross-GPU operation to the stream""" with self.lock: # Record operation in database self.conn.execute(""" INSERT INTO stream_operations ( stream_id, operation_data, status ) VALUES (?, ?, ?) """, [self.stream_id, operation, "pending"]) # If operation involves data transfer, find optimal path if operation.get('type') == 'transfer': src_gpu = operation['source_gpu'] dst_gpu = operation['target_gpu'] size = operation['size'] # Query for best transfer path paths = self.conn.execute(""" WITH RECURSIVE gpu_path(src, dst, path, total_bandwidth, hops) AS ( -- Direct connections SELECT chip_a_id, chip_b_id, chip_a_id || ',' || chip_b_id, bandwidth_tbps, 1 FROM optical_interconnects WHERE (chip_a_id = ? AND chip_b_id = ?) OR (chip_a_id = ? AND chip_b_id = ?) UNION ALL -- Multi-hop paths SELECT p.src, i.chip_b_id, p.path || ',' || i.chip_b_id, LEAST(p.total_bandwidth, i.bandwidth_tbps), p.hops + 1 FROM gpu_path p JOIN optical_interconnects i ON p.dst = i.chip_a_id WHERE p.hops < 3 -- Limit path length AND i.chip_b_id != ALL(string_to_array(p.path, ',')::integer[]) ) SELECT path, total_bandwidth FROM gpu_path WHERE src = ? AND dst = ? ORDER BY total_bandwidth DESC, hops ASC LIMIT 1 """, (src_gpu, dst_gpu, dst_gpu, src_gpu, src_gpu, dst_gpu)).fetchone() if paths: operation['nvlink_path'] = paths[0].split(',') operation['expected_bandwidth'] = float(paths[1]) self.operation_queue.put(operation) def execute_next(self) -> bool: """Execute the next operation in the queue""" try: with self.lock: if self.operation_queue.empty(): return False operation = self.operation_queue.get() event = self.record_event( source_gpu=operation.get('source_gpu'), target_gpu=operation.get('target_gpu'), transfer_size=operation.get('size') ) try: if operation.get('type') == 'transfer': self._execute_transfer(operation) elif operation.get('type') == 'compute': self._execute_compute(operation) elif operation.get('type') == 'sync': self._execute_sync(operation) # Mark event as completed event.completed = True if self.hal: self.hal.execute(""" UPDATE stream_events SET completed = TRUE, nvlink_path = ?, state_json = ? WHERE event_id = ? """, ( ','.join(operation.get('nvlink_path', [])), json.dumps({"status": "completed"}), event.event_id )) return True except Exception as e: logging.error(f"Error in stream {self.stream_id}: {str(e)}") if self.hal: self.hal.execute(""" UPDATE stream_events SET state_json = ? WHERE event_id = ? """, ( json.dumps({"status": "error", "error": str(e)}), event.event_id )) return False except Exception as e: logging.error(f"Error in stream {self.stream_id}: {str(e)}") return False def _execute_transfer(self, operation: Dict[str, Any]): """Execute a data transfer operation using NVLink path""" if not self.hal: raise RuntimeError("No HAL connection available for transfer") src_gpu = operation['source_gpu'] dst_gpu = operation['target_gpu'] size = operation['size'] nvlink_path = operation.get('nvlink_path', []) # Record transfer in HAL database self.hal.execute(""" INSERT INTO memory_transfers ( source_chip, target_chip, size_bytes, nvlink_path, start_time, bandwidth_used ) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, ?) """, ( src_gpu, dst_gpu, size, ','.join(map(str, nvlink_path)), operation.get('expected_bandwidth', 0) )) def _execute_compute(self, operation: Dict[str, Any]): """Execute a compute operation on specified GPU""" if not self.hal: raise RuntimeError("No HAL connection available for compute") gpu_id = operation['gpu_id'] # Update GPU state in HAL self.hal.execute(""" UPDATE gpu_chips SET state_json = json_patch( state_json, json_object( 'current_operation', json(?), 'last_updated', json(CURRENT_TIMESTAMP) ) ) WHERE chip_id = ? """, (json.dumps(operation), gpu_id)) def _execute_sync(self, operation: Dict[str, Any]): """Execute a synchronization operation""" gpu_ids = operation.get('gpu_ids', []) if not gpu_ids: return # Create sync point in database sync_id = f"sync_{self.stream_id}_{time.time_ns()}" self.conn.execute(""" INSERT INTO sync_points ( sync_id, stream_id, gpu_ids ) VALUES (?, ?, ?) """, [sync_id, self.stream_id, ','.join(map(str, gpu_ids))]) # Wait for all GPUs to reach sync point while True: result = self.conn.execute(""" SELECT COUNT(*) FROM gpu_chips WHERE chip_id = ANY(string_split(?, ',')::INTEGER[]) AND state_json->>'sync_id' = ? """, [','.join(map(str, gpu_ids)), sync_id]).fetchall() if result[0][0] == len(gpu_ids): # All GPUs reached sync point self.conn.execute(""" UPDATE sync_points SET status = 'completed', completed_at = CURRENT_TIMESTAMP WHERE sync_id = ? """, [sync_id]) break time.sleep(0.001) class CrossGPUStreamManager: """Manages multiple cross-GPU streams""" DB_URL = "hf://datasets/Fred808/helium/storage.json" def __init__(self, db_url: Optional[str] = None): self.streams: List[CrossGPUStream] = [] self.db_url = db_url or self.DB_URL self.max_retries = 3 self._connect_with_retries() # Create default stream self.default_stream = self.create_stream() def _connect_with_retries(self): """Establish database connection with retry logic""" for attempt in range(self.max_retries): try: self.conn = self._init_db_connection() self._setup_database() return except Exception as e: if attempt == self.max_retries - 1: raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}") time.sleep(1) def _init_db_connection(self) -> duckdb.DuckDBPyConnection: """Initialize database connection with HuggingFace configuration""" # Convert HF URL to S3 path _, _, owner, dataset, db_file = self.db_url.split('/', 4) db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}" # Connect to remote database conn = duckdb.connect(db_path) conn.execute("INSTALL httpfs;") conn.execute("LOAD httpfs;") conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';") conn.execute("SET s3_use_ssl=true;") conn.execute("SET s3_url_style='path';") conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';") conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';") return conn def _setup_database(self): """Initialize database tables""" self.conn.execute(""" CREATE TABLE IF NOT EXISTS streams ( stream_id BIGINT PRIMARY KEY, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, is_active BOOLEAN DEFAULT true, state_json JSON ) """) def create_stream(self) -> CrossGPUStream: """Create a new cross-GPU stream""" stream_id = len(self.streams) # Register stream in database self.conn.execute(""" INSERT INTO streams (stream_id, state_json) VALUES (?, ?) """, [stream_id, {"status": "created"}]) # Create stream object with same DB connection stream = CrossGPUStream(stream_id, self.db_url) self.streams.append(stream) return stream def get_stream(self, stream_id: int) -> Optional[CrossGPUStream]: """Get a stream by its ID""" if 0 <= stream_id < len(self.streams): return self.streams[stream_id] return None def synchronize_all(self): """Synchronize all streams""" for stream in self.streams: stream.synchronize() def synchronize_stream(self, stream_id: int): """Synchronize a specific stream""" stream = self.get_stream(stream_id) if stream: stream.synchronize() def execute_streams(self): """Execute operations in all streams""" while True: executed = False for stream in self.streams: if stream.execute_next(): executed = True if not executed: break def get_optimal_transfer_path(self, src_gpu: int, dst_gpu: int, size: int) -> List[str]: """Get optimal NVLink path for data transfer""" result = self.conn.execute(""" WITH RECURSIVE gpu_path(src, dst, path, total_bandwidth, hops) AS ( -- Direct connections SELECT chip_a_id, chip_b_id, CAST(chip_a_id AS VARCHAR) || ',' || CAST(chip_b_id AS VARCHAR), bandwidth_tbps, 1 FROM optical_interconnects WHERE (chip_a_id = ? AND chip_b_id = ?) OR (chip_a_id = ? AND chip_b_id = ?) UNION ALL -- Multi-hop paths SELECT p.src, i.chip_b_id, p.path || ',' || CAST(i.chip_b_id AS VARCHAR), LEAST(p.total_bandwidth, i.bandwidth_tbps), p.hops + 1 FROM gpu_path p JOIN optical_interconnects i ON p.dst = i.chip_a_id WHERE p.hops < 3 -- Limit path length AND NOT list_contains(string_split(p.path, ','), CAST(i.chip_b_id AS VARCHAR)) ) SELECT path, total_bandwidth FROM gpu_path WHERE src = ? AND dst = ? ORDER BY total_bandwidth DESC, hops ASC LIMIT 1 """, [src_gpu, dst_gpu, dst_gpu, src_gpu, src_gpu, dst_gpu]).fetchall() if result: return result[0][0].split(',') return [] def add_cross_gpu_operation(self, stream_id: int, operation: Dict[str, Any]): """Add a cross-GPU operation to a specific stream""" stream = self.get_stream(stream_id) if not stream: raise ValueError(f"Invalid stream ID: {stream_id}") # If it's a transfer operation, find optimal path if operation.get('type') == 'transfer': path = self.get_optimal_transfer_path( operation['source_gpu'], operation['target_gpu'], operation['size'] ) operation['nvlink_path'] = path stream.add_operation(operation)