FServe / ai.py
Factor Studios
Upload 37 files
e9bc512 verified
import numpy as np
import time
from typing import Dict, Any, Optional, Tuple, Union, List
from enum import Enum
from tensor_core import TensorCoreArray
class VectorOperation(Enum):
"""Enumeration of supported vector operations."""
ADD = "add"
SUBTRACT = "subtract"
MULTIPLY = "multiply"
DIVIDE = "divide"
DOT_PRODUCT = "dot_product"
CROSS_PRODUCT = "cross_product"
NORMALIZE = "normalize"
MAGNITUDE = "magnitude"
class AIAccelerator:
"""
AI Accelerator that simulates GPU-based AI computations.
This class leverages NumPy's optimized operations to simulate the parallel
processing capabilities of the vGPU for AI workloads.
"""
def __init__(self, vram=None, num_sms: int = 800, cores_per_sm: int = 222, storage=None):
"""Initialize AI Accelerator with electron-speed awareness and shared WebSocket storage."""
from electron_speed import TARGET_SWITCHES_PER_SEC, TRANSISTORS_ON_CHIP, drift_velocity
self.storage = storage # Use the shared storage instance
if self.storage is None:
from websocket_storage import WebSocketGPUStorage
self.storage = WebSocketGPUStorage() # Only create new if not provided
if not self.storage.wait_for_connection():
raise RuntimeError("Could not connect to GPU storage server")
self.vram = vram
self.num_sms = num_sms
self.cores_per_sm = cores_per_sm
self.total_cores = num_sms * cores_per_sm
# Configure for maximum parallel processing at electron speed
total_tensor_cores = num_sms * cores_per_sm # Use ALL cores for tensor operations
self.tensor_core_array = TensorCoreArray(
num_tensor_cores=total_tensor_cores,
bits=32,
bandwidth_tbps=drift_velocity / 1e-12 # Bandwidth scaled to electron drift speed
)
self.tensor_cores_initialized = False
# Initialize model, tensor, and tokenizer tracking
self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
self.tokenizer_registry: Dict[str, Any] = {} # Track tokenizers
self.resource_monitor = {
'vram_used': 0,
'active_tensors': 0,
'loaded_models': set()
}
def _serialize_model_config(self, config: Any) -> dict:
"""Convert model config to a serializable format."""
# Handle None case first
if config is None:
return None
# Handle Florence2LanguageConfig specifically
if config.__class__.__name__ == "Florence2LanguageConfig":
try:
return {
"type": "Florence2LanguageConfig",
"model_type": getattr(config, "model_type", ""),
"architectures": getattr(config, "architectures", []),
"hidden_size": getattr(config, "hidden_size", 0),
"num_attention_heads": getattr(config, "num_attention_heads", 0),
"num_hidden_layers": getattr(config, "num_hidden_layers", 0),
"intermediate_size": getattr(config, "intermediate_size", 0),
"max_position_embeddings": getattr(config, "max_position_embeddings", 0),
"layer_norm_eps": getattr(config, "layer_norm_eps", 1e-12),
"vocab_size": getattr(config, "vocab_size", 0)
}
except Exception as e:
print(f"Warning: Error serializing Florence2LanguageConfig: {e}")
return {"type": "Florence2LanguageConfig", "error": str(e)}
# Handle standard types
if isinstance(config, (int, float, str, bool)):
return config
# Handle lists and tuples
if isinstance(config, (list, tuple)):
return [self._serialize_model_config(item) for item in config]
# Handle dictionaries
if isinstance(config, dict):
return {k: self._serialize_model_config(v) for k, v in config.items()}
# Handle objects with __dict__
if hasattr(config, '__dict__'):
config_dict = {}
for key, value in config.__dict__.items():
try:
# Skip private attributes
if key.startswith('_'):
continue
config_dict[key] = self._serialize_model_config(value)
except Exception as e:
print(f"Warning: Error serializing attribute {key}: {e}")
config_dict[key] = str(value)
return config_dict
# Fallback: convert to string representation
try:
return str(config)
except Exception as e:
return f"<Unserializable object of type {type(config).__name__}: {str(e)}>"
def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
"""Store model state in WebSocket storage with proper serialization."""
try:
# Convert any non-serializable parts of model_info
serializable_info = self._serialize_model_config(model_info)
# Store in model registry
self.model_registry[model_name] = serializable_info
# Save to storage
if self.storage:
# Store model info
info_success = self.storage.store_state(
"models",
f"{model_name}/info",
serializable_info
)
# Store model state
state_success = self.storage.store_state(
"models",
f"{model_name}/state",
{"loaded": True, "timestamp": time.time()}
)
if info_success and state_success:
self.resource_monitor['loaded_models'].add(model_name)
return True
return False
except Exception as e:
print(f"Error storing model state: {str(e)}")
return False
def initialize_tensor_cores(self):
"""Initialize tensor cores and verify they're ready for computation"""
if self.tensor_cores_initialized:
return True
try:
# Verify tensor core array is properly initialized
if not hasattr(self, 'tensor_core_array') or self.tensor_core_array is None:
raise RuntimeError("Tensor core array not properly initialized")
# Initialize tensor cores if needed
if hasattr(self.tensor_core_array, 'initialize'):
self.tensor_core_array.initialize()
# Verify VRAM access
if self.vram is None:
raise RuntimeError("VRAM not properly configured")
# Test tensor core functionality with a small computation
test_input = [[1.0, 2.0], [3.0, 4.0]]
# Convert input to numpy array if needed
if isinstance(test_input, list):
test_input = np.array(test_input, dtype=np.float32)
test_result = self.tensor_core_array.matmul(test_input, test_input)
if test_result is None or not isinstance(test_result, (np.ndarray, list)) or len(test_result) == 0:
raise RuntimeError("Tensor core test computation failed")
self.tensor_cores_initialized = True
return True
except Exception as e:
print(f"Failed to initialize tensor cores: {str(e)}")
self.tensor_cores_initialized = False
return False
# AI operation statistics
self.operations_performed = 0
self.total_compute_time = 0.0
self.flops_performed = 0
# WebSocket-based memory management
self.model_registry = {} # Track loaded models
self.matrix_registry = {} # Track loaded matrices
self.matrix_counter = 0
self.activation_cache: Dict[str, str] = {} # Cache activation outputs
self.weight_cache: Dict[str, Any] = {} # Cache preprocessed weights
# Model registries
self.model_registry: Dict[str, Any] = {}
self.tokenizer_registry: Dict[str, Any] = {}
self.model_configs: Dict[str, Any] = {} # Store model architectures
self.model_loaded = False
# Batch processing configuration
self.max_batch_size = 64
self.min_batch_size = 4
self.dynamic_batching = True # Enable automatic batch size adjustment
def set_vram(self, vram):
"""Set the VRAM reference."""
self.vram = vram
def allocate_matrix(self, shape: Tuple[int, ...], dtype=np.float32,
name: Optional[str] = None) -> str:
"""Allocate a matrix in VRAM and return its ID."""
if not self.vram:
raise RuntimeError("VRAM not available")
if name is None:
name = f"matrix_{self.matrix_counter}"
self.matrix_counter += 1
# Create matrix data
matrix_data = np.zeros(shape, dtype=dtype)
# Store in VRAM as a texture (reusing texture storage mechanism)
matrix_id = self.vram.load_texture(matrix_data, name)
self.matrix_registry[name] = matrix_id
return name
def load_matrix(self, matrix_data: np.ndarray, name: Optional[str] = None) -> str:
"""Load matrix data into VRAM and return its ID."""
if not self.vram:
raise RuntimeError("VRAM not available")
if name is None:
name = f"matrix_{self.matrix_counter}"
self.matrix_counter += 1
# Store in VRAM
matrix_id = self.vram.load_texture(matrix_data, name)
self.matrix_registry[name] = matrix_id
return name
def get_matrix(self, matrix_id: str) -> Optional[np.ndarray]:
"""Retrieve matrix data from VRAM."""
if not self.vram or matrix_id not in self.matrix_registry:
return None
vram_id = self.matrix_registry[matrix_id]
return self.vram.get_texture(vram_id)
def matrix_multiply(self, matrix_a_id: str, matrix_b_id: str,
result_id: Optional[str] = None) -> Optional[str]:
"""Perform matrix multiplication using simulated GPU parallelism."""
start_time = time.time()
# Retrieve matrices from VRAM
matrix_a = self.get_matrix(matrix_a_id)
matrix_b = self.get_matrix(matrix_b_id)
if matrix_a is None or matrix_b is None:
print(f"Error: Could not retrieve matrices {matrix_a_id} or {matrix_b_id}")
return None
try:
# Check if matrices can be multiplied
if matrix_a.shape[-1] != matrix_b.shape[0]:
print(f"Error: Matrix dimensions incompatible for multiplication: "
f"{matrix_a.shape} x {matrix_b.shape}")
return None
# Simulate parallel processing by breaking down the operation
# In a real GPU, this would be distributed across SMs and cores
def _simulate_parallel_matmul(self, matrix_a: np.ndarray, matrix_b: np.ndarray) -> np.ndarray:
"""Route matrix multiplication through the virtual TensorCoreArray."""
A = matrix_a.tolist()
B = matrix_b.tolist()
result = self.tensor_core_array.matmul(A, B)
return np.array(result)
# Store result in VRAM
if result_id is None:
result_id = f"result_{self.matrix_counter}"
self.matrix_counter += 1
result_matrix_id = self.load_matrix(result, result_id)
# Update statistics
compute_time = time.time() - start_time
self.total_compute_time += compute_time
self.operations_performed += 1
# Calculate FLOPs (2 * M * N * K for matrix multiplication)
m, k = matrix_a.shape
k2, n = matrix_b.shape
flops = 2 * m * n * k
self.flops_performed += flops
print(f"Matrix multiplication completed: {matrix_a.shape} x {matrix_b.shape} "
f"= {result.shape} in {compute_time:.4f}s")
print(f"Simulated {flops:,} FLOPs across {self.total_cores} cores")
return result_matrix_id
except Exception as e:
print(f"Error in matrix multiplication: {e}")
return None
def _simulate_parallel_matmul(self, matrix_a: np.ndarray, matrix_b: np.ndarray) -> np.ndarray:
"""Simulate parallel matrix multiplication across SMs."""
# Use NumPy's optimized matrix multiplication
# In a real implementation, this would be broken down into blocks
# and distributed across the simulated SMs
# For demonstration, we can show how the work would be distributed
m, k = matrix_a.shape
k2, n = matrix_b.shape
# Calculate work distribution
total_output_elements = m * n
elements_per_sm = max(1, total_output_elements // self.num_sms)
print(f"Distributing {total_output_elements:,} output elements across "
f"{self.num_sms} SMs ({elements_per_sm} elements per SM)")
# Perform the actual computation using NumPy
result = np.dot(matrix_a, matrix_b)
return result
def vector_operation(self, operation: VectorOperation, vector_a_id: str,
vector_b_id: Optional[str] = None,
result_id: Optional[str] = None) -> Optional[str]:
"""Perform vector operations using simulated GPU parallelism."""
start_time = time.time()
# Retrieve vectors from VRAM
vector_a = self.get_matrix(vector_a_id)
if vector_a is None:
print(f"Error: Could not retrieve vector {vector_a_id}")
return None
vector_b = None
if vector_b_id:
vector_b = self.get_matrix(vector_b_id)
if vector_b is None:
print(f"Error: Could not retrieve vector {vector_b_id}")
return None
try:
result = None
flops = 0
if operation == VectorOperation.ADD:
if vector_b is None:
raise ValueError("Vector B required for addition")
result = vector_a + vector_b
flops = vector_a.size
elif operation == VectorOperation.SUBTRACT:
if vector_b is None:
raise ValueError("Vector B required for subtraction")
result = vector_a - vector_b
flops = vector_a.size
elif operation == VectorOperation.MULTIPLY:
if vector_b is None:
raise ValueError("Vector B required for multiplication")
result = vector_a * vector_b
flops = vector_a.size
elif operation == VectorOperation.DIVIDE:
if vector_b is None:
raise ValueError("Vector B required for division")
result = vector_a / vector_b
flops = vector_a.size
elif operation == VectorOperation.DOT_PRODUCT:
if vector_b is None:
raise ValueError("Vector B required for dot product")
result = np.dot(vector_a.flatten(), vector_b.flatten())
flops = 2 * vector_a.size
elif operation == VectorOperation.CROSS_PRODUCT:
if vector_b is None:
raise ValueError("Vector B required for cross product")
result = np.cross(vector_a, vector_b)
flops = 6 # Approximate for 3D cross product
elif operation == VectorOperation.NORMALIZE:
magnitude = np.linalg.norm(vector_a)
result = vector_a / magnitude if magnitude > 0 else vector_a
flops = vector_a.size * 2 # Division + magnitude calculation
elif operation == VectorOperation.MAGNITUDE:
result = np.array([np.linalg.norm(vector_a)])
flops = vector_a.size * 2 # Squares and sum
else:
raise ValueError(f"Unsupported vector operation: {operation}")
# Store result in VRAM
if result_id is None:
result_id = f"vector_result_{self.matrix_counter}"
self.matrix_counter += 1
result_vector_id = self.load_matrix(result, result_id)
# Update statistics
compute_time = time.time() - start_time
self.total_compute_time += compute_time
self.operations_performed += 1
self.flops_performed += flops
print(f"Vector operation {operation.value} completed in {compute_time:.4f}s")
return result_vector_id
except Exception as e:
print(f"Error in vector operation {operation.value}: {e}")
return None
def convolution_2d(self, input_id: str, kernel_id: str,
stride: int = 1, padding: int = 0,
result_id: Optional[str] = None) -> Optional[str]:
"""Perform 2D convolution operation."""
start_time = time.time()
# Retrieve input and kernel from VRAM
input_data = self.get_matrix(input_id)
kernel = self.get_matrix(kernel_id)
if input_data is None or kernel is None:
print(f"Error: Could not retrieve input or kernel")
return None
try:
# Simple 2D convolution implementation
# In a real GPU implementation, this would be highly optimized
# and distributed across many cores
if len(input_data.shape) == 2:
input_h, input_w = input_data.shape
channels = 1
else:
input_h, input_w, channels = input_data.shape
kernel_h, kernel_w = kernel.shape[:2]
# Calculate output dimensions
output_h = (input_h + 2 * padding - kernel_h) // stride + 1
output_w = (input_w + 2 * padding - kernel_w) // stride + 1
# Initialize output
if channels == 1:
output = np.zeros((output_h, output_w))
else:
output = np.zeros((output_h, output_w, channels))
# Pad input if necessary
if padding > 0:
if channels == 1:
padded_input = np.pad(input_data, padding, mode='constant')
else:
padded_input = np.pad(input_data,
((padding, padding), (padding, padding), (0, 0)),
mode='constant')
else:
padded_input = input_data
# Perform convolution
flops = 0
for y in range(0, output_h):
for x in range(0, output_w):
y_start = y * stride
x_start = x * stride
if channels == 1:
patch = padded_input[y_start:y_start+kernel_h, x_start:x_start+kernel_w]
output[y, x] = np.sum(patch * kernel)
flops += kernel_h * kernel_w * 2 # Multiply and add
else:
for c in range(channels):
patch = padded_input[y_start:y_start+kernel_h,
x_start:x_start+kernel_w, c]
output[y, x, c] = np.sum(patch * kernel)
flops += kernel_h * kernel_w * 2
# Store result in VRAM
if result_id is None:
result_id = f"conv_result_{self.matrix_counter}"
self.matrix_counter += 1
result_conv_id = self.load_matrix(output, result_id)
# Update statistics
compute_time = time.time() - start_time
self.total_compute_time += compute_time
self.operations_performed += 1
self.flops_performed += flops
print(f"2D Convolution completed: {input_data.shape} * {kernel.shape} "
f"= {output.shape} in {compute_time:.4f}s")
print(f"Simulated {flops:,} FLOPs")
return result_conv_id
except Exception as e:
print(f"Error in 2D convolution: {e}")
return None
def get_stats(self) -> Dict[str, Any]:
"""Get AI accelerator statistics."""
avg_compute_time = self.total_compute_time / max(1, self.operations_performed)
flops_per_second = self.flops_performed / max(0.001, self.total_compute_time)
return {
"operations_performed": self.operations_performed,
"total_compute_time": self.total_compute_time,
"avg_compute_time": avg_compute_time,
"flops_performed": self.flops_performed,
"flops_per_second": flops_per_second,
"matrices_in_memory": len(self.matrix_registry),
"simulated_cores": self.total_cores,
"simulated_sms": self.num_sms
}
def reset_stats(self) -> None:
"""Reset AI accelerator statistics."""
self.operations_performed = 0
self.total_compute_time = 0.0
self.flops_performed = 0
def optimize_attention_weights(self, weight_matrix):
"""Preprocess attention weights for faster computation."""
# Optimize weight layout for tensor core operations
if isinstance(weight_matrix, np.ndarray):
# Reshape for optimal memory access
if len(weight_matrix.shape) == 2:
# Pad to multiple of tensor core size if needed
h, w = weight_matrix.shape
pad_h = (8 - h % 8) if h % 8 != 0 else 0
pad_w = (8 - w % 8) if w % 8 != 0 else 0
if pad_h > 0 or pad_w > 0:
weight_matrix = np.pad(weight_matrix, ((0, pad_h), (0, pad_w)))
return weight_matrix
return weight_matrix
def parallel_attention(self, query, key_value_weights, features_per_sm):
"""Execute multi-head attention using parallel tensor cores."""
# Split attention heads across SMs
num_heads = min(self.num_sms, 32) # Max 32 attention heads
head_dim = query.shape[-1] // num_heads
# Parallel processing of attention heads
attention_results = []
for i in range(0, num_heads):
start_idx = i * head_dim
end_idx = (i + 1) * head_dim
# Process attention head using tensor core
q_head = [row[start_idx:end_idx] for row in query]
k_head = [row[start_idx:end_idx] for row in key_value_weights]
# Compute attention scores using tensor core
attention_scores = self.tensor_core_array.matmul(
q_head, k_head,
split_size=features_per_sm
)
attention_results.append(attention_scores)
# Combine attention heads
return self.combine_attention_heads(attention_results)
def combine_attention_heads(self, attention_heads):
"""Combine attention heads efficiently using tensor cores."""
if not attention_heads:
return None
# Get dimensions
num_heads = len(attention_heads)
batch_size = len(attention_heads[0])
head_dim = len(attention_heads[0][0])
# Concatenate heads efficiently
combined = [[0.0] * (head_dim * num_heads) for _ in range(batch_size)]
for i in range(batch_size):
for h in range(num_heads):
for j in range(head_dim):
combined[i][h * head_dim + j] = attention_heads[h][i][j]
return combined
def calculate_tflops(self, model_info, batch_size, inference_time):
"""Calculate effective TFLOPS for the inference."""
total_params = sum(np.prod(self.get_matrix(w_id).shape) for w_id in model_info["weights"].values())
ops_per_param = 2 # Multiply-add
total_ops = total_params * batch_size * ops_per_param
return (total_ops / inference_time) / 1e12 # Convert to TFLOPS
def _serialize_tensor(self, tensor: Any) -> np.ndarray:
"""Convert a PyTorch tensor to numpy array safely."""
try:
if hasattr(tensor, 'detach'):
tensor = tensor.detach()
if hasattr(tensor, 'cpu'):
tensor = tensor.cpu()
if hasattr(tensor, 'numpy'):
return tensor.numpy()
return np.array(tensor)
except Exception as e:
print(f"Warning: Error converting tensor to numpy: {e}")
return None
def load_model(self, model_id: str, model: Any, processor: Any):
"""Loads a model directly into WebSocket storage without CPU intermediary."""
try:
if model is None and processor is None:
# Zero-copy mode
self.model_registry[model_id] = {
"zero_copy": True,
"websocket_mapped": True
}
self.tokenizer_registry[model_id] = None
self.model_loaded = True
return
# Verify WebSocket connection first
if not self.storage or not self.storage.wait_for_connection():
raise RuntimeError("WebSocket connection not available")
# 1. Store model configuration
try:
config_dict = (self._serialize_model_config(model.config)
if hasattr(model, "config") else {})
model_info = {
"architecture": model.__class__.__name__ if model else "Unknown",
"processor": processor.__class__.__name__ if processor else "Unknown",
"config": config_dict
}
except Exception as e:
print(f"Warning: Error serializing model config: {e}")
model_info = {
"architecture": str(type(model).__name__),
"error": str(e)
}
# Store model info with retry
for attempt in range(3):
try:
if self.storage.store_state(f"models/{model_id}/info", "info", model_info):
break
print(f"Retrying model info storage, attempt {attempt + 1}")
time.sleep(1)
except Exception as e:
if attempt == 2:
raise RuntimeError(f"Failed to store model info: {e}")
# 2. Store model weights
if hasattr(model, "state_dict"):
weight_registry = {}
for name, param in model.state_dict().items():
# Convert tensor to numpy and store in chunks if needed
tensor_data = self._serialize_tensor(param)
if tensor_data is not None:
tensor_id = f"{model_id}/weights/{name}"
if tensor_data.nbytes > 1024*1024*1024: # If larger than 1GB
# Store large tensors in chunks
chunks = np.array_split(tensor_data,
max(1, tensor_data.nbytes // (512*1024*1024)))
chunk_ids = []
for i, chunk in enumerate(chunks):
chunk_id = f"{tensor_id}/chunk_{i}"
if self.storage.store_tensor(chunk_id, chunk):
chunk_ids.append(chunk_id)
weight_registry[name] = {
"type": "chunked",
"chunks": chunk_ids,
"shape": tensor_data.shape,
"dtype": str(tensor_data.dtype)
}
else:
# Store small tensors directly
if self.storage.store_tensor(tensor_id, tensor_data):
weight_registry[name] = {
"type": "direct",
"tensor_id": tensor_id,
"shape": tensor_data.shape,
"dtype": str(tensor_data.dtype)
}
# Store weight registry
self.storage.store_state(f"models/{model_id}/weights", "registry", weight_registry)
self.model_registry[model_id] = {
"weight_registry": weight_registry,
"websocket_mapped": True
}
# Map weight tensors directly to WebSocket storage
if model is not None and hasattr(model, "state_dict"):
model_weights = {}
for name, param in model.state_dict().items():
tensor_id = f"{model_id}/weights/{name}"
# Store tensor directly in WebSocket storage
if not self.storage.store_tensor(tensor_id, param.detach().numpy()):
raise RuntimeError(f"Failed to store tensor {name}")
model_weights[name] = tensor_id
# Store only WebSocket references
self.model_registry[model_id] = {
"weights": model_weights,
"architecture_id": hash(str(type(model))),
"websocket_mapped": True
}
else:
# Store the entire model state in WebSocket storage
tensor_id = f"{model_id}/model_state"
if not self.storage.store_state(f"models/{model_id}/state", "state", model):
raise RuntimeError("Failed to store model state")
self.model_registry[model_id] = tensor_id
# Store tokenizer/processor
self.tokenizer_registry[model_id] = processor
self.model_loaded = True
print(f"Model '{model_id}' loaded into WebSocket storage")
except Exception as e:
print(f"Error loading model into WebSocket storage: {str(e)}")
raise
def has_model(self, model_id: str) -> bool:
"""Checks if a model is loaded in the accelerator's registry."""
return model_id in self.model_registry
def inference(self, model_id: str, input_data: np.ndarray, idx: Optional[int] = None) -> Optional[np.ndarray]:
"""Execute pure WebSocket-based inference with zero CPU usage."""
print(f"[DEBUG] Starting WebSocket-based inference for model_id={model_id}")
try:
if not self.has_model(model_id):
print(f"[ERROR] Model {model_id} not loaded in WebSocket storage.")
return None
model_info = self.model_registry[model_id]
processor = self.tokenizer_registry[model_id]
# Store input data in WebSocket storage
input_tensor_id = f"{model_id}/inputs/{idx if idx is not None else time.time_ns()}"
self.storage.store_tensor(input_tensor_id, input_data)
# Process input using tensor cores through WebSocket
processed_data = processor(input_data, return_tensors="np")
processed_tensor_id = f"{model_id}/processed/{idx if idx is not None else time.time_ns()}"
self.storage.store_tensor(processed_tensor_id, processed_data["input_ids"])
# Load weights from WebSocket storage and perform forward pass
if isinstance(model_info, dict) and "weights" in model_info:
# Initialize hidden states
hidden_states = processed_data["input_ids"]
# Process through each layer using tensor cores
for layer_name, weight_id in model_info["weights"].items():
if "weight" in layer_name:
# Load weights from WebSocket storage
weights = self.storage.load_tensor(weight_id)
if weights is None:
continue
# Process through tensor cores
if "attention" in layer_name:
hidden_states = self.parallel_attention(
hidden_states,
weights,
features_per_sm=hidden_states.shape[-1] // self.num_sms
)
else:
# Regular layer processing
hidden_states = self.tensor_core_array.matmul(
hidden_states.tolist(),
weights.tolist()
)
# Store final output in WebSocket storage
output_tensor_id = f"{model_id}/outputs/{idx if idx is not None else time.time_ns()}"
output = np.array(hidden_states)
self.storage.store_tensor(output_tensor_id, output)
return output
else:
print(f"[ERROR] Unsupported model format in WebSocket storage")
return None
except Exception as e:
print(f"[ERROR] WebSocket-based inference failed for idx={idx}: {e}")
return None