|
|
from typing import List, Optional
|
|
|
from dataclasses import dataclass
|
|
|
import time
|
|
|
import json
|
|
|
from queue import Queue
|
|
|
from threading import Lock
|
|
|
import duckdb
|
|
|
from huggingface_hub import HfApi, HfFileSystem
|
|
|
from config import get_hf_token_cached
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class Event:
|
|
|
"""Represents a CUDA-like event for synchronization"""
|
|
|
event_id: str
|
|
|
timestamp: float
|
|
|
completed: bool = False
|
|
|
state_json: Optional[dict] = None
|
|
|
|
|
|
class Stream:
|
|
|
"""Represents a CUDA-like stream for concurrent execution"""
|
|
|
DB_URL = "hf://datasets/Fred808/helium/storage.json"
|
|
|
|
|
|
def __init__(self, stream_id: int, db_url: Optional[str] = None):
|
|
|
self.stream_id = stream_id
|
|
|
self.events: List[Event] = []
|
|
|
self.operation_queue: Queue = Queue()
|
|
|
self.lock = Lock()
|
|
|
self.is_active = True
|
|
|
|
|
|
|
|
|
self.db_url = db_url or self.DB_URL
|
|
|
self.max_retries = 3
|
|
|
self._connect_with_retries()
|
|
|
self._setup_database()
|
|
|
|
|
|
def _connect_with_retries(self):
|
|
|
"""Establish database connection with retry logic"""
|
|
|
for attempt in range(self.max_retries):
|
|
|
try:
|
|
|
self.conn = self._init_db_connection()
|
|
|
return
|
|
|
except Exception as e:
|
|
|
if attempt == self.max_retries - 1:
|
|
|
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
|
|
|
time.sleep(1)
|
|
|
|
|
|
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Initialize database connection with HuggingFace configuration"""
|
|
|
|
|
|
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
|
|
|
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
|
|
|
|
|
|
|
|
|
conn = duckdb.connect(db_path)
|
|
|
conn.execute("INSTALL httpfs;")
|
|
|
conn.execute("LOAD httpfs;")
|
|
|
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
|
|
|
conn.execute("SET s3_use_ssl=true;")
|
|
|
conn.execute("SET s3_url_style='path';")
|
|
|
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
|
|
|
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
|
|
|
return conn
|
|
|
|
|
|
def _setup_database(self):
|
|
|
"""Initialize database tables"""
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS stream_events (
|
|
|
event_id VARCHAR PRIMARY KEY,
|
|
|
stream_id BIGINT,
|
|
|
timestamp DOUBLE,
|
|
|
completed BOOLEAN DEFAULT false,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
completed_at TIMESTAMP,
|
|
|
state_json JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS stream_operations (
|
|
|
operation_id VARCHAR PRIMARY KEY,
|
|
|
stream_id BIGINT,
|
|
|
operation_type VARCHAR,
|
|
|
args JSON,
|
|
|
kwargs JSON,
|
|
|
status VARCHAR DEFAULT 'pending',
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
started_at TIMESTAMP,
|
|
|
completed_at TIMESTAMP,
|
|
|
error_message VARCHAR
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
def record_event(self) -> Event:
|
|
|
"""Record an event in the stream"""
|
|
|
with self.lock:
|
|
|
event_id = f"event_{self.stream_id}_{time.time_ns()}"
|
|
|
event = Event(event_id=event_id, timestamp=time.time())
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO stream_events (
|
|
|
event_id, stream_id, timestamp, state_json
|
|
|
) VALUES (?, ?, ?, ?)
|
|
|
""", [event_id, self.stream_id, event.timestamp, {"status": "created"}])
|
|
|
|
|
|
self.events.append(event)
|
|
|
return event
|
|
|
|
|
|
def wait_event(self, event: Event):
|
|
|
"""Wait for a specific event to complete"""
|
|
|
while True:
|
|
|
|
|
|
result = self.conn.execute("""
|
|
|
SELECT completed, state_json
|
|
|
FROM stream_events
|
|
|
WHERE event_id = ?
|
|
|
""", [event.event_id]).fetchall()
|
|
|
|
|
|
if result and result[0][0]:
|
|
|
event.completed = True
|
|
|
event.state_json = result[0][1]
|
|
|
break
|
|
|
|
|
|
if event.completed:
|
|
|
break
|
|
|
|
|
|
time.sleep(0.001)
|
|
|
|
|
|
def synchronize(self):
|
|
|
"""Synchronize the stream, waiting for all operations to complete"""
|
|
|
with self.lock:
|
|
|
for event in self.events:
|
|
|
self.wait_event(event)
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
DELETE FROM stream_events
|
|
|
WHERE stream_id = ? AND completed = true
|
|
|
""", [self.stream_id])
|
|
|
|
|
|
self.events.clear()
|
|
|
|
|
|
def add_operation(self, operation: callable, *args, **kwargs):
|
|
|
"""Add an operation to the stream's queue"""
|
|
|
with self.lock:
|
|
|
self.operation_queue.put((operation, args, kwargs))
|
|
|
|
|
|
def execute_next(self) -> bool:
|
|
|
"""Execute the next operation in the queue"""
|
|
|
try:
|
|
|
with self.lock:
|
|
|
if self.operation_queue.empty():
|
|
|
return False
|
|
|
|
|
|
operation, args, kwargs = self.operation_queue.get()
|
|
|
event = self.record_event()
|
|
|
|
|
|
try:
|
|
|
operation(*args, **kwargs)
|
|
|
finally:
|
|
|
event.completed = True
|
|
|
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
print(f"Error in stream {self.stream_id}: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
class StreamManager:
|
|
|
"""Manages multiple CUDA-like streams"""
|
|
|
def __init__(self):
|
|
|
self.streams: List[Stream] = []
|
|
|
self.default_stream = self.create_stream()
|
|
|
|
|
|
def create_stream(self) -> Stream:
|
|
|
"""Create a new stream"""
|
|
|
stream_id = len(self.streams)
|
|
|
stream = Stream(stream_id)
|
|
|
self.streams.append(stream)
|
|
|
return stream
|
|
|
|
|
|
def get_stream(self, stream_id: int) -> Optional[Stream]:
|
|
|
"""Get a stream by its ID"""
|
|
|
if 0 <= stream_id < len(self.streams):
|
|
|
return self.streams[stream_id]
|
|
|
return None
|
|
|
|
|
|
def synchronize_all(self):
|
|
|
"""Synchronize all streams"""
|
|
|
for stream in self.streams:
|
|
|
stream.synchronize()
|
|
|
|
|
|
def synchronize_stream(self, stream_id: int):
|
|
|
"""Synchronize a specific stream"""
|
|
|
stream = self.get_stream(stream_id)
|
|
|
if stream:
|
|
|
stream.synchronize()
|
|
|
|
|
|
def execute_streams(self):
|
|
|
"""Execute operations in all streams"""
|
|
|
while True:
|
|
|
executed = False
|
|
|
for stream in self.streams:
|
|
|
if stream.execute_next():
|
|
|
executed = True
|
|
|
if not executed:
|
|
|
break
|
|
|
|