""" Cross-GPU Stream Manager for coordinating operations across multiple GPUs Uses DuckDB with HuggingFace for persistent storage """ from typing import Dict, List, Optional, Any import threading import time import logging import duckdb import json from config import get_hf_token_cached # Initialize token from .env class CrossGPUStreamManager: def __init__(self, db_path: str = "hf://datasets/Fred808/helium/storage.json"): self.stream_lock = threading.Lock() self.transfer_lock = threading.Lock() self.db_path = db_path self.con = self._init_db() def _init_db(self) -> duckdb.DuckDBPyConnection: """Initialize database connection and schema""" con = duckdb.connect(self.db_path) # Configure HuggingFace access con.execute("INSTALL httpfs;") con.execute("LOAD httpfs;") con.execute("SET s3_endpoint='hf.co';") con.execute("SET s3_use_ssl=true;") con.execute("SET s3_url_style='path';") con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';") con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';") # Create streams table con.execute(""" CREATE TABLE IF NOT EXISTS streams ( stream_id VARCHAR PRIMARY KEY, source_gpu VARCHAR, target_gpu VARCHAR, state VARCHAR, created_at TIMESTAMP, last_active TIMESTAMP, transfer_count INTEGER, total_bytes_transferred BIGINT ) """) # Create transfers table con.execute(""" CREATE TABLE IF NOT EXISTS transfers ( transfer_id VARCHAR PRIMARY KEY, stream_id VARCHAR, transfer_size BIGINT, started_at TIMESTAMP, completed_at TIMESTAMP, state VARCHAR ) """) return con def create_stream(self, stream_id: str, source_gpu: str, target_gpu: str) -> Dict[str, Any]: """Create a new cross-GPU stream""" with self.stream_lock: # Check if stream exists result = self.con.execute(""" SELECT * FROM streams WHERE stream_id = ? """, [stream_id]).fetchone() if result: return dict(zip(['id', 'source_gpu', 'target_gpu', 'state', 'created_at', 'last_active', 'transfer_count', 'total_bytes_transferred'], result)) # Create new stream now = time.time() self.con.execute(""" INSERT INTO streams ( stream_id, source_gpu, target_gpu, state, created_at, last_active, transfer_count, total_bytes_transferred ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, [stream_id, source_gpu, target_gpu, 'initialized', now, now, 0, 0]) logging.info(f"Created cross-GPU stream {stream_id} from {source_gpu} to {target_gpu}") return { 'id': stream_id, 'source_gpu': source_gpu, 'target_gpu': target_gpu, 'state': 'initialized', 'created_at': now, 'last_active': now, 'transfer_count': 0, 'total_bytes_transferred': 0 } def start_transfer(self, stream_id: str, transfer_size: int) -> str: """Start a new data transfer on a stream""" with self.stream_lock: # Check if stream exists result = self.con.execute(""" SELECT transfer_count FROM streams WHERE stream_id = ? """, [stream_id]).fetchone() if not result: raise ValueError(f"Stream {stream_id} does not exist") transfer_count = result[0] transfer_id = f"transfer_{stream_id}_{transfer_count}" now = time.time() # Create transfer record with self.transfer_lock: self.con.execute(""" INSERT INTO transfers ( transfer_id, stream_id, transfer_size, started_at, state ) VALUES (?, ?, ?, ?, ?) """, [transfer_id, stream_id, transfer_size, now, 'in_progress']) # Update stream self.con.execute(""" UPDATE streams SET transfer_count = transfer_count + 1, last_active = ? WHERE stream_id = ? """, [now, stream_id]) logging.info(f"Started transfer {transfer_id} on stream {stream_id}") return transfer_id def complete_transfer(self, transfer_id: str) -> None: """Mark a transfer as complete and update statistics""" with self.transfer_lock: # Get transfer info transfer = self.con.execute(""" SELECT stream_id, transfer_size FROM transfers WHERE transfer_id = ? AND state = 'in_progress' """, [transfer_id]).fetchone() if not transfer: raise ValueError(f"Transfer {transfer_id} not found or not in progress") stream_id, transfer_size = transfer now = time.time() # Update transfer status self.con.execute(""" UPDATE transfers SET state = 'completed', completed_at = ? WHERE transfer_id = ? """, [now, transfer_id]) # Update stream statistics with self.stream_lock: self.con.execute(""" UPDATE streams SET total_bytes_transferred = total_bytes_transferred + ?, last_active = ? WHERE stream_id = ? """, [transfer_size, now, stream_id]) logging.info(f"Completed transfer {transfer_id}") def get_stream_stats(self, stream_id: str) -> Dict[str, Any]: """Get statistics for a specific stream""" with self.stream_lock: result = self.con.execute(""" SELECT transfer_count, total_bytes_transferred, created_at, last_active FROM streams WHERE stream_id = ? """, [stream_id]).fetchone() if not result: raise ValueError(f"Stream {stream_id} does not exist") transfer_count, total_bytes, created_at, last_active = result now = time.time() return { 'transfer_count': transfer_count, 'total_bytes_transferred': total_bytes, 'uptime': now - created_at, 'last_active_ago': now - last_active } def cleanup_inactive_streams(self, timeout: float = 300.0) -> List[str]: """Remove streams that have been inactive for the specified timeout""" current_time = time.time() cutoff_time = current_time - timeout with self.stream_lock: # Find inactive streams results = self.con.execute(""" DELETE FROM streams WHERE last_active < ? RETURNING stream_id """, [cutoff_time]).fetchall() cleaned_streams = [r[0] for r in results] for stream_id in cleaned_streams: logging.info(f"Cleaned up inactive stream {stream_id}") return cleaned_streams