File size: 16,055 Bytes
7a0c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
from typing import List, Dict, Any, Optional
import time
import json
import logging
import duckdb
from huggingface_hub import HfApi, HfFileSystem
from tensor_core import TensorCore
from config import get_hf_token_cached

# Initialize token from .env



class TensorOps:
    """Manages tensor operations with remote state tracking"""
    DB_URL = "hf://datasets/Fred808/helium/storage.json"
    
    def __init__(self, db_url: Optional[str] = None):
        self.db_url = db_url or self.DB_URL
        self.max_retries = 3
        self._connect_with_retries()
        self._setup_database()
        
    def _connect_with_retries(self):
        """Establish database connection with retry logic"""
        for attempt in range(self.max_retries):
            try:
                self.conn = self._init_db_connection()
                return
            except Exception as e:
                if attempt == self.max_retries - 1:
                    raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
                time.sleep(1)

    def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
        """Initialize database connection with HuggingFace configuration"""
        # Convert HF URL to S3 path
        _, _, owner, dataset, db_file = self.db_url.split('/', 4)
        db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
        
        # Connect to remote database
        conn = duckdb.connect(db_path)
        conn.execute("INSTALL httpfs;")
        conn.execute("LOAD httpfs;")
        conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
        conn.execute("SET s3_use_ssl=true;")
        conn.execute("SET s3_url_style='path';")
        conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
        conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
        return conn

    def _setup_database(self):
        """Initialize database tables"""
        # Tensor operations table
        self.conn.execute("""

            CREATE TABLE IF NOT EXISTS tensor_operations (

                operation_id VARCHAR PRIMARY KEY,

                operation_type VARCHAR,

                inputs JSON,

                output_shape VARCHAR,

                chip_id INTEGER,

                stream_id INTEGER,

                warp_id VARCHAR,

                status VARCHAR DEFAULT 'pending',

                result_address BIGINT,

                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,

                started_at TIMESTAMP,

                completed_at TIMESTAMP,

                error_message VARCHAR,

                state_json JSON

            )

        """)

    def execute_tensor_op(self, operation: str, inputs: List[Dict[str, Any]], 

                        output_shape: Optional[tuple] = None,

                        chip_id: Optional[int] = None,

                        stream_id: Optional[int] = None,

                        warp_id: Optional[str] = None) -> Optional[int]:
        """

        Execute a tensor operation with enhanced tracking and coordination

        Args:

            operation: Operation type (matmul, conv2d, etc.)

            inputs: List of input tensors with metadata

            output_shape: Expected output shape (for pre-allocation)

            chip_id: Target GPU chip (if None, will be automatically selected)

            stream_id: Execution stream ID (if None, uses default stream)

            warp_id: ID of warp to execute on (if None, automatically scheduled)

        Returns:

            Address of output tensor or None if operation fails

        """
        operation_id = None
        try:
            # Generate operation ID
            operation_id = f"op_{time.time_ns()}"
            
            # Choose optimal GPU if not specified
            if chip_id is None:
                # Query least loaded GPU
                result = self.conn.execute("""

                    SELECT chip_id

                    FROM tensor_operations

                    WHERE status = 'running'

                    GROUP BY chip_id

                    ORDER BY COUNT(*) ASC

                    LIMIT 1

                """).fetchall()
                
                chip_id = result[0][0] if result else 0
            
            # Create operation record
            self.conn.execute("""

                INSERT INTO tensor_operations (

                    operation_id, operation_type, inputs, output_shape,

                    chip_id, stream_id, warp_id, status, state_json

                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)

            """, [
                operation_id,
                operation,
                inputs,
                str(output_shape) if output_shape else None,
                chip_id,
                stream_id,
                warp_id,
                'pending',
                {
                    "status": "initialized",
                    "timestamp": time.time_ns()
                }
            ])
            
            # Initialize tensor core
            tensor_core = TensorCore()
            
            # Execute operation
            # Update status to running
            self.conn.execute("""

                UPDATE tensor_operations

                SET status = 'running',

                    started_at = CURRENT_TIMESTAMP,

                    state_json = ?

                WHERE operation_id = ?

            """, [{"status": "running"}, operation_id])
            
            # Execute based on operation type
            result_address = None
            if operation == 'matmul':
                result_address = tensor_core.matmul(
                    inputs[0]['data'],
                    inputs[1]['data'],
                    warp_id=warp_id
                )
            elif operation == 'conv2d':
                result_address = tensor_core.conv2d(
                    inputs[0]['data'],
                    inputs[1]['data'],
                    warp_id=warp_id
                )
                
            # Update operation status to completed
            self.conn.execute("""

                UPDATE tensor_operations

                SET status = 'completed',

                    completed_at = CURRENT_TIMESTAMP,

                    result_address = ?,

                    state_json = ?

                WHERE operation_id = ?

            """, [
                result_address,
                {"status": "completed", "result": result_address},
                operation_id
            ])
            
            return result_address
            
        except Exception as e:
            if operation_id:
                # Update operation status to failed
                self.conn.execute("""

                    UPDATE tensor_operations

                    SET status = 'failed',

                        completed_at = CURRENT_TIMESTAMP,

                        error_message = ?,

                        state_json = ?

                    WHERE operation_id = ?

                """, [
                    str(e),
                    {"status": "failed", "error": str(e)},
                    operation_id
                ])
            logging.error(f"Tensor operation failed: {str(e)}")
            return None
                
    def get_operation_status(self, operation_id: str) -> Dict[str, Any]:
        """Get the current status of a tensor operation"""
        try:
            result = self.conn.execute("""

                SELECT status, result_address, error_message, state_json

                FROM tensor_operations

                WHERE operation_id = ?

            """, [operation_id]).fetchall()
            
            if not result:
                return {"status": "not_found"}
                
            row = result[0]
            return {
                "status": row[0],
                "result_address": row[1],
                "error_message": row[2],
                "state": row[3]
            }
            
        except Exception as e:
            logging.error(f"Failed to get operation status: {str(e)}")
            return {"status": "error", "error": str(e)}
            
    def wait_for_operation(self, operation_id: str, timeout: Optional[float] = None) -> Dict[str, Any]:
        """Wait for a tensor operation to complete"""
        start_time = time.time()
        while True:
            status = self.get_operation_status(operation_id)
            
            if status["status"] in ["completed", "failed"]:
                return status
                
            if timeout and (time.time() - start_time) > timeout:
                return {"status": "timeout"}
                
            time.sleep(0.001)
            
    def synchronize_operations(self, operation_ids: List[str]) -> Dict[str, Any]:
        """Synchronize multiple tensor operations"""
        try:
            results = {}
            for op_id in operation_ids:
                results[op_id] = self.wait_for_operation(op_id)
                
            return {
                "status": "completed",
                "operations": results
            }
            
        except Exception as e:
            logging.error(f"Failed to synchronize tensor operations: {str(e)}")
            return {
                "status": "error",
                "error": str(e)
            }
            
            # Get warp if not specified
            if warp_id is None:
                available_warps = [
                    w for w in self.warps[chip_id][target_sm_id]
                    if len(w.get_active_threads()) > 0
                ]
                if not available_warps:
                    raise RuntimeError("No available warps")
                warp = available_warps[0]
                warp_id = str(warp.warp_id)
                op_info["warp_id"] = warp_id
            
            # Schedule operation
            op_metadata = target_sm.matrix_op_scheduler.schedule_operation(
                op_type=operation,
                input_shapes=[inp.get("shape") for inp in inputs],
                warp_id=warp_id
            )
            
            if op_metadata is None:
                raise RuntimeError("Failed to schedule matrix operation")
            
            try:
                # Acquire matrix operation lock
                if not target_sm.matrix_op_lock.acquire_matrix_op(
                    op_metadata.op_id,
                    op_info
                ):
                    raise RuntimeError("Failed to acquire matrix operation lock")
                
                # Execute operation based on type
                result = None
                if operation == "matmul":
                    A = self.memory_manager.read_tensor(inputs[0]["address"])
                    B = self.memory_manager.read_tensor(inputs[1]["address"])
                    result = target_sm.tensor_core_matmul(A, B, warp_id=warp_id)
                    
                elif operation == "conv2d":
                    input_tensor = self.memory_manager.read_tensor(inputs[0]["address"])
                    kernel = self.memory_manager.read_tensor(inputs[1]["address"])
                    result = target_sm.tensor_core_conv2d(input_tensor, kernel, warp_id=warp_id)
                    
                if result is None:
                    raise RuntimeError(f"Failed to execute {operation}")
                    
                # Allocate output and store result
                output_addr = self.allocate_memory(
                    result.nbytes,
                    chip_id=chip_id,
                    tensor_shape=result.shape,
                    dtype=result.dtype
                )
                
                self.memory_manager.write_tensor(output_addr, result)
                
                # Complete operation successfully
                target_sm.matrix_op_scheduler.complete_operation(
                    op_metadata,
                    output_shape=result.shape,
                    success=True
                )
                
                # Update operation history
                target_sm.tensor_op_history.append({
                    **op_info,
                    "op_id": op_metadata.op_id,
                    "output_shape": result.shape,
                    "output_address": output_addr,
                    "end_time": time.time_ns(),
                    "status": "completed"
                })
                
                return output_addr
                
            except Exception as e:
                # Handle operation failure
                if op_metadata:
                    target_sm.matrix_op_scheduler.complete_operation(
                        op_metadata,
                        output_shape=None,
                        success=False,
                        error=str(e)
                    )
                raise
                
            finally:
                # Always release the matrix operation lock
                if op_metadata:
                    target_sm.matrix_op_lock.release_matrix_op(op_metadata.op_id)
                
        except Exception as e:
            logging.error(f"Tensor operation failed: {str(e)}")
            return None
            
def get_tensor_op_status(self, chip_id: int, sm_id: int, op_id: str) -> Dict[str, Any]:
        """Get status and metadata for a tensor operation"""
        try:
            sm = self.streaming_multiprocessors[chip_id][sm_id]
            active_ops = sm.matrix_op_scheduler.coordinator.get_active_operations()
            
            # Check active operations
            for op in active_ops:
                if op.op_id == op_id:
                    return {
                        "status": "running",
                        "metadata": op.__dict__
                    }
            
            # Check operation history
            history = sm.matrix_op_scheduler.coordinator.get_operation_history()
            for op in history:
                if op.op_id == op_id:
                    return {
                        "status": op.status,
                        "metadata": op.__dict__
                    }
            
            return {
                "status": "not_found",
                "metadata": None
            }
            
        except Exception as e:
            logging.error(f"Failed to get operation status: {str(e)}")
            return {
                "status": "error",
                "metadata": {"error": str(e)}
            }
            
def sync_tensor_ops(self, chip_id: int, sm_id: int, warp_id: Optional[str] = None):
        """Synchronize pending tensor operations"""
        try:
            sm = self.streaming_multiprocessors[chip_id][sm_id]
            
            # Get relevant operations
            if warp_id is not None:
                active_ops = [
                    op for op in sm.matrix_op_scheduler.coordinator.get_active_operations()
                    if op.warp_id == warp_id
                ]
            else:
                active_ops = sm.matrix_op_scheduler.coordinator.get_active_operations()
            
            # Wait for operations to complete
            for op in active_ops:
                while True:
                    status = self.get_tensor_op_status(chip_id, sm_id, op.op_id)
                    if status["status"] not in ["running", "scheduled"]:
                        break
                    time.sleep(0.001)  # Small delay to prevent busy waiting
                    
            return True
            
        except Exception as e:
            logging.error(f"Failed to synchronize tensor operations: {str(e)}")