File size: 7,484 Bytes
7a0c684 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
from typing import List, Optional
from dataclasses import dataclass
import time
import json
from queue import Queue
from threading import Lock
import duckdb
from huggingface_hub import HfApi, HfFileSystem
from config import get_hf_token_cached
# Initialize token from .env
@dataclass
class Event:
"""Represents a CUDA-like event for synchronization"""
event_id: str
timestamp: float
completed: bool = False
state_json: Optional[dict] = None
class Stream:
"""Represents a CUDA-like stream for concurrent execution"""
DB_URL = "hf://datasets/Fred808/helium/storage.json"
def __init__(self, stream_id: int, db_url: Optional[str] = None):
self.stream_id = stream_id
self.events: List[Event] = []
self.operation_queue: Queue = Queue()
self.lock = Lock()
self.is_active = True
# Initialize database connection
self.db_url = db_url or self.DB_URL
self.max_retries = 3
self._connect_with_retries()
self._setup_database()
def _connect_with_retries(self):
"""Establish database connection with retry logic"""
for attempt in range(self.max_retries):
try:
self.conn = self._init_db_connection()
return
except Exception as e:
if attempt == self.max_retries - 1:
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
time.sleep(1)
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
"""Initialize database connection with HuggingFace configuration"""
# Convert HF URL to S3 path
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
# Connect to remote database
conn = duckdb.connect(db_path)
conn.execute("INSTALL httpfs;")
conn.execute("LOAD httpfs;")
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
conn.execute("SET s3_use_ssl=true;")
conn.execute("SET s3_url_style='path';")
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
return conn
def _setup_database(self):
"""Initialize database tables"""
# Events table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS stream_events (
event_id VARCHAR PRIMARY KEY,
stream_id BIGINT,
timestamp DOUBLE,
completed BOOLEAN DEFAULT false,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
completed_at TIMESTAMP,
state_json JSON
)
""")
# Operations table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS stream_operations (
operation_id VARCHAR PRIMARY KEY,
stream_id BIGINT,
operation_type VARCHAR,
args JSON,
kwargs JSON,
status VARCHAR DEFAULT 'pending',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
completed_at TIMESTAMP,
error_message VARCHAR
)
""")
def record_event(self) -> Event:
"""Record an event in the stream"""
with self.lock:
event_id = f"event_{self.stream_id}_{time.time_ns()}"
event = Event(event_id=event_id, timestamp=time.time())
# Record event in database
self.conn.execute("""
INSERT INTO stream_events (
event_id, stream_id, timestamp, state_json
) VALUES (?, ?, ?, ?)
""", [event_id, self.stream_id, event.timestamp, {"status": "created"}])
self.events.append(event)
return event
def wait_event(self, event: Event):
"""Wait for a specific event to complete"""
while True:
# Check database for completion
result = self.conn.execute("""
SELECT completed, state_json
FROM stream_events
WHERE event_id = ?
""", [event.event_id]).fetchall()
if result and result[0][0]:
event.completed = True
event.state_json = result[0][1]
break
if event.completed:
break
time.sleep(0.001) # Small sleep to prevent busy waiting
def synchronize(self):
"""Synchronize the stream, waiting for all operations to complete"""
with self.lock:
for event in self.events:
self.wait_event(event)
# Clear completed events
self.conn.execute("""
DELETE FROM stream_events
WHERE stream_id = ? AND completed = true
""", [self.stream_id])
self.events.clear()
def add_operation(self, operation: callable, *args, **kwargs):
"""Add an operation to the stream's queue"""
with self.lock:
self.operation_queue.put((operation, args, kwargs))
def execute_next(self) -> bool:
"""Execute the next operation in the queue"""
try:
with self.lock:
if self.operation_queue.empty():
return False
operation, args, kwargs = self.operation_queue.get()
event = self.record_event()
try:
operation(*args, **kwargs)
finally:
event.completed = True
return True
except Exception as e:
print(f"Error in stream {self.stream_id}: {str(e)}")
return False
class StreamManager:
"""Manages multiple CUDA-like streams"""
def __init__(self):
self.streams: List[Stream] = []
self.default_stream = self.create_stream()
def create_stream(self) -> Stream:
"""Create a new stream"""
stream_id = len(self.streams)
stream = Stream(stream_id)
self.streams.append(stream)
return stream
def get_stream(self, stream_id: int) -> Optional[Stream]:
"""Get a stream by its ID"""
if 0 <= stream_id < len(self.streams):
return self.streams[stream_id]
return None
def synchronize_all(self):
"""Synchronize all streams"""
for stream in self.streams:
stream.synchronize()
def synchronize_stream(self, stream_id: int):
"""Synchronize a specific stream"""
stream = self.get_stream(stream_id)
if stream:
stream.synchronize()
def execute_streams(self):
"""Execute operations in all streams"""
while True:
executed = False
for stream in self.streams:
if stream.execute_next():
executed = True
if not executed:
break
|