Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import time | |
| from typing import Dict, Any, Optional, Tuple, Union, List | |
| from enum import Enum | |
| from tensor_core import TensorCoreArray | |
| class VectorOperation(Enum): | |
| """Enumeration of supported vector operations.""" | |
| ADD = "add" | |
| SUBTRACT = "subtract" | |
| MULTIPLY = "multiply" | |
| DIVIDE = "divide" | |
| DOT_PRODUCT = "dot_product" | |
| CROSS_PRODUCT = "cross_product" | |
| NORMALIZE = "normalize" | |
| MAGNITUDE = "magnitude" | |
| class AIAccelerator: | |
| """ | |
| AI Accelerator that simulates GPU-based AI computations. | |
| This class leverages NumPy's optimized operations to simulate the parallel | |
| processing capabilities of the vGPU for AI workloads. | |
| """ | |
| def __init__(self, vram=None, num_sms: int = 800, cores_per_sm: int = 222, storage=None): | |
| """Initialize AI Accelerator with electron-speed awareness and shared WebSocket storage.""" | |
| from electron_speed import TARGET_SWITCHES_PER_SEC, TRANSISTORS_ON_CHIP, drift_velocity | |
| self.storage = storage # Use the shared storage instance | |
| if self.storage is None: | |
| from websocket_storage import WebSocketGPUStorage | |
| self.storage = WebSocketGPUStorage() # Only create new if not provided | |
| if not self.storage.wait_for_connection(): | |
| raise RuntimeError("Could not connect to GPU storage server") | |
| self.vram = vram | |
| self.num_sms = num_sms | |
| self.cores_per_sm = cores_per_sm | |
| self.total_cores = num_sms * cores_per_sm | |
| # Configure for maximum parallel processing at electron speed | |
| total_tensor_cores = num_sms * cores_per_sm # Use ALL cores for tensor operations | |
| self.tensor_core_array = TensorCoreArray( | |
| num_tensor_cores=total_tensor_cores, | |
| bits=32, | |
| bandwidth_tbps=drift_velocity / 1e-12 # Bandwidth scaled to electron drift speed | |
| ) | |
| self.tensor_cores_initialized = False | |
| # Initialize model, tensor, and tokenizer tracking | |
| self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models | |
| self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata | |
| self.tokenizer_registry: Dict[str, Any] = {} # Track tokenizers | |
| self.resource_monitor = { | |
| 'vram_used': 0, | |
| 'active_tensors': 0, | |
| 'loaded_models': set() | |
| } | |
| def _serialize_model_config(self, config: Any) -> dict: | |
| """Convert model config to a serializable format.""" | |
| # Handle None case first | |
| if config is None: | |
| return None | |
| # Handle Florence2LanguageConfig specifically | |
| if config.__class__.__name__ == "Florence2LanguageConfig": | |
| try: | |
| return { | |
| "type": "Florence2LanguageConfig", | |
| "model_type": getattr(config, "model_type", ""), | |
| "architectures": getattr(config, "architectures", []), | |
| "hidden_size": getattr(config, "hidden_size", 0), | |
| "num_attention_heads": getattr(config, "num_attention_heads", 0), | |
| "num_hidden_layers": getattr(config, "num_hidden_layers", 0), | |
| "intermediate_size": getattr(config, "intermediate_size", 0), | |
| "max_position_embeddings": getattr(config, "max_position_embeddings", 0), | |
| "layer_norm_eps": getattr(config, "layer_norm_eps", 1e-12), | |
| "vocab_size": getattr(config, "vocab_size", 0) | |
| } | |
| except Exception as e: | |
| print(f"Warning: Error serializing Florence2LanguageConfig: {e}") | |
| return {"type": "Florence2LanguageConfig", "error": str(e)} | |
| # Handle standard types | |
| if isinstance(config, (int, float, str, bool)): | |
| return config | |
| # Handle lists and tuples | |
| if isinstance(config, (list, tuple)): | |
| return [self._serialize_model_config(item) for item in config] | |
| # Handle dictionaries | |
| if isinstance(config, dict): | |
| return {k: self._serialize_model_config(v) for k, v in config.items()} | |
| # Handle objects with __dict__ | |
| if hasattr(config, '__dict__'): | |
| config_dict = {} | |
| for key, value in config.__dict__.items(): | |
| try: | |
| # Skip private attributes | |
| if key.startswith('_'): | |
| continue | |
| config_dict[key] = self._serialize_model_config(value) | |
| except Exception as e: | |
| print(f"Warning: Error serializing attribute {key}: {e}") | |
| config_dict[key] = str(value) | |
| return config_dict | |
| # Fallback: convert to string representation | |
| try: | |
| return str(config) | |
| except Exception as e: | |
| return f"<Unserializable object of type {type(config).__name__}: {str(e)}>" | |
| def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool: | |
| """Store model state in WebSocket storage with proper serialization.""" | |
| try: | |
| # Convert any non-serializable parts of model_info | |
| serializable_info = self._serialize_model_config(model_info) | |
| # Store in model registry | |
| self.model_registry[model_name] = serializable_info | |
| # Save to storage | |
| if self.storage: | |
| # Store model info | |
| info_success = self.storage.store_state( | |
| "models", | |
| f"{model_name}/info", | |
| serializable_info | |
| ) | |
| # Store model state | |
| state_success = self.storage.store_state( | |
| "models", | |
| f"{model_name}/state", | |
| {"loaded": True, "timestamp": time.time()} | |
| ) | |
| if info_success and state_success: | |
| self.resource_monitor['loaded_models'].add(model_name) | |
| return True | |
| return False | |
| except Exception as e: | |
| print(f"Error storing model state: {str(e)}") | |
| return False | |
| def initialize_tensor_cores(self): | |
| """Initialize tensor cores and verify they're ready for computation""" | |
| if self.tensor_cores_initialized: | |
| return True | |
| try: | |
| # Verify tensor core array is properly initialized | |
| if not hasattr(self, 'tensor_core_array') or self.tensor_core_array is None: | |
| raise RuntimeError("Tensor core array not properly initialized") | |
| # Initialize tensor cores if needed | |
| if hasattr(self.tensor_core_array, 'initialize'): | |
| self.tensor_core_array.initialize() | |
| # Verify VRAM access | |
| if self.vram is None: | |
| raise RuntimeError("VRAM not properly configured") | |
| # Test tensor core functionality with a small computation | |
| test_input = [[1.0, 2.0], [3.0, 4.0]] | |
| # Convert input to numpy array if needed | |
| if isinstance(test_input, list): | |
| test_input = np.array(test_input, dtype=np.float32) | |
| test_result = self.tensor_core_array.matmul(test_input, test_input) | |
| if test_result is None or not isinstance(test_result, (np.ndarray, list)) or len(test_result) == 0: | |
| raise RuntimeError("Tensor core test computation failed") | |
| self.tensor_cores_initialized = True | |
| return True | |
| except Exception as e: | |
| print(f"Failed to initialize tensor cores: {str(e)}") | |
| self.tensor_cores_initialized = False | |
| return False | |
| # AI operation statistics | |
| self.operations_performed = 0 | |
| self.total_compute_time = 0.0 | |
| self.flops_performed = 0 | |
| # WebSocket-based memory management | |
| self.model_registry = {} # Track loaded models | |
| self.matrix_registry = {} # Track loaded matrices | |
| self.matrix_counter = 0 | |
| self.activation_cache: Dict[str, str] = {} # Cache activation outputs | |
| self.weight_cache: Dict[str, Any] = {} # Cache preprocessed weights | |
| # Model registries | |
| self.model_registry: Dict[str, Any] = {} | |
| self.tokenizer_registry: Dict[str, Any] = {} | |
| self.model_configs: Dict[str, Any] = {} # Store model architectures | |
| self.model_loaded = False | |
| # Batch processing configuration | |
| self.max_batch_size = 64 | |
| self.min_batch_size = 4 | |
| self.dynamic_batching = True # Enable automatic batch size adjustment | |
| def set_vram(self, vram): | |
| """Set the VRAM reference.""" | |
| self.vram = vram | |
| def allocate_matrix(self, shape: Tuple[int, ...], dtype=np.float32, | |
| name: Optional[str] = None) -> str: | |
| """Allocate a matrix in VRAM and return its ID.""" | |
| if not self.vram: | |
| raise RuntimeError("VRAM not available") | |
| if name is None: | |
| name = f"matrix_{self.matrix_counter}" | |
| self.matrix_counter += 1 | |
| # Create matrix data | |
| matrix_data = np.zeros(shape, dtype=dtype) | |
| # Store in VRAM as a texture (reusing texture storage mechanism) | |
| matrix_id = self.vram.load_texture(matrix_data, name) | |
| self.matrix_registry[name] = matrix_id | |
| return name | |
| def load_matrix(self, matrix_data: np.ndarray, name: Optional[str] = None) -> str: | |
| """Load matrix data into VRAM and return its ID.""" | |
| if not self.vram: | |
| raise RuntimeError("VRAM not available") | |
| if name is None: | |
| name = f"matrix_{self.matrix_counter}" | |
| self.matrix_counter += 1 | |
| # Store in VRAM | |
| matrix_id = self.vram.load_texture(matrix_data, name) | |
| self.matrix_registry[name] = matrix_id | |
| return name | |
| def get_matrix(self, matrix_id: str) -> Optional[np.ndarray]: | |
| """Retrieve matrix data from VRAM.""" | |
| if not self.vram or matrix_id not in self.matrix_registry: | |
| return None | |
| vram_id = self.matrix_registry[matrix_id] | |
| return self.vram.get_texture(vram_id) | |
| def matrix_multiply(self, matrix_a_id: str, matrix_b_id: str, | |
| result_id: Optional[str] = None) -> Optional[str]: | |
| """Perform matrix multiplication using simulated GPU parallelism.""" | |
| start_time = time.time() | |
| # Retrieve matrices from VRAM | |
| matrix_a = self.get_matrix(matrix_a_id) | |
| matrix_b = self.get_matrix(matrix_b_id) | |
| if matrix_a is None or matrix_b is None: | |
| print(f"Error: Could not retrieve matrices {matrix_a_id} or {matrix_b_id}") | |
| return None | |
| try: | |
| # Check if matrices can be multiplied | |
| if matrix_a.shape[-1] != matrix_b.shape[0]: | |
| print(f"Error: Matrix dimensions incompatible for multiplication: " | |
| f"{matrix_a.shape} x {matrix_b.shape}") | |
| return None | |
| # Simulate parallel processing by breaking down the operation | |
| # In a real GPU, this would be distributed across SMs and cores | |
| def _simulate_parallel_matmul(self, matrix_a: np.ndarray, matrix_b: np.ndarray) -> np.ndarray: | |
| """Route matrix multiplication through the virtual TensorCoreArray.""" | |
| A = matrix_a.tolist() | |
| B = matrix_b.tolist() | |
| result = self.tensor_core_array.matmul(A, B) | |
| return np.array(result) | |
| # Store result in VRAM | |
| if result_id is None: | |
| result_id = f"result_{self.matrix_counter}" | |
| self.matrix_counter += 1 | |
| result_matrix_id = self.load_matrix(result, result_id) | |
| # Update statistics | |
| compute_time = time.time() - start_time | |
| self.total_compute_time += compute_time | |
| self.operations_performed += 1 | |
| # Calculate FLOPs (2 * M * N * K for matrix multiplication) | |
| m, k = matrix_a.shape | |
| k2, n = matrix_b.shape | |
| flops = 2 * m * n * k | |
| self.flops_performed += flops | |
| print(f"Matrix multiplication completed: {matrix_a.shape} x {matrix_b.shape} " | |
| f"= {result.shape} in {compute_time:.4f}s") | |
| print(f"Simulated {flops:,} FLOPs across {self.total_cores} cores") | |
| return result_matrix_id | |
| except Exception as e: | |
| print(f"Error in matrix multiplication: {e}") | |
| return None | |
| def _simulate_parallel_matmul(self, matrix_a: np.ndarray, matrix_b: np.ndarray) -> np.ndarray: | |
| """Simulate parallel matrix multiplication across SMs.""" | |
| # Use NumPy's optimized matrix multiplication | |
| # In a real implementation, this would be broken down into blocks | |
| # and distributed across the simulated SMs | |
| # For demonstration, we can show how the work would be distributed | |
| m, k = matrix_a.shape | |
| k2, n = matrix_b.shape | |
| # Calculate work distribution | |
| total_output_elements = m * n | |
| elements_per_sm = max(1, total_output_elements // self.num_sms) | |
| print(f"Distributing {total_output_elements:,} output elements across " | |
| f"{self.num_sms} SMs ({elements_per_sm} elements per SM)") | |
| # Perform the actual computation using NumPy | |
| result = np.dot(matrix_a, matrix_b) | |
| return result | |
| def vector_operation(self, operation: VectorOperation, vector_a_id: str, | |
| vector_b_id: Optional[str] = None, | |
| result_id: Optional[str] = None) -> Optional[str]: | |
| """Perform vector operations using simulated GPU parallelism.""" | |
| start_time = time.time() | |
| # Retrieve vectors from VRAM | |
| vector_a = self.get_matrix(vector_a_id) | |
| if vector_a is None: | |
| print(f"Error: Could not retrieve vector {vector_a_id}") | |
| return None | |
| vector_b = None | |
| if vector_b_id: | |
| vector_b = self.get_matrix(vector_b_id) | |
| if vector_b is None: | |
| print(f"Error: Could not retrieve vector {vector_b_id}") | |
| return None | |
| try: | |
| result = None | |
| flops = 0 | |
| if operation == VectorOperation.ADD: | |
| if vector_b is None: | |
| raise ValueError("Vector B required for addition") | |
| result = vector_a + vector_b | |
| flops = vector_a.size | |
| elif operation == VectorOperation.SUBTRACT: | |
| if vector_b is None: | |
| raise ValueError("Vector B required for subtraction") | |
| result = vector_a - vector_b | |
| flops = vector_a.size | |
| elif operation == VectorOperation.MULTIPLY: | |
| if vector_b is None: | |
| raise ValueError("Vector B required for multiplication") | |
| result = vector_a * vector_b | |
| flops = vector_a.size | |
| elif operation == VectorOperation.DIVIDE: | |
| if vector_b is None: | |
| raise ValueError("Vector B required for division") | |
| result = vector_a / vector_b | |
| flops = vector_a.size | |
| elif operation == VectorOperation.DOT_PRODUCT: | |
| if vector_b is None: | |
| raise ValueError("Vector B required for dot product") | |
| result = np.dot(vector_a.flatten(), vector_b.flatten()) | |
| flops = 2 * vector_a.size | |
| elif operation == VectorOperation.CROSS_PRODUCT: | |
| if vector_b is None: | |
| raise ValueError("Vector B required for cross product") | |
| result = np.cross(vector_a, vector_b) | |
| flops = 6 # Approximate for 3D cross product | |
| elif operation == VectorOperation.NORMALIZE: | |
| magnitude = np.linalg.norm(vector_a) | |
| result = vector_a / magnitude if magnitude > 0 else vector_a | |
| flops = vector_a.size * 2 # Division + magnitude calculation | |
| elif operation == VectorOperation.MAGNITUDE: | |
| result = np.array([np.linalg.norm(vector_a)]) | |
| flops = vector_a.size * 2 # Squares and sum | |
| else: | |
| raise ValueError(f"Unsupported vector operation: {operation}") | |
| # Store result in VRAM | |
| if result_id is None: | |
| result_id = f"vector_result_{self.matrix_counter}" | |
| self.matrix_counter += 1 | |
| result_vector_id = self.load_matrix(result, result_id) | |
| # Update statistics | |
| compute_time = time.time() - start_time | |
| self.total_compute_time += compute_time | |
| self.operations_performed += 1 | |
| self.flops_performed += flops | |
| print(f"Vector operation {operation.value} completed in {compute_time:.4f}s") | |
| return result_vector_id | |
| except Exception as e: | |
| print(f"Error in vector operation {operation.value}: {e}") | |
| return None | |
| def convolution_2d(self, input_id: str, kernel_id: str, | |
| stride: int = 1, padding: int = 0, | |
| result_id: Optional[str] = None) -> Optional[str]: | |
| """Perform 2D convolution operation.""" | |
| start_time = time.time() | |
| # Retrieve input and kernel from VRAM | |
| input_data = self.get_matrix(input_id) | |
| kernel = self.get_matrix(kernel_id) | |
| if input_data is None or kernel is None: | |
| print(f"Error: Could not retrieve input or kernel") | |
| return None | |
| try: | |
| # Simple 2D convolution implementation | |
| # In a real GPU implementation, this would be highly optimized | |
| # and distributed across many cores | |
| if len(input_data.shape) == 2: | |
| input_h, input_w = input_data.shape | |
| channels = 1 | |
| else: | |
| input_h, input_w, channels = input_data.shape | |
| kernel_h, kernel_w = kernel.shape[:2] | |
| # Calculate output dimensions | |
| output_h = (input_h + 2 * padding - kernel_h) // stride + 1 | |
| output_w = (input_w + 2 * padding - kernel_w) // stride + 1 | |
| # Initialize output | |
| if channels == 1: | |
| output = np.zeros((output_h, output_w)) | |
| else: | |
| output = np.zeros((output_h, output_w, channels)) | |
| # Pad input if necessary | |
| if padding > 0: | |
| if channels == 1: | |
| padded_input = np.pad(input_data, padding, mode='constant') | |
| else: | |
| padded_input = np.pad(input_data, | |
| ((padding, padding), (padding, padding), (0, 0)), | |
| mode='constant') | |
| else: | |
| padded_input = input_data | |
| # Perform convolution | |
| flops = 0 | |
| for y in range(0, output_h): | |
| for x in range(0, output_w): | |
| y_start = y * stride | |
| x_start = x * stride | |
| if channels == 1: | |
| patch = padded_input[y_start:y_start+kernel_h, x_start:x_start+kernel_w] | |
| output[y, x] = np.sum(patch * kernel) | |
| flops += kernel_h * kernel_w * 2 # Multiply and add | |
| else: | |
| for c in range(channels): | |
| patch = padded_input[y_start:y_start+kernel_h, | |
| x_start:x_start+kernel_w, c] | |
| output[y, x, c] = np.sum(patch * kernel) | |
| flops += kernel_h * kernel_w * 2 | |
| # Store result in VRAM | |
| if result_id is None: | |
| result_id = f"conv_result_{self.matrix_counter}" | |
| self.matrix_counter += 1 | |
| result_conv_id = self.load_matrix(output, result_id) | |
| # Update statistics | |
| compute_time = time.time() - start_time | |
| self.total_compute_time += compute_time | |
| self.operations_performed += 1 | |
| self.flops_performed += flops | |
| print(f"2D Convolution completed: {input_data.shape} * {kernel.shape} " | |
| f"= {output.shape} in {compute_time:.4f}s") | |
| print(f"Simulated {flops:,} FLOPs") | |
| return result_conv_id | |
| except Exception as e: | |
| print(f"Error in 2D convolution: {e}") | |
| return None | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get AI accelerator statistics.""" | |
| avg_compute_time = self.total_compute_time / max(1, self.operations_performed) | |
| flops_per_second = self.flops_performed / max(0.001, self.total_compute_time) | |
| return { | |
| "operations_performed": self.operations_performed, | |
| "total_compute_time": self.total_compute_time, | |
| "avg_compute_time": avg_compute_time, | |
| "flops_performed": self.flops_performed, | |
| "flops_per_second": flops_per_second, | |
| "matrices_in_memory": len(self.matrix_registry), | |
| "simulated_cores": self.total_cores, | |
| "simulated_sms": self.num_sms | |
| } | |
| def reset_stats(self) -> None: | |
| """Reset AI accelerator statistics.""" | |
| self.operations_performed = 0 | |
| self.total_compute_time = 0.0 | |
| self.flops_performed = 0 | |
| def optimize_attention_weights(self, weight_matrix): | |
| """Preprocess attention weights for faster computation.""" | |
| # Optimize weight layout for tensor core operations | |
| if isinstance(weight_matrix, np.ndarray): | |
| # Reshape for optimal memory access | |
| if len(weight_matrix.shape) == 2: | |
| # Pad to multiple of tensor core size if needed | |
| h, w = weight_matrix.shape | |
| pad_h = (8 - h % 8) if h % 8 != 0 else 0 | |
| pad_w = (8 - w % 8) if w % 8 != 0 else 0 | |
| if pad_h > 0 or pad_w > 0: | |
| weight_matrix = np.pad(weight_matrix, ((0, pad_h), (0, pad_w))) | |
| return weight_matrix | |
| return weight_matrix | |
| def parallel_attention(self, query, key_value_weights, features_per_sm): | |
| """Execute multi-head attention using parallel tensor cores.""" | |
| # Split attention heads across SMs | |
| num_heads = min(self.num_sms, 32) # Max 32 attention heads | |
| head_dim = query.shape[-1] // num_heads | |
| # Parallel processing of attention heads | |
| attention_results = [] | |
| for i in range(0, num_heads): | |
| start_idx = i * head_dim | |
| end_idx = (i + 1) * head_dim | |
| # Process attention head using tensor core | |
| q_head = [row[start_idx:end_idx] for row in query] | |
| k_head = [row[start_idx:end_idx] for row in key_value_weights] | |
| # Compute attention scores using tensor core | |
| attention_scores = self.tensor_core_array.matmul( | |
| q_head, k_head, | |
| split_size=features_per_sm | |
| ) | |
| attention_results.append(attention_scores) | |
| # Combine attention heads | |
| return self.combine_attention_heads(attention_results) | |
| def combine_attention_heads(self, attention_heads): | |
| """Combine attention heads efficiently using tensor cores.""" | |
| if not attention_heads: | |
| return None | |
| # Get dimensions | |
| num_heads = len(attention_heads) | |
| batch_size = len(attention_heads[0]) | |
| head_dim = len(attention_heads[0][0]) | |
| # Concatenate heads efficiently | |
| combined = [[0.0] * (head_dim * num_heads) for _ in range(batch_size)] | |
| for i in range(batch_size): | |
| for h in range(num_heads): | |
| for j in range(head_dim): | |
| combined[i][h * head_dim + j] = attention_heads[h][i][j] | |
| return combined | |
| def calculate_tflops(self, model_info, batch_size, inference_time): | |
| """Calculate effective TFLOPS for the inference.""" | |
| total_params = sum(np.prod(self.get_matrix(w_id).shape) for w_id in model_info["weights"].values()) | |
| ops_per_param = 2 # Multiply-add | |
| total_ops = total_params * batch_size * ops_per_param | |
| return (total_ops / inference_time) / 1e12 # Convert to TFLOPS | |
| def _serialize_tensor(self, tensor: Any) -> np.ndarray: | |
| """Convert a PyTorch tensor to numpy array safely.""" | |
| try: | |
| if hasattr(tensor, 'detach'): | |
| tensor = tensor.detach() | |
| if hasattr(tensor, 'cpu'): | |
| tensor = tensor.cpu() | |
| if hasattr(tensor, 'numpy'): | |
| return tensor.numpy() | |
| return np.array(tensor) | |
| except Exception as e: | |
| print(f"Warning: Error converting tensor to numpy: {e}") | |
| return None | |
| def load_model(self, model_id: str, model: Any, processor: Any): | |
| """Loads a model directly into WebSocket storage without CPU intermediary.""" | |
| try: | |
| if model is None and processor is None: | |
| # Zero-copy mode | |
| self.model_registry[model_id] = { | |
| "zero_copy": True, | |
| "websocket_mapped": True | |
| } | |
| self.tokenizer_registry[model_id] = None | |
| self.model_loaded = True | |
| return | |
| # Verify WebSocket connection first | |
| if not self.storage or not self.storage.wait_for_connection(): | |
| raise RuntimeError("WebSocket connection not available") | |
| # 1. Store model configuration | |
| try: | |
| config_dict = (self._serialize_model_config(model.config) | |
| if hasattr(model, "config") else {}) | |
| model_info = { | |
| "architecture": model.__class__.__name__ if model else "Unknown", | |
| "processor": processor.__class__.__name__ if processor else "Unknown", | |
| "config": config_dict | |
| } | |
| except Exception as e: | |
| print(f"Warning: Error serializing model config: {e}") | |
| model_info = { | |
| "architecture": str(type(model).__name__), | |
| "error": str(e) | |
| } | |
| # Store model info with retry | |
| for attempt in range(3): | |
| try: | |
| if self.storage.store_state(f"models/{model_id}/info", "info", model_info): | |
| break | |
| print(f"Retrying model info storage, attempt {attempt + 1}") | |
| time.sleep(1) | |
| except Exception as e: | |
| if attempt == 2: | |
| raise RuntimeError(f"Failed to store model info: {e}") | |
| # 2. Store model weights | |
| if hasattr(model, "state_dict"): | |
| weight_registry = {} | |
| for name, param in model.state_dict().items(): | |
| # Convert tensor to numpy and store in chunks if needed | |
| tensor_data = self._serialize_tensor(param) | |
| if tensor_data is not None: | |
| tensor_id = f"{model_id}/weights/{name}" | |
| if tensor_data.nbytes > 1024*1024*1024: # If larger than 1GB | |
| # Store large tensors in chunks | |
| chunks = np.array_split(tensor_data, | |
| max(1, tensor_data.nbytes // (512*1024*1024))) | |
| chunk_ids = [] | |
| for i, chunk in enumerate(chunks): | |
| chunk_id = f"{tensor_id}/chunk_{i}" | |
| if self.storage.store_tensor(chunk_id, chunk): | |
| chunk_ids.append(chunk_id) | |
| weight_registry[name] = { | |
| "type": "chunked", | |
| "chunks": chunk_ids, | |
| "shape": tensor_data.shape, | |
| "dtype": str(tensor_data.dtype) | |
| } | |
| else: | |
| # Store small tensors directly | |
| if self.storage.store_tensor(tensor_id, tensor_data): | |
| weight_registry[name] = { | |
| "type": "direct", | |
| "tensor_id": tensor_id, | |
| "shape": tensor_data.shape, | |
| "dtype": str(tensor_data.dtype) | |
| } | |
| # Store weight registry | |
| self.storage.store_state(f"models/{model_id}/weights", "registry", weight_registry) | |
| self.model_registry[model_id] = { | |
| "weight_registry": weight_registry, | |
| "websocket_mapped": True | |
| } | |
| # Map weight tensors directly to WebSocket storage | |
| if model is not None and hasattr(model, "state_dict"): | |
| model_weights = {} | |
| for name, param in model.state_dict().items(): | |
| tensor_id = f"{model_id}/weights/{name}" | |
| # Store tensor directly in WebSocket storage | |
| if not self.storage.store_tensor(tensor_id, param.detach().numpy()): | |
| raise RuntimeError(f"Failed to store tensor {name}") | |
| model_weights[name] = tensor_id | |
| # Store only WebSocket references | |
| self.model_registry[model_id] = { | |
| "weights": model_weights, | |
| "architecture_id": hash(str(type(model))), | |
| "websocket_mapped": True | |
| } | |
| else: | |
| # Store the entire model state in WebSocket storage | |
| tensor_id = f"{model_id}/model_state" | |
| if not self.storage.store_state(f"models/{model_id}/state", "state", model): | |
| raise RuntimeError("Failed to store model state") | |
| self.model_registry[model_id] = tensor_id | |
| # Store tokenizer/processor | |
| self.tokenizer_registry[model_id] = processor | |
| self.model_loaded = True | |
| print(f"Model '{model_id}' loaded into WebSocket storage") | |
| except Exception as e: | |
| print(f"Error loading model into WebSocket storage: {str(e)}") | |
| raise | |
| def has_model(self, model_id: str) -> bool: | |
| """Checks if a model is loaded in the accelerator's registry.""" | |
| return model_id in self.model_registry | |
| def inference(self, model_id: str, input_data: np.ndarray, idx: Optional[int] = None) -> Optional[np.ndarray]: | |
| """Execute pure WebSocket-based inference with zero CPU usage.""" | |
| print(f"[DEBUG] Starting WebSocket-based inference for model_id={model_id}") | |
| try: | |
| if not self.has_model(model_id): | |
| print(f"[ERROR] Model {model_id} not loaded in WebSocket storage.") | |
| return None | |
| model_info = self.model_registry[model_id] | |
| processor = self.tokenizer_registry[model_id] | |
| # Store input data in WebSocket storage | |
| input_tensor_id = f"{model_id}/inputs/{idx if idx is not None else time.time_ns()}" | |
| self.storage.store_tensor(input_tensor_id, input_data) | |
| # Process input using tensor cores through WebSocket | |
| processed_data = processor(input_data, return_tensors="np") | |
| processed_tensor_id = f"{model_id}/processed/{idx if idx is not None else time.time_ns()}" | |
| self.storage.store_tensor(processed_tensor_id, processed_data["input_ids"]) | |
| # Load weights from WebSocket storage and perform forward pass | |
| if isinstance(model_info, dict) and "weights" in model_info: | |
| # Initialize hidden states | |
| hidden_states = processed_data["input_ids"] | |
| # Process through each layer using tensor cores | |
| for layer_name, weight_id in model_info["weights"].items(): | |
| if "weight" in layer_name: | |
| # Load weights from WebSocket storage | |
| weights = self.storage.load_tensor(weight_id) | |
| if weights is None: | |
| continue | |
| # Process through tensor cores | |
| if "attention" in layer_name: | |
| hidden_states = self.parallel_attention( | |
| hidden_states, | |
| weights, | |
| features_per_sm=hidden_states.shape[-1] // self.num_sms | |
| ) | |
| else: | |
| # Regular layer processing | |
| hidden_states = self.tensor_core_array.matmul( | |
| hidden_states.tolist(), | |
| weights.tolist() | |
| ) | |
| # Store final output in WebSocket storage | |
| output_tensor_id = f"{model_id}/outputs/{idx if idx is not None else time.time_ns()}" | |
| output = np.array(hidden_states) | |
| self.storage.store_tensor(output_tensor_id, output) | |
| return output | |
| else: | |
| print(f"[ERROR] Unsupported model format in WebSocket storage") | |
| return None | |
| except Exception as e: | |
| print(f"[ERROR] WebSocket-based inference failed for idx={idx}: {e}") | |
| return None | |