INV / streaming_multiprocessor.py
Fred808's picture
Upload 256 files
7a0c684 verified
from http_storage import LocalStorage
from matrix_ops import MatrixOpScheduler, MatrixOpMetadata
# Configure loggingorage import
import numpy as np
from typing import Dict, Any, Optional, List
import time
import threading
import json
import hashlib
import logging
from config import get_db_url
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
class MatrixOpLock:
"""Enhanced locking mechanism for matrix operations"""
def __init__(self, sm_id: int, chip_id: int, storage):
self.sm_id = sm_id
self.chip_id = chip_id
self.storage = storage
self.lock = threading.Lock()
self.op_locks = {}
self.op_metadata = {}
def acquire_matrix_op(self, op_id: str, matrix_info: Dict[str, Any]) -> bool:
"""Acquire lock for matrix operation with metadata"""
with self.lock:
if op_id in self.op_locks:
return False
# Create operation-specific lock
self.op_locks[op_id] = threading.Lock()
self.op_metadata[op_id] = {
**matrix_info,
'sm_id': self.sm_id,
'chip_id': self.chip_id,
'acquired_time': time.time_ns(),
'status': 'locked'
}
# Store lock state
try:
self.storage.store_state(
f"matrix_op_{self.chip_id}_{self.sm_id}_{op_id}",
'lock_state',
self.op_metadata[op_id]
)
return True
except Exception:
del self.op_locks[op_id]
del self.op_metadata[op_id]
return False
def release_matrix_op(self, op_id: str):
"""Release matrix operation lock"""
with self.lock:
if op_id in self.op_locks:
self.op_metadata[op_id]['status'] = 'released'
self.op_metadata[op_id]['release_time'] = time.time_ns()
try:
self.storage.store_state(
f"matrix_op_{self.chip_id}_{self.sm_id}_{op_id}",
'lock_state',
self.op_metadata[op_id]
)
finally:
del self.op_locks[op_id]
del self.op_metadata[op_id]
class StreamingMultiprocessor:
def __init__(self, sm_id: int, chip_id: int = 0, num_cores: int = 128, storage=None):
self.sm_id = sm_id
self.chip_id = chip_id
self.num_cores = num_cores
# Initialize storage with retries
max_retries = 3
retry_delay = 1.0
for attempt in range(max_retries):
try:
self.storage = storage or LocalStorage(db_url=get_db_url())
if not self.storage.wait_for_connection(timeout=10):
raise RuntimeError("Storage connection timeout")
logging.info(f"SM {sm_id} on chip {chip_id}: Connected to storage backend")
break
except Exception as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Failed to initialize storage after {max_retries} attempts: {str(e)}")
logging.warning(f"Storage initialization attempt {attempt + 1} failed: {str(e)}")
time.sleep(retry_delay)
# Initialize enhanced matrix operation tracking
self.matrix_op_scheduler = MatrixOpScheduler(num_sms=1, cores_per_sm=8) # Each SM manages its own scheduler
self.matrix_op_lock = MatrixOpLock(sm_id, chip_id, self.storage)
self.current_tensor_ops = {}
self.tensor_op_history = []
self.state_lock = threading.Lock()
# Initialize SM state with unique identifier
self.sm_key = hashlib.md5(f"sm_{chip_id}_{sm_id}".encode()).hexdigest()[:16]
self.sm_state = {
"sm_id": sm_id,
"chip_id": chip_id,
"num_cores": num_cores,
"active_warps": {},
"shared_memory": {},
"register_file": {},
"l1_cache": {},
"tensor_cores": {
"count": 8,
"active": True,
"operations": ["matmul", "conv2d", "attention"],
"current_ops": {},
"op_history": [],
"locks": {},
"utilization": {
"ops_completed": 0,
"ops_failed": 0,
"total_execution_time": 0,
"last_operation_time": None
}
},
"warp_scheduler_state": {
"active": True,
"current_warp": 0,
"active_warps": [],
"completed_warps": [],
"warp_dependencies": {},
"warp_priorities": {},
"blocked_warps": {},
"warp_sync_points": {}
},
"matrix_operations": {
"active_locks": {},
"operation_history": [],
"resource_usage": {
"shared_memory_usage": {},
"register_allocation": {}
}
}
}
self.store_sm_state()
def tensor_core_matmul(self, A: np.ndarray, B: np.ndarray, tensor_core_id: int = 0) -> Optional[np.ndarray]:
"""Execute matrix multiplication on tensor core"""
op_id = f"tensor_op_{time.time_ns()}"
with self.matrix_ops_lock:
# Check tensor core availability
if tensor_core_id >= self.sm_state["tensor_cores"]["count"]:
logging.error(f"Invalid tensor core ID: {tensor_core_id}")
return None
try:
# Update operation state
self.sm_state["tensor_operations"][op_id] = {
"type": "matmul",
"tensor_core_id": tensor_core_id,
"status": "running",
"start_time": time.time()
}
self.store_sm_state()
# Execute matrix multiplication
result = np.matmul(A, B)
# Update operation status
self.sm_state["tensor_operations"][op_id]["status"] = "completed"
self.sm_state["tensor_operations"][op_id]["end_time"] = time.time()
self.store_sm_state()
return result
except Exception as e:
logging.error(f"Tensor core matmul failed: {str(e)}")
if op_id in self.sm_state["tensor_operations"]:
self.sm_state["tensor_operations"][op_id]["status"] = "failed"
self.sm_state["tensor_operations"][op_id]["error"] = str(e)
self.store_sm_state()
return None
def read_matrix_from_shared_memory(self, addr: int, n: int, m: int) -> np.ndarray:
"""Read a matrix from shared memory"""
matrix = np.zeros((n, m))
for i in range(n):
for j in range(m):
key = f"{addr + i * m + j}"
matrix[i, j] = self.sm_state["shared_memory"].get(key, 0.0)
return matrix
def write_matrix_to_shared_memory(self, addr: int, matrix: np.ndarray) -> None:
"""Write a matrix to shared memory"""
n, m = matrix.shape
for i in range(n):
for j in range(m):
key = f"{addr + i * m + j}"
self.sm_state["shared_memory"][key] = float(matrix[i, j])
self.store_sm_state()
def tensor_core_matmul_from_memory(self, addr_A: int, shape_A: tuple,
addr_B: int, shape_B: tuple,
addr_C: int, tensor_core_id: int = 0) -> bool:
"""Execute matrix multiplication using data from shared memory"""
try:
# Read input matrices
A = self.read_matrix_from_shared_memory(addr_A, *shape_A)
B = self.read_matrix_from_shared_memory(addr_B, *shape_B)
# Perform multiplication
C = self.tensor_core_matmul(A, B, tensor_core_id)
if C is None:
return False
# Write result
self.write_matrix_to_shared_memory(addr_C, C)
return True
except Exception as e:
logging.error(f"Tensor core matmul from memory failed: {str(e)}")
return False
def tensor_core_matmul(self, A: np.ndarray, B: np.ndarray, tensor_core_id: int = 0) -> Optional[np.ndarray]:
"""Execute matrix multiplication on tensor core"""
op_id = f"tensor_op_{time.time_ns()}"
with self.matrix_ops_lock:
# Check tensor core availability
if tensor_core_id >= self.sm_state["tensor_cores"]["count"]:
logging.error(f"Invalid tensor core ID: {tensor_core_id}")
return None
try:
# Update operation state
self.sm_state["tensor_operations"][op_id] = {
"type": "matmul",
"tensor_core_id": tensor_core_id,
"status": "running",
"start_time": time.time()
}
self.store_sm_state()
# Execute matrix multiplication
result = np.matmul(A, B)
# Update operation status
self.sm_state["tensor_operations"][op_id]["status"] = "completed"
self.sm_state["tensor_operations"][op_id]["end_time"] = time.time()
self.store_sm_state()
return result
except Exception as e:
logging.error(f"Tensor core matmul failed: {str(e)}")
if op_id in self.sm_state["tensor_operations"]:
self.sm_state["tensor_operations"][op_id]["status"] = "failed"
self.sm_state["tensor_operations"][op_id]["error"] = str(e)
self.store_sm_state()
return None
def read_matrix_from_shared_memory(self, addr: int, n: int, m: int) -> np.ndarray:
"""Read a matrix from shared memory"""
matrix = np.zeros((n, m))
for i in range(n):
for j in range(m):
key = f"{addr + i * m + j}"
matrix[i, j] = self.sm_state["shared_memory"].get(key, 0.0)
return matrix
def write_matrix_to_shared_memory(self, addr: int, matrix: np.ndarray) -> None:
"""Write a matrix to shared memory"""
n, m = matrix.shape
for i in range(n):
for j in range(m):
key = f"{addr + i * m + j}"
self.sm_state["shared_memory"][key] = float(matrix[i, j])
self.store_sm_state()
def tensor_core_matmul_from_memory(self, addr_A: int, shape_A: tuple,
addr_B: int, shape_B: tuple,
addr_C: int, tensor_core_id: int = 0,
warp_id: Optional[str] = None) -> bool:
"""Execute matrix multiplication using data from shared memory with enhanced tracking"""
try:
# Schedule the operation
op_metadata = self.matrix_op_scheduler.schedule_operation(
op_type="matmul",
input_shapes=[shape_A, shape_B],
warp_id=warp_id
)
if op_metadata is None:
logging.error("Failed to schedule matrix operation - resources unavailable")
return False
try:
# Read input matrices
A = self.read_matrix_from_shared_memory(addr_A, *shape_A)
B = self.read_matrix_from_shared_memory(addr_B, *shape_B)
# Acquire matrix operation lock
if not self.matrix_op_lock.acquire_matrix_op(op_metadata.op_id, {
"type": "matmul",
"input_shapes": [shape_A, shape_B],
"warp_id": warp_id,
"tensor_core_id": tensor_core_id
}):
raise RuntimeError("Failed to acquire matrix operation lock")
try:
# Perform multiplication with tensor core
C = self.tensor_core_matmul(A, B, tensor_core_id, warp_id)
if C is None:
raise RuntimeError("Matrix multiplication failed")
# Write result
self.write_matrix_to_shared_memory(addr_C, C)
# Complete operation successfully
self.matrix_op_scheduler.complete_operation(
op_metadata,
output_shape=C.shape,
success=True
)
# Update operation history
self.tensor_op_history.append({
"op_id": op_metadata.op_id,
"type": "matmul",
"input_shapes": [shape_A, shape_B],
"output_shape": C.shape,
"warp_id": warp_id,
"tensor_core_id": tensor_core_id,
"start_time": op_metadata.start_time,
"end_time": time.time_ns(),
"status": "completed"
})
return True
finally:
# Always release the matrix operation lock
self.matrix_op_lock.release_matrix_op(op_metadata.op_id)
except Exception as e:
# Handle operation failure
self.matrix_op_scheduler.complete_operation(
op_metadata,
output_shape=None,
success=False,
error=str(e)
)
raise
except Exception as e:
logging.error(f"Tensor core matmul from memory failed: {str(e)}")
return False
def store_sm_state(self):
"""Store SM state in remote storage"""
with self.state_lock:
# Prepare state data with metadata
state_data = {
"sm_state": self.sm_state,
"timestamp": time.time_ns(),
"chip_id": self.chip_id,
"sm_id": self.sm_id,
"sm_key": self.sm_key
}
try:
# Store state in remote storage
success = self.storage.store_state(
component=f"sm_{self.chip_id}_{self.sm_id}",
state_id=self.sm_key,
state_data=state_data
)
if not success:
logging.error(f"Failed to store state for SM {self.sm_id} on chip {self.chip_id}")
return False
# Update last sync time
self.sm_state["storage_state"]["last_sync"] = time.time_ns()
return True
except Exception as e:
logging.error(f"Error storing SM state: {str(e)}")
return False
def allocate_shared_memory(self, size: int, block_id: str) -> str:
"""Allocate shared memory block in remote storage"""
shared_id = f"shared_{self.chip_id}_{self.sm_id}_{block_id}_{time.time_ns()}"
with self.state_lock:
# Create memory block metadata
memory_block = {
"size": size,
"block_id": block_id,
"allocated_at": time.time_ns(),
"sm_key": self.sm_key,
"shared_id": shared_id
}
# Store metadata in SM state and remote storage
self.sm_state["shared_memory"][shared_id] = memory_block
try:
# Store initial empty tensor to reserve the space
empty_tensor = np.zeros(size, dtype=np.float32)
self.storage.store_tensor(shared_id, empty_tensor, {
"sm_key": self.sm_key,
"block_id": block_id,
"allocated_at": time.time_ns(),
"size": size,
"status": "allocated"
})
# Update SM state in storage
self.store_sm_state()
return shared_id
except Exception as e:
# Cleanup on failure
del self.sm_state["shared_memory"][shared_id]
logging.error(f"Failed to allocate shared memory: {str(e)}")
raise RuntimeError(f"Shared memory allocation failed: {str(e)}")
def write_shared_memory(self, shared_id: str, data: np.ndarray):
"""Write to shared memory using remote storage"""
with self.state_lock:
if shared_id not in self.sm_state["shared_memory"]:
raise ValueError(f"Shared memory block {shared_id} not allocated")
try:
# Store data with metadata
success = self.storage.store_tensor(shared_id, data, {
"sm_key": self.sm_key,
"block_id": self.sm_state["shared_memory"][shared_id]["block_id"],
"last_write": time.time_ns(),
"shape": data.shape,
"dtype": str(data.dtype),
"status": "written"
})
if not success:
raise RuntimeError("Failed to store tensor data")
# Update access timestamp and state
self.sm_state["shared_memory"][shared_id]["last_accessed"] = time.time_ns()
self.sm_state["shared_memory"][shared_id]["last_write"] = time.time_ns()
self.store_sm_state()
return True
except Exception as e:
logging.error(f"Error writing to shared memory: {str(e)}")
return False
def read_shared_memory(self, shared_id: str) -> Optional[np.ndarray]:
"""Read from shared memory using remote storage"""
with self.state_lock:
if shared_id not in self.sm_state["shared_memory"]:
raise ValueError(f"Shared memory block {shared_id} not allocated")
try:
# Read from remote storage
result = self.storage.load_tensor(shared_id)
if result is not None:
data, metadata = result
# Update cache hit/miss stats
self.sm_state["storage_state"]["cache_hits"] += 1
# Update access timestamp
self.sm_state["shared_memory"][shared_id]["last_accessed"] = time.time_ns()
return data
else:
self.sm_state["storage_state"]["cache_misses"] += 1
logging.warning(f"Cache miss for shared memory block {shared_id}")
return None
except Exception as e:
logging.error(f"Error reading from shared memory: {str(e)}")
return None
finally:
# Always update access timestamp and state
self.sm_state["shared_memory"][shared_id]["last_accessed"] = time.time_ns()
self.store_sm_state()
def schedule_warp(self, warp_id: str, warp_state: Dict[str, Any]):
"""Schedule a warp for execution with enhanced state tracking and resource management"""
with self.state_lock:
# Generate unique storage key for warp
warp_key = f"warp_{self.chip_id}_{self.sm_id}_{warp_id}_{time.time_ns()}"
try:
# Check resource availability and dependencies
resource_state = self._check_warp_resources(warp_id, warp_state)
if not resource_state['available']:
logging.warning(f"Resources not available for warp {warp_id}: {resource_state['reason']}")
self.sm_state["warp_scheduler_state"]["blocked_warps"][warp_id] = {
"reason": resource_state['reason'],
"blocking_resources": resource_state['blocking_resources'],
"timestamp": time.time_ns()
}
return False
# Check for dependencies
dependencies = warp_state.get('dependencies', [])
if dependencies:
for dep_id in dependencies:
if dep_id not in self.sm_state["warp_scheduler_state"]["completed_warps"]:
self.sm_state["warp_scheduler_state"]["warp_dependencies"][warp_id] = dependencies
logging.info(f"Warp {warp_id} waiting for dependencies: {dependencies}")
return False
# Prepare enhanced warp state with resource tracking
enhanced_warp_state = {
**warp_state,
"warp_key": warp_key,
"scheduled_at": time.time_ns(),
"resources": resource_state['allocated_resources'],
"priority": warp_state.get('priority', 0),
"expected_duration": warp_state.get('expected_duration'),
"matrix_ops": [],
"sync_points": []
}
# Store state in remote storage with resource metadata
success = self.storage.store_state(
component=f"warp_{self.chip_id}_{self.sm_id}",
state_id=warp_key,
state_data={
"warp_id": warp_id,
"warp_state": enhanced_warp_state,
"sm_key": self.sm_key,
"scheduled_at": time.time_ns(),
"status": "scheduled",
"resource_state": resource_state
}
)
if not success:
raise RuntimeError("Failed to store warp state")
# Update scheduler state
self.sm_state["warp_scheduler_state"]["active_warps"].append(warp_id)
self.sm_state["warp_scheduler_state"]["warp_priorities"][warp_id] = enhanced_warp_state["priority"]
# Update active warps with resource tracking
self.sm_state["active_warps"][warp_id] = enhanced_warp_state
# Clear any blocked state
if warp_id in self.sm_state["warp_scheduler_state"]["blocked_warps"]:
del self.sm_state["warp_scheduler_state"]["blocked_warps"][warp_id]
# Update SM state in storage
self.store_sm_state()
logging.info(f"Successfully scheduled warp {warp_id} with priority {enhanced_warp_state['priority']}")
return True
except Exception as e:
logging.error(f"Error scheduling warp {warp_id}: {str(e)}")
# Cleanup on failure
if warp_id in self.sm_state["active_warps"]:
del self.sm_state["active_warps"][warp_id]
if warp_id in self.sm_state["warp_scheduler_state"]["active_warps"]:
self.sm_state["warp_scheduler_state"]["active_warps"].remove(warp_id)
if warp_id in self.sm_state["warp_scheduler_state"]["warp_priorities"]:
del self.sm_state["warp_scheduler_state"]["warp_priorities"][warp_id]
return False
def _check_warp_resources(self, warp_id: str, warp_state: Dict[str, Any]) -> Dict[str, Any]:
"""Check and allocate resources for a warp"""
needed_resources = warp_state.get('resource_requirements', {})
# Check tensor core availability
if 'tensor_cores' in needed_resources:
num_cores_needed = needed_resources['tensor_cores']
available_cores = self.sm_state["tensor_cores"]["count"] - len(self.sm_state["tensor_cores"]["current_ops"])
if available_cores < num_cores_needed:
return {
'available': False,
'reason': 'insufficient_tensor_cores',
'blocking_resources': {'tensor_cores': num_cores_needed - available_cores}
}
# Check shared memory availability
if 'shared_memory' in needed_resources:
memory_needed = needed_resources['shared_memory']
memory_used = sum(self.sm_state["matrix_operations"]["resource_usage"]["shared_memory_usage"].values())
if memory_used + memory_needed > self._get_max_shared_memory():
return {
'available': False,
'reason': 'insufficient_shared_memory',
'blocking_resources': {'shared_memory': memory_needed}
}
# All resources available, allocate them
allocated_resources = {
'tensor_cores': [], # Will be filled when actually used
'shared_memory': 0, # Will be updated when memory is actually allocated
'allocation_time': time.time_ns()
}
return {
'available': True,
'allocated_resources': allocated_resources,
'allocation_id': f"alloc_{warp_id}_{time.time_ns()}"
}
def complete_warp(self, warp_id: str):
"""Mark a warp as completed using remote storage"""
with self.state_lock:
if warp_id in self.sm_state["active_warps"]:
try:
# Get warp state and key
warp_state = self.sm_state["active_warps"][warp_id]
warp_key = warp_state.get("warp_key")
if warp_key:
# Update warp state in storage
success = self.storage.store_state(
component=f"warp_{self.chip_id}_{self.sm_id}",
state_id=warp_key,
state_data={
"warp_id": warp_id,
"warp_state": warp_state,
"sm_key": self.sm_key,
"completed_at": time.time_ns(),
"status": "completed"
}
)
if not success:
logging.error(f"Failed to store completed state for warp {warp_id}")
# Update local state
self.sm_state["warp_scheduler_state"]["active_warps"].remove(warp_id)
self.sm_state["warp_scheduler_state"]["completed_warps"].append(warp_id)
self.sm_state["active_warps"].pop(warp_id)
# Update SM state
self.store_sm_state()
return True
except Exception as e:
logging.error(f"Error completing warp {warp_id}: {str(e)}")
return False
return False
def write_register(self, warp_id: str, reg_id: str, data: np.ndarray):
"""Write to register file using remote storage"""
reg_key = f"reg_{self.chip_id}_{self.sm_id}_{warp_id}_{reg_id}_{time.time_ns()}"
try:
# Store register data with metadata
success = self.storage.store_tensor(reg_key, data, {
"warp_id": warp_id,
"reg_id": reg_id,
"sm_key": self.sm_key,
"chip_id": self.chip_id,
"written_at": time.time_ns(),
"shape": data.shape,
"dtype": str(data.dtype)
})
if success:
# Update register file state
self.sm_state["register_file"][reg_key] = {
"warp_id": warp_id,
"reg_id": reg_id,
"last_accessed": time.time_ns(),
"storage_key": reg_key
}
self.store_sm_state()
return True
return False
except Exception as e:
logging.error(f"Error writing to register {reg_id} for warp {warp_id}: {str(e)}")
return False
def read_register(self, warp_id: str, reg_id: str) -> Optional[np.ndarray]:
"""Read from register file using remote storage"""
# Find the latest register key for this warp/reg combination
reg_keys = [k for k in self.sm_state["register_file"].keys()
if k.startswith(f"reg_{self.chip_id}_{self.sm_id}_{warp_id}_{reg_id}")]
if not reg_keys:
return None
# Get the latest register key
latest_key = max(reg_keys, key=lambda k: self.sm_state["register_file"][k]["last_accessed"])
try:
# Read from storage
result = self.storage.load_tensor(latest_key)
if result is not None:
data, metadata = result
# Update access timestamp
self.sm_state["register_file"][latest_key]["last_accessed"] = time.time_ns()
self.store_sm_state()
return data
return None
except Exception as e:
logging.error(f"Error reading register {reg_id} for warp {warp_id}: {str(e)}")
return None
def get_stats(self) -> Dict[str, Any]:
"""Get SM statistics"""
return {
"sm_id": self.sm_id,
"num_cores": self.num_cores,
"active_warps": len(self.sm_state["active_warps"]),
"shared_memory_blocks": len(self.sm_state["shared_memory"]),
"register_file_entries": len(self.sm_state["register_file"]),
"completed_warps": len(self.sm_state["warp_scheduler_state"]["completed_warps"])
}