File size: 7,484 Bytes
7a0c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from typing import List, Optional
from dataclasses import dataclass
import time
import json
from queue import Queue
from threading import Lock
import duckdb
from huggingface_hub import HfApi, HfFileSystem
from config import get_hf_token_cached

# Initialize token from .env



@dataclass
class Event:
    """Represents a CUDA-like event for synchronization"""
    event_id: str
    timestamp: float
    completed: bool = False
    state_json: Optional[dict] = None

class Stream:
    """Represents a CUDA-like stream for concurrent execution"""
    DB_URL = "hf://datasets/Fred808/helium/storage.json"
    
    def __init__(self, stream_id: int, db_url: Optional[str] = None):
        self.stream_id = stream_id
        self.events: List[Event] = []
        self.operation_queue: Queue = Queue()
        self.lock = Lock()
        self.is_active = True
        
        # Initialize database connection
        self.db_url = db_url or self.DB_URL
        self.max_retries = 3
        self._connect_with_retries()
        self._setup_database()

    def _connect_with_retries(self):
        """Establish database connection with retry logic"""
        for attempt in range(self.max_retries):
            try:
                self.conn = self._init_db_connection()
                return
            except Exception as e:
                if attempt == self.max_retries - 1:
                    raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
                time.sleep(1)

    def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
        """Initialize database connection with HuggingFace configuration"""
        # Convert HF URL to S3 path
        _, _, owner, dataset, db_file = self.db_url.split('/', 4)
        db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
        
        # Connect to remote database
        conn = duckdb.connect(db_path)
        conn.execute("INSTALL httpfs;")
        conn.execute("LOAD httpfs;")
        conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
        conn.execute("SET s3_use_ssl=true;")
        conn.execute("SET s3_url_style='path';")
        conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
        conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
        return conn

    def _setup_database(self):
        """Initialize database tables"""
        # Events table
        self.conn.execute("""

            CREATE TABLE IF NOT EXISTS stream_events (

                event_id VARCHAR PRIMARY KEY,

                stream_id BIGINT,

                timestamp DOUBLE,

                completed BOOLEAN DEFAULT false,

                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,

                completed_at TIMESTAMP,

                state_json JSON

            )

        """)
        
        # Operations table
        self.conn.execute("""

            CREATE TABLE IF NOT EXISTS stream_operations (

                operation_id VARCHAR PRIMARY KEY,

                stream_id BIGINT,

                operation_type VARCHAR,

                args JSON,

                kwargs JSON,

                status VARCHAR DEFAULT 'pending',

                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,

                started_at TIMESTAMP,

                completed_at TIMESTAMP,

                error_message VARCHAR

            )

        """)
        
    def record_event(self) -> Event:
        """Record an event in the stream"""
        with self.lock:
            event_id = f"event_{self.stream_id}_{time.time_ns()}"
            event = Event(event_id=event_id, timestamp=time.time())
            
            # Record event in database
            self.conn.execute("""

                INSERT INTO stream_events (

                    event_id, stream_id, timestamp, state_json

                ) VALUES (?, ?, ?, ?)

            """, [event_id, self.stream_id, event.timestamp, {"status": "created"}])
            
            self.events.append(event)
            return event

    def wait_event(self, event: Event):
        """Wait for a specific event to complete"""
        while True:
            # Check database for completion
            result = self.conn.execute("""

                SELECT completed, state_json

                FROM stream_events

                WHERE event_id = ?

            """, [event.event_id]).fetchall()
            
            if result and result[0][0]:
                event.completed = True
                event.state_json = result[0][1]
                break
            
            if event.completed:
                break
                
            time.sleep(0.001)  # Small sleep to prevent busy waiting

    def synchronize(self):
        """Synchronize the stream, waiting for all operations to complete"""
        with self.lock:
            for event in self.events:
                self.wait_event(event)
                
            # Clear completed events
            self.conn.execute("""

                DELETE FROM stream_events

                WHERE stream_id = ? AND completed = true

            """, [self.stream_id])
            
            self.events.clear()

    def add_operation(self, operation: callable, *args, **kwargs):
        """Add an operation to the stream's queue"""
        with self.lock:
            self.operation_queue.put((operation, args, kwargs))

    def execute_next(self) -> bool:
        """Execute the next operation in the queue"""
        try:
            with self.lock:
                if self.operation_queue.empty():
                    return False
                
                operation, args, kwargs = self.operation_queue.get()
                event = self.record_event()
                
                try:
                    operation(*args, **kwargs)
                finally:
                    event.completed = True
                    
                return True
        except Exception as e:
            print(f"Error in stream {self.stream_id}: {str(e)}")
            return False

class StreamManager:
    """Manages multiple CUDA-like streams"""
    def __init__(self):
        self.streams: List[Stream] = []
        self.default_stream = self.create_stream()

    def create_stream(self) -> Stream:
        """Create a new stream"""
        stream_id = len(self.streams)
        stream = Stream(stream_id)
        self.streams.append(stream)
        return stream

    def get_stream(self, stream_id: int) -> Optional[Stream]:
        """Get a stream by its ID"""
        if 0 <= stream_id < len(self.streams):
            return self.streams[stream_id]
        return None

    def synchronize_all(self):
        """Synchronize all streams"""
        for stream in self.streams:
            stream.synchronize()

    def synchronize_stream(self, stream_id: int):
        """Synchronize a specific stream"""
        stream = self.get_stream(stream_id)
        if stream:
            stream.synchronize()

    def execute_streams(self):
        """Execute operations in all streams"""
        while True:
            executed = False
            for stream in self.streams:
                if stream.execute_next():
                    executed = True
            if not executed:
                break