from typing import List, Dict, Any, Optional import time import json import logging import duckdb from huggingface_hub import HfApi, HfFileSystem from tensor_core import TensorCore from config import get_hf_token_cached # Initialize token from .env class TensorOps: """Manages tensor operations with remote state tracking""" DB_URL = "hf://datasets/Fred808/helium/storage.json" def __init__(self, db_url: Optional[str] = None): self.db_url = db_url or self.DB_URL self.max_retries = 3 self._connect_with_retries() self._setup_database() def _connect_with_retries(self): """Establish database connection with retry logic""" for attempt in range(self.max_retries): try: self.conn = self._init_db_connection() return except Exception as e: if attempt == self.max_retries - 1: raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}") time.sleep(1) def _init_db_connection(self) -> duckdb.DuckDBPyConnection: """Initialize database connection with HuggingFace configuration""" # Convert HF URL to S3 path _, _, owner, dataset, db_file = self.db_url.split('/', 4) db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}" # Connect to remote database conn = duckdb.connect(db_path) conn.execute("INSTALL httpfs;") conn.execute("LOAD httpfs;") conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';") conn.execute("SET s3_use_ssl=true;") conn.execute("SET s3_url_style='path';") conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';") conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';") return conn def _setup_database(self): """Initialize database tables""" # Tensor operations table self.conn.execute(""" CREATE TABLE IF NOT EXISTS tensor_operations ( operation_id VARCHAR PRIMARY KEY, operation_type VARCHAR, inputs JSON, output_shape VARCHAR, chip_id INTEGER, stream_id INTEGER, warp_id VARCHAR, status VARCHAR DEFAULT 'pending', result_address BIGINT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, started_at TIMESTAMP, completed_at TIMESTAMP, error_message VARCHAR, state_json JSON ) """) def execute_tensor_op(self, operation: str, inputs: List[Dict[str, Any]], output_shape: Optional[tuple] = None, chip_id: Optional[int] = None, stream_id: Optional[int] = None, warp_id: Optional[str] = None) -> Optional[int]: """ Execute a tensor operation with enhanced tracking and coordination Args: operation: Operation type (matmul, conv2d, etc.) inputs: List of input tensors with metadata output_shape: Expected output shape (for pre-allocation) chip_id: Target GPU chip (if None, will be automatically selected) stream_id: Execution stream ID (if None, uses default stream) warp_id: ID of warp to execute on (if None, automatically scheduled) Returns: Address of output tensor or None if operation fails """ operation_id = None try: # Generate operation ID operation_id = f"op_{time.time_ns()}" # Choose optimal GPU if not specified if chip_id is None: # Query least loaded GPU result = self.conn.execute(""" SELECT chip_id FROM tensor_operations WHERE status = 'running' GROUP BY chip_id ORDER BY COUNT(*) ASC LIMIT 1 """).fetchall() chip_id = result[0][0] if result else 0 # Create operation record self.conn.execute(""" INSERT INTO tensor_operations ( operation_id, operation_type, inputs, output_shape, chip_id, stream_id, warp_id, status, state_json ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, [ operation_id, operation, inputs, str(output_shape) if output_shape else None, chip_id, stream_id, warp_id, 'pending', { "status": "initialized", "timestamp": time.time_ns() } ]) # Initialize tensor core tensor_core = TensorCore() # Execute operation # Update status to running self.conn.execute(""" UPDATE tensor_operations SET status = 'running', started_at = CURRENT_TIMESTAMP, state_json = ? WHERE operation_id = ? """, [{"status": "running"}, operation_id]) # Execute based on operation type result_address = None if operation == 'matmul': result_address = tensor_core.matmul( inputs[0]['data'], inputs[1]['data'], warp_id=warp_id ) elif operation == 'conv2d': result_address = tensor_core.conv2d( inputs[0]['data'], inputs[1]['data'], warp_id=warp_id ) # Update operation status to completed self.conn.execute(""" UPDATE tensor_operations SET status = 'completed', completed_at = CURRENT_TIMESTAMP, result_address = ?, state_json = ? WHERE operation_id = ? """, [ result_address, {"status": "completed", "result": result_address}, operation_id ]) return result_address except Exception as e: if operation_id: # Update operation status to failed self.conn.execute(""" UPDATE tensor_operations SET status = 'failed', completed_at = CURRENT_TIMESTAMP, error_message = ?, state_json = ? WHERE operation_id = ? """, [ str(e), {"status": "failed", "error": str(e)}, operation_id ]) logging.error(f"Tensor operation failed: {str(e)}") return None def get_operation_status(self, operation_id: str) -> Dict[str, Any]: """Get the current status of a tensor operation""" try: result = self.conn.execute(""" SELECT status, result_address, error_message, state_json FROM tensor_operations WHERE operation_id = ? """, [operation_id]).fetchall() if not result: return {"status": "not_found"} row = result[0] return { "status": row[0], "result_address": row[1], "error_message": row[2], "state": row[3] } except Exception as e: logging.error(f"Failed to get operation status: {str(e)}") return {"status": "error", "error": str(e)} def wait_for_operation(self, operation_id: str, timeout: Optional[float] = None) -> Dict[str, Any]: """Wait for a tensor operation to complete""" start_time = time.time() while True: status = self.get_operation_status(operation_id) if status["status"] in ["completed", "failed"]: return status if timeout and (time.time() - start_time) > timeout: return {"status": "timeout"} time.sleep(0.001) def synchronize_operations(self, operation_ids: List[str]) -> Dict[str, Any]: """Synchronize multiple tensor operations""" try: results = {} for op_id in operation_ids: results[op_id] = self.wait_for_operation(op_id) return { "status": "completed", "operations": results } except Exception as e: logging.error(f"Failed to synchronize tensor operations: {str(e)}") return { "status": "error", "error": str(e) } # Get warp if not specified if warp_id is None: available_warps = [ w for w in self.warps[chip_id][target_sm_id] if len(w.get_active_threads()) > 0 ] if not available_warps: raise RuntimeError("No available warps") warp = available_warps[0] warp_id = str(warp.warp_id) op_info["warp_id"] = warp_id # Schedule operation op_metadata = target_sm.matrix_op_scheduler.schedule_operation( op_type=operation, input_shapes=[inp.get("shape") for inp in inputs], warp_id=warp_id ) if op_metadata is None: raise RuntimeError("Failed to schedule matrix operation") try: # Acquire matrix operation lock if not target_sm.matrix_op_lock.acquire_matrix_op( op_metadata.op_id, op_info ): raise RuntimeError("Failed to acquire matrix operation lock") # Execute operation based on type result = None if operation == "matmul": A = self.memory_manager.read_tensor(inputs[0]["address"]) B = self.memory_manager.read_tensor(inputs[1]["address"]) result = target_sm.tensor_core_matmul(A, B, warp_id=warp_id) elif operation == "conv2d": input_tensor = self.memory_manager.read_tensor(inputs[0]["address"]) kernel = self.memory_manager.read_tensor(inputs[1]["address"]) result = target_sm.tensor_core_conv2d(input_tensor, kernel, warp_id=warp_id) if result is None: raise RuntimeError(f"Failed to execute {operation}") # Allocate output and store result output_addr = self.allocate_memory( result.nbytes, chip_id=chip_id, tensor_shape=result.shape, dtype=result.dtype ) self.memory_manager.write_tensor(output_addr, result) # Complete operation successfully target_sm.matrix_op_scheduler.complete_operation( op_metadata, output_shape=result.shape, success=True ) # Update operation history target_sm.tensor_op_history.append({ **op_info, "op_id": op_metadata.op_id, "output_shape": result.shape, "output_address": output_addr, "end_time": time.time_ns(), "status": "completed" }) return output_addr except Exception as e: # Handle operation failure if op_metadata: target_sm.matrix_op_scheduler.complete_operation( op_metadata, output_shape=None, success=False, error=str(e) ) raise finally: # Always release the matrix operation lock if op_metadata: target_sm.matrix_op_lock.release_matrix_op(op_metadata.op_id) except Exception as e: logging.error(f"Tensor operation failed: {str(e)}") return None def get_tensor_op_status(self, chip_id: int, sm_id: int, op_id: str) -> Dict[str, Any]: """Get status and metadata for a tensor operation""" try: sm = self.streaming_multiprocessors[chip_id][sm_id] active_ops = sm.matrix_op_scheduler.coordinator.get_active_operations() # Check active operations for op in active_ops: if op.op_id == op_id: return { "status": "running", "metadata": op.__dict__ } # Check operation history history = sm.matrix_op_scheduler.coordinator.get_operation_history() for op in history: if op.op_id == op_id: return { "status": op.status, "metadata": op.__dict__ } return { "status": "not_found", "metadata": None } except Exception as e: logging.error(f"Failed to get operation status: {str(e)}") return { "status": "error", "metadata": {"error": str(e)} } def sync_tensor_ops(self, chip_id: int, sm_id: int, warp_id: Optional[str] = None): """Synchronize pending tensor operations""" try: sm = self.streaming_multiprocessors[chip_id][sm_id] # Get relevant operations if warp_id is not None: active_ops = [ op for op in sm.matrix_op_scheduler.coordinator.get_active_operations() if op.warp_id == warp_id ] else: active_ops = sm.matrix_op_scheduler.coordinator.get_active_operations() # Wait for operations to complete for op in active_ops: while True: status = self.get_tensor_op_status(chip_id, sm_id, op.op_id) if status["status"] not in ["running", "scheduled"]: break time.sleep(0.001) # Small delay to prevent busy waiting return True except Exception as e: logging.error(f"Failed to synchronize tensor operations: {str(e)}")