from typing import List, Optional from dataclasses import dataclass import time import json from queue import Queue from threading import Lock import duckdb from huggingface_hub import HfApi, HfFileSystem from config import get_hf_token_cached # Initialize token from .env @dataclass class Event: """Represents a CUDA-like event for synchronization""" event_id: str timestamp: float completed: bool = False state_json: Optional[dict] = None class Stream: """Represents a CUDA-like stream for concurrent execution""" DB_URL = "hf://datasets/Fred808/helium/storage.json" def __init__(self, stream_id: int, db_url: Optional[str] = None): self.stream_id = stream_id self.events: List[Event] = [] self.operation_queue: Queue = Queue() self.lock = Lock() self.is_active = True # Initialize database connection self.db_url = db_url or self.DB_URL self.max_retries = 3 self._connect_with_retries() self._setup_database() def _connect_with_retries(self): """Establish database connection with retry logic""" for attempt in range(self.max_retries): try: self.conn = self._init_db_connection() return except Exception as e: if attempt == self.max_retries - 1: raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}") time.sleep(1) def _init_db_connection(self) -> duckdb.DuckDBPyConnection: """Initialize database connection with HuggingFace configuration""" # Convert HF URL to S3 path _, _, owner, dataset, db_file = self.db_url.split('/', 4) db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}" # Connect to remote database conn = duckdb.connect(db_path) conn.execute("INSTALL httpfs;") conn.execute("LOAD httpfs;") conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';") conn.execute("SET s3_use_ssl=true;") conn.execute("SET s3_url_style='path';") conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';") conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';") return conn def _setup_database(self): """Initialize database tables""" # Events table self.conn.execute(""" CREATE TABLE IF NOT EXISTS stream_events ( event_id VARCHAR PRIMARY KEY, stream_id BIGINT, timestamp DOUBLE, completed BOOLEAN DEFAULT false, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, completed_at TIMESTAMP, state_json JSON ) """) # Operations table self.conn.execute(""" CREATE TABLE IF NOT EXISTS stream_operations ( operation_id VARCHAR PRIMARY KEY, stream_id BIGINT, operation_type VARCHAR, args JSON, kwargs JSON, status VARCHAR DEFAULT 'pending', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, started_at TIMESTAMP, completed_at TIMESTAMP, error_message VARCHAR ) """) def record_event(self) -> Event: """Record an event in the stream""" with self.lock: event_id = f"event_{self.stream_id}_{time.time_ns()}" event = Event(event_id=event_id, timestamp=time.time()) # Record event in database self.conn.execute(""" INSERT INTO stream_events ( event_id, stream_id, timestamp, state_json ) VALUES (?, ?, ?, ?) """, [event_id, self.stream_id, event.timestamp, {"status": "created"}]) self.events.append(event) return event def wait_event(self, event: Event): """Wait for a specific event to complete""" while True: # Check database for completion result = self.conn.execute(""" SELECT completed, state_json FROM stream_events WHERE event_id = ? """, [event.event_id]).fetchall() if result and result[0][0]: event.completed = True event.state_json = result[0][1] break if event.completed: break time.sleep(0.001) # Small sleep to prevent busy waiting def synchronize(self): """Synchronize the stream, waiting for all operations to complete""" with self.lock: for event in self.events: self.wait_event(event) # Clear completed events self.conn.execute(""" DELETE FROM stream_events WHERE stream_id = ? AND completed = true """, [self.stream_id]) self.events.clear() def add_operation(self, operation: callable, *args, **kwargs): """Add an operation to the stream's queue""" with self.lock: self.operation_queue.put((operation, args, kwargs)) def execute_next(self) -> bool: """Execute the next operation in the queue""" try: with self.lock: if self.operation_queue.empty(): return False operation, args, kwargs = self.operation_queue.get() event = self.record_event() try: operation(*args, **kwargs) finally: event.completed = True return True except Exception as e: print(f"Error in stream {self.stream_id}: {str(e)}") return False class StreamManager: """Manages multiple CUDA-like streams""" def __init__(self): self.streams: List[Stream] = [] self.default_stream = self.create_stream() def create_stream(self) -> Stream: """Create a new stream""" stream_id = len(self.streams) stream = Stream(stream_id) self.streams.append(stream) return stream def get_stream(self, stream_id: int) -> Optional[Stream]: """Get a stream by its ID""" if 0 <= stream_id < len(self.streams): return self.streams[stream_id] return None def synchronize_all(self): """Synchronize all streams""" for stream in self.streams: stream.synchronize() def synchronize_stream(self, stream_id: int): """Synchronize a specific stream""" stream = self.get_stream(stream_id) if stream: stream.synchronize() def execute_streams(self): """Execute operations in all streams""" while True: executed = False for stream in self.streams: if stream.execute_next(): executed = True if not executed: break