|
|
from http_storage import LocalStorage
|
|
|
from matrix_ops import MatrixOpScheduler, MatrixOpMetadata
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
from typing import Dict, Any, Optional, List
|
|
|
import time
|
|
|
import threading
|
|
|
import json
|
|
|
import hashlib
|
|
|
import logging
|
|
|
from config import get_db_url
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
|
)
|
|
|
|
|
|
class MatrixOpLock:
|
|
|
"""Enhanced locking mechanism for matrix operations"""
|
|
|
def __init__(self, sm_id: int, chip_id: int, storage):
|
|
|
self.sm_id = sm_id
|
|
|
self.chip_id = chip_id
|
|
|
self.storage = storage
|
|
|
self.lock = threading.Lock()
|
|
|
self.op_locks = {}
|
|
|
self.op_metadata = {}
|
|
|
|
|
|
def acquire_matrix_op(self, op_id: str, matrix_info: Dict[str, Any]) -> bool:
|
|
|
"""Acquire lock for matrix operation with metadata"""
|
|
|
with self.lock:
|
|
|
if op_id in self.op_locks:
|
|
|
return False
|
|
|
|
|
|
|
|
|
self.op_locks[op_id] = threading.Lock()
|
|
|
self.op_metadata[op_id] = {
|
|
|
**matrix_info,
|
|
|
'sm_id': self.sm_id,
|
|
|
'chip_id': self.chip_id,
|
|
|
'acquired_time': time.time_ns(),
|
|
|
'status': 'locked'
|
|
|
}
|
|
|
|
|
|
|
|
|
try:
|
|
|
self.storage.store_state(
|
|
|
f"matrix_op_{self.chip_id}_{self.sm_id}_{op_id}",
|
|
|
'lock_state',
|
|
|
self.op_metadata[op_id]
|
|
|
)
|
|
|
return True
|
|
|
except Exception:
|
|
|
del self.op_locks[op_id]
|
|
|
del self.op_metadata[op_id]
|
|
|
return False
|
|
|
|
|
|
def release_matrix_op(self, op_id: str):
|
|
|
"""Release matrix operation lock"""
|
|
|
with self.lock:
|
|
|
if op_id in self.op_locks:
|
|
|
self.op_metadata[op_id]['status'] = 'released'
|
|
|
self.op_metadata[op_id]['release_time'] = time.time_ns()
|
|
|
|
|
|
try:
|
|
|
self.storage.store_state(
|
|
|
f"matrix_op_{self.chip_id}_{self.sm_id}_{op_id}",
|
|
|
'lock_state',
|
|
|
self.op_metadata[op_id]
|
|
|
)
|
|
|
finally:
|
|
|
del self.op_locks[op_id]
|
|
|
del self.op_metadata[op_id]
|
|
|
|
|
|
class StreamingMultiprocessor:
|
|
|
def __init__(self, sm_id: int, chip_id: int = 0, num_cores: int = 128, storage=None):
|
|
|
self.sm_id = sm_id
|
|
|
self.chip_id = chip_id
|
|
|
self.num_cores = num_cores
|
|
|
|
|
|
|
|
|
max_retries = 3
|
|
|
retry_delay = 1.0
|
|
|
|
|
|
for attempt in range(max_retries):
|
|
|
try:
|
|
|
self.storage = storage or LocalStorage(db_url=get_db_url())
|
|
|
if not self.storage.wait_for_connection(timeout=10):
|
|
|
raise RuntimeError("Storage connection timeout")
|
|
|
logging.info(f"SM {sm_id} on chip {chip_id}: Connected to storage backend")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
if attempt == max_retries - 1:
|
|
|
raise RuntimeError(f"Failed to initialize storage after {max_retries} attempts: {str(e)}")
|
|
|
logging.warning(f"Storage initialization attempt {attempt + 1} failed: {str(e)}")
|
|
|
time.sleep(retry_delay)
|
|
|
|
|
|
|
|
|
self.matrix_op_scheduler = MatrixOpScheduler(num_sms=1, cores_per_sm=8)
|
|
|
self.matrix_op_lock = MatrixOpLock(sm_id, chip_id, self.storage)
|
|
|
self.current_tensor_ops = {}
|
|
|
self.tensor_op_history = []
|
|
|
self.state_lock = threading.Lock()
|
|
|
|
|
|
|
|
|
self.sm_key = hashlib.md5(f"sm_{chip_id}_{sm_id}".encode()).hexdigest()[:16]
|
|
|
self.sm_state = {
|
|
|
"sm_id": sm_id,
|
|
|
"chip_id": chip_id,
|
|
|
"num_cores": num_cores,
|
|
|
"active_warps": {},
|
|
|
"shared_memory": {},
|
|
|
"register_file": {},
|
|
|
"l1_cache": {},
|
|
|
"tensor_cores": {
|
|
|
"count": 8,
|
|
|
"active": True,
|
|
|
"operations": ["matmul", "conv2d", "attention"],
|
|
|
"current_ops": {},
|
|
|
"op_history": [],
|
|
|
"locks": {},
|
|
|
"utilization": {
|
|
|
"ops_completed": 0,
|
|
|
"ops_failed": 0,
|
|
|
"total_execution_time": 0,
|
|
|
"last_operation_time": None
|
|
|
}
|
|
|
},
|
|
|
"warp_scheduler_state": {
|
|
|
"active": True,
|
|
|
"current_warp": 0,
|
|
|
"active_warps": [],
|
|
|
"completed_warps": [],
|
|
|
"warp_dependencies": {},
|
|
|
"warp_priorities": {},
|
|
|
"blocked_warps": {},
|
|
|
"warp_sync_points": {}
|
|
|
},
|
|
|
"matrix_operations": {
|
|
|
"active_locks": {},
|
|
|
"operation_history": [],
|
|
|
"resource_usage": {
|
|
|
"shared_memory_usage": {},
|
|
|
"register_allocation": {}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
self.store_sm_state()
|
|
|
|
|
|
def tensor_core_matmul(self, A: np.ndarray, B: np.ndarray, tensor_core_id: int = 0) -> Optional[np.ndarray]:
|
|
|
"""Execute matrix multiplication on tensor core"""
|
|
|
op_id = f"tensor_op_{time.time_ns()}"
|
|
|
|
|
|
with self.matrix_ops_lock:
|
|
|
|
|
|
if tensor_core_id >= self.sm_state["tensor_cores"]["count"]:
|
|
|
logging.error(f"Invalid tensor core ID: {tensor_core_id}")
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.sm_state["tensor_operations"][op_id] = {
|
|
|
"type": "matmul",
|
|
|
"tensor_core_id": tensor_core_id,
|
|
|
"status": "running",
|
|
|
"start_time": time.time()
|
|
|
}
|
|
|
self.store_sm_state()
|
|
|
|
|
|
|
|
|
result = np.matmul(A, B)
|
|
|
|
|
|
|
|
|
self.sm_state["tensor_operations"][op_id]["status"] = "completed"
|
|
|
self.sm_state["tensor_operations"][op_id]["end_time"] = time.time()
|
|
|
self.store_sm_state()
|
|
|
|
|
|
return result
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Tensor core matmul failed: {str(e)}")
|
|
|
if op_id in self.sm_state["tensor_operations"]:
|
|
|
self.sm_state["tensor_operations"][op_id]["status"] = "failed"
|
|
|
self.sm_state["tensor_operations"][op_id]["error"] = str(e)
|
|
|
self.store_sm_state()
|
|
|
return None
|
|
|
|
|
|
def read_matrix_from_shared_memory(self, addr: int, n: int, m: int) -> np.ndarray:
|
|
|
"""Read a matrix from shared memory"""
|
|
|
matrix = np.zeros((n, m))
|
|
|
for i in range(n):
|
|
|
for j in range(m):
|
|
|
key = f"{addr + i * m + j}"
|
|
|
matrix[i, j] = self.sm_state["shared_memory"].get(key, 0.0)
|
|
|
return matrix
|
|
|
|
|
|
def write_matrix_to_shared_memory(self, addr: int, matrix: np.ndarray) -> None:
|
|
|
"""Write a matrix to shared memory"""
|
|
|
n, m = matrix.shape
|
|
|
for i in range(n):
|
|
|
for j in range(m):
|
|
|
key = f"{addr + i * m + j}"
|
|
|
self.sm_state["shared_memory"][key] = float(matrix[i, j])
|
|
|
self.store_sm_state()
|
|
|
|
|
|
def tensor_core_matmul_from_memory(self, addr_A: int, shape_A: tuple,
|
|
|
addr_B: int, shape_B: tuple,
|
|
|
addr_C: int, tensor_core_id: int = 0) -> bool:
|
|
|
"""Execute matrix multiplication using data from shared memory"""
|
|
|
try:
|
|
|
|
|
|
A = self.read_matrix_from_shared_memory(addr_A, *shape_A)
|
|
|
B = self.read_matrix_from_shared_memory(addr_B, *shape_B)
|
|
|
|
|
|
|
|
|
C = self.tensor_core_matmul(A, B, tensor_core_id)
|
|
|
if C is None:
|
|
|
return False
|
|
|
|
|
|
|
|
|
self.write_matrix_to_shared_memory(addr_C, C)
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Tensor core matmul from memory failed: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
def tensor_core_matmul(self, A: np.ndarray, B: np.ndarray, tensor_core_id: int = 0) -> Optional[np.ndarray]:
|
|
|
"""Execute matrix multiplication on tensor core"""
|
|
|
op_id = f"tensor_op_{time.time_ns()}"
|
|
|
|
|
|
with self.matrix_ops_lock:
|
|
|
|
|
|
if tensor_core_id >= self.sm_state["tensor_cores"]["count"]:
|
|
|
logging.error(f"Invalid tensor core ID: {tensor_core_id}")
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.sm_state["tensor_operations"][op_id] = {
|
|
|
"type": "matmul",
|
|
|
"tensor_core_id": tensor_core_id,
|
|
|
"status": "running",
|
|
|
"start_time": time.time()
|
|
|
}
|
|
|
self.store_sm_state()
|
|
|
|
|
|
|
|
|
result = np.matmul(A, B)
|
|
|
|
|
|
|
|
|
self.sm_state["tensor_operations"][op_id]["status"] = "completed"
|
|
|
self.sm_state["tensor_operations"][op_id]["end_time"] = time.time()
|
|
|
self.store_sm_state()
|
|
|
|
|
|
return result
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Tensor core matmul failed: {str(e)}")
|
|
|
if op_id in self.sm_state["tensor_operations"]:
|
|
|
self.sm_state["tensor_operations"][op_id]["status"] = "failed"
|
|
|
self.sm_state["tensor_operations"][op_id]["error"] = str(e)
|
|
|
self.store_sm_state()
|
|
|
return None
|
|
|
|
|
|
def read_matrix_from_shared_memory(self, addr: int, n: int, m: int) -> np.ndarray:
|
|
|
"""Read a matrix from shared memory"""
|
|
|
matrix = np.zeros((n, m))
|
|
|
for i in range(n):
|
|
|
for j in range(m):
|
|
|
key = f"{addr + i * m + j}"
|
|
|
matrix[i, j] = self.sm_state["shared_memory"].get(key, 0.0)
|
|
|
return matrix
|
|
|
|
|
|
def write_matrix_to_shared_memory(self, addr: int, matrix: np.ndarray) -> None:
|
|
|
"""Write a matrix to shared memory"""
|
|
|
n, m = matrix.shape
|
|
|
for i in range(n):
|
|
|
for j in range(m):
|
|
|
key = f"{addr + i * m + j}"
|
|
|
self.sm_state["shared_memory"][key] = float(matrix[i, j])
|
|
|
self.store_sm_state()
|
|
|
|
|
|
def tensor_core_matmul_from_memory(self, addr_A: int, shape_A: tuple,
|
|
|
addr_B: int, shape_B: tuple,
|
|
|
addr_C: int, tensor_core_id: int = 0,
|
|
|
warp_id: Optional[str] = None) -> bool:
|
|
|
"""Execute matrix multiplication using data from shared memory with enhanced tracking"""
|
|
|
try:
|
|
|
|
|
|
op_metadata = self.matrix_op_scheduler.schedule_operation(
|
|
|
op_type="matmul",
|
|
|
input_shapes=[shape_A, shape_B],
|
|
|
warp_id=warp_id
|
|
|
)
|
|
|
|
|
|
if op_metadata is None:
|
|
|
logging.error("Failed to schedule matrix operation - resources unavailable")
|
|
|
return False
|
|
|
|
|
|
try:
|
|
|
|
|
|
A = self.read_matrix_from_shared_memory(addr_A, *shape_A)
|
|
|
B = self.read_matrix_from_shared_memory(addr_B, *shape_B)
|
|
|
|
|
|
|
|
|
if not self.matrix_op_lock.acquire_matrix_op(op_metadata.op_id, {
|
|
|
"type": "matmul",
|
|
|
"input_shapes": [shape_A, shape_B],
|
|
|
"warp_id": warp_id,
|
|
|
"tensor_core_id": tensor_core_id
|
|
|
}):
|
|
|
raise RuntimeError("Failed to acquire matrix operation lock")
|
|
|
|
|
|
try:
|
|
|
|
|
|
C = self.tensor_core_matmul(A, B, tensor_core_id, warp_id)
|
|
|
if C is None:
|
|
|
raise RuntimeError("Matrix multiplication failed")
|
|
|
|
|
|
|
|
|
self.write_matrix_to_shared_memory(addr_C, C)
|
|
|
|
|
|
|
|
|
self.matrix_op_scheduler.complete_operation(
|
|
|
op_metadata,
|
|
|
output_shape=C.shape,
|
|
|
success=True
|
|
|
)
|
|
|
|
|
|
|
|
|
self.tensor_op_history.append({
|
|
|
"op_id": op_metadata.op_id,
|
|
|
"type": "matmul",
|
|
|
"input_shapes": [shape_A, shape_B],
|
|
|
"output_shape": C.shape,
|
|
|
"warp_id": warp_id,
|
|
|
"tensor_core_id": tensor_core_id,
|
|
|
"start_time": op_metadata.start_time,
|
|
|
"end_time": time.time_ns(),
|
|
|
"status": "completed"
|
|
|
})
|
|
|
|
|
|
return True
|
|
|
|
|
|
finally:
|
|
|
|
|
|
self.matrix_op_lock.release_matrix_op(op_metadata.op_id)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.matrix_op_scheduler.complete_operation(
|
|
|
op_metadata,
|
|
|
output_shape=None,
|
|
|
success=False,
|
|
|
error=str(e)
|
|
|
)
|
|
|
raise
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Tensor core matmul from memory failed: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
def store_sm_state(self):
|
|
|
"""Store SM state in remote storage"""
|
|
|
with self.state_lock:
|
|
|
|
|
|
state_data = {
|
|
|
"sm_state": self.sm_state,
|
|
|
"timestamp": time.time_ns(),
|
|
|
"chip_id": self.chip_id,
|
|
|
"sm_id": self.sm_id,
|
|
|
"sm_key": self.sm_key
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
|
|
|
success = self.storage.store_state(
|
|
|
component=f"sm_{self.chip_id}_{self.sm_id}",
|
|
|
state_id=self.sm_key,
|
|
|
state_data=state_data
|
|
|
)
|
|
|
|
|
|
if not success:
|
|
|
logging.error(f"Failed to store state for SM {self.sm_id} on chip {self.chip_id}")
|
|
|
return False
|
|
|
|
|
|
|
|
|
self.sm_state["storage_state"]["last_sync"] = time.time_ns()
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error storing SM state: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
def allocate_shared_memory(self, size: int, block_id: str) -> str:
|
|
|
"""Allocate shared memory block in remote storage"""
|
|
|
shared_id = f"shared_{self.chip_id}_{self.sm_id}_{block_id}_{time.time_ns()}"
|
|
|
|
|
|
with self.state_lock:
|
|
|
|
|
|
memory_block = {
|
|
|
"size": size,
|
|
|
"block_id": block_id,
|
|
|
"allocated_at": time.time_ns(),
|
|
|
"sm_key": self.sm_key,
|
|
|
"shared_id": shared_id
|
|
|
}
|
|
|
|
|
|
|
|
|
self.sm_state["shared_memory"][shared_id] = memory_block
|
|
|
|
|
|
try:
|
|
|
|
|
|
empty_tensor = np.zeros(size, dtype=np.float32)
|
|
|
self.storage.store_tensor(shared_id, empty_tensor, {
|
|
|
"sm_key": self.sm_key,
|
|
|
"block_id": block_id,
|
|
|
"allocated_at": time.time_ns(),
|
|
|
"size": size,
|
|
|
"status": "allocated"
|
|
|
})
|
|
|
|
|
|
|
|
|
self.store_sm_state()
|
|
|
return shared_id
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
del self.sm_state["shared_memory"][shared_id]
|
|
|
logging.error(f"Failed to allocate shared memory: {str(e)}")
|
|
|
raise RuntimeError(f"Shared memory allocation failed: {str(e)}")
|
|
|
|
|
|
def write_shared_memory(self, shared_id: str, data: np.ndarray):
|
|
|
"""Write to shared memory using remote storage"""
|
|
|
with self.state_lock:
|
|
|
if shared_id not in self.sm_state["shared_memory"]:
|
|
|
raise ValueError(f"Shared memory block {shared_id} not allocated")
|
|
|
|
|
|
try:
|
|
|
|
|
|
success = self.storage.store_tensor(shared_id, data, {
|
|
|
"sm_key": self.sm_key,
|
|
|
"block_id": self.sm_state["shared_memory"][shared_id]["block_id"],
|
|
|
"last_write": time.time_ns(),
|
|
|
"shape": data.shape,
|
|
|
"dtype": str(data.dtype),
|
|
|
"status": "written"
|
|
|
})
|
|
|
|
|
|
if not success:
|
|
|
raise RuntimeError("Failed to store tensor data")
|
|
|
|
|
|
|
|
|
self.sm_state["shared_memory"][shared_id]["last_accessed"] = time.time_ns()
|
|
|
self.sm_state["shared_memory"][shared_id]["last_write"] = time.time_ns()
|
|
|
self.store_sm_state()
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error writing to shared memory: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
def read_shared_memory(self, shared_id: str) -> Optional[np.ndarray]:
|
|
|
"""Read from shared memory using remote storage"""
|
|
|
with self.state_lock:
|
|
|
if shared_id not in self.sm_state["shared_memory"]:
|
|
|
raise ValueError(f"Shared memory block {shared_id} not allocated")
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = self.storage.load_tensor(shared_id)
|
|
|
|
|
|
if result is not None:
|
|
|
data, metadata = result
|
|
|
|
|
|
self.sm_state["storage_state"]["cache_hits"] += 1
|
|
|
|
|
|
self.sm_state["shared_memory"][shared_id]["last_accessed"] = time.time_ns()
|
|
|
return data
|
|
|
else:
|
|
|
self.sm_state["storage_state"]["cache_misses"] += 1
|
|
|
logging.warning(f"Cache miss for shared memory block {shared_id}")
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error reading from shared memory: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
finally:
|
|
|
|
|
|
self.sm_state["shared_memory"][shared_id]["last_accessed"] = time.time_ns()
|
|
|
self.store_sm_state()
|
|
|
|
|
|
def schedule_warp(self, warp_id: str, warp_state: Dict[str, Any]):
|
|
|
"""Schedule a warp for execution with enhanced state tracking and resource management"""
|
|
|
with self.state_lock:
|
|
|
|
|
|
warp_key = f"warp_{self.chip_id}_{self.sm_id}_{warp_id}_{time.time_ns()}"
|
|
|
|
|
|
try:
|
|
|
|
|
|
resource_state = self._check_warp_resources(warp_id, warp_state)
|
|
|
if not resource_state['available']:
|
|
|
logging.warning(f"Resources not available for warp {warp_id}: {resource_state['reason']}")
|
|
|
self.sm_state["warp_scheduler_state"]["blocked_warps"][warp_id] = {
|
|
|
"reason": resource_state['reason'],
|
|
|
"blocking_resources": resource_state['blocking_resources'],
|
|
|
"timestamp": time.time_ns()
|
|
|
}
|
|
|
return False
|
|
|
|
|
|
|
|
|
dependencies = warp_state.get('dependencies', [])
|
|
|
if dependencies:
|
|
|
for dep_id in dependencies:
|
|
|
if dep_id not in self.sm_state["warp_scheduler_state"]["completed_warps"]:
|
|
|
self.sm_state["warp_scheduler_state"]["warp_dependencies"][warp_id] = dependencies
|
|
|
logging.info(f"Warp {warp_id} waiting for dependencies: {dependencies}")
|
|
|
return False
|
|
|
|
|
|
|
|
|
enhanced_warp_state = {
|
|
|
**warp_state,
|
|
|
"warp_key": warp_key,
|
|
|
"scheduled_at": time.time_ns(),
|
|
|
"resources": resource_state['allocated_resources'],
|
|
|
"priority": warp_state.get('priority', 0),
|
|
|
"expected_duration": warp_state.get('expected_duration'),
|
|
|
"matrix_ops": [],
|
|
|
"sync_points": []
|
|
|
}
|
|
|
|
|
|
|
|
|
success = self.storage.store_state(
|
|
|
component=f"warp_{self.chip_id}_{self.sm_id}",
|
|
|
state_id=warp_key,
|
|
|
state_data={
|
|
|
"warp_id": warp_id,
|
|
|
"warp_state": enhanced_warp_state,
|
|
|
"sm_key": self.sm_key,
|
|
|
"scheduled_at": time.time_ns(),
|
|
|
"status": "scheduled",
|
|
|
"resource_state": resource_state
|
|
|
}
|
|
|
)
|
|
|
|
|
|
if not success:
|
|
|
raise RuntimeError("Failed to store warp state")
|
|
|
|
|
|
|
|
|
self.sm_state["warp_scheduler_state"]["active_warps"].append(warp_id)
|
|
|
self.sm_state["warp_scheduler_state"]["warp_priorities"][warp_id] = enhanced_warp_state["priority"]
|
|
|
|
|
|
|
|
|
self.sm_state["active_warps"][warp_id] = enhanced_warp_state
|
|
|
|
|
|
|
|
|
if warp_id in self.sm_state["warp_scheduler_state"]["blocked_warps"]:
|
|
|
del self.sm_state["warp_scheduler_state"]["blocked_warps"][warp_id]
|
|
|
|
|
|
|
|
|
self.store_sm_state()
|
|
|
logging.info(f"Successfully scheduled warp {warp_id} with priority {enhanced_warp_state['priority']}")
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error scheduling warp {warp_id}: {str(e)}")
|
|
|
|
|
|
if warp_id in self.sm_state["active_warps"]:
|
|
|
del self.sm_state["active_warps"][warp_id]
|
|
|
if warp_id in self.sm_state["warp_scheduler_state"]["active_warps"]:
|
|
|
self.sm_state["warp_scheduler_state"]["active_warps"].remove(warp_id)
|
|
|
if warp_id in self.sm_state["warp_scheduler_state"]["warp_priorities"]:
|
|
|
del self.sm_state["warp_scheduler_state"]["warp_priorities"][warp_id]
|
|
|
return False
|
|
|
|
|
|
def _check_warp_resources(self, warp_id: str, warp_state: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
"""Check and allocate resources for a warp"""
|
|
|
needed_resources = warp_state.get('resource_requirements', {})
|
|
|
|
|
|
|
|
|
if 'tensor_cores' in needed_resources:
|
|
|
num_cores_needed = needed_resources['tensor_cores']
|
|
|
available_cores = self.sm_state["tensor_cores"]["count"] - len(self.sm_state["tensor_cores"]["current_ops"])
|
|
|
if available_cores < num_cores_needed:
|
|
|
return {
|
|
|
'available': False,
|
|
|
'reason': 'insufficient_tensor_cores',
|
|
|
'blocking_resources': {'tensor_cores': num_cores_needed - available_cores}
|
|
|
}
|
|
|
|
|
|
|
|
|
if 'shared_memory' in needed_resources:
|
|
|
memory_needed = needed_resources['shared_memory']
|
|
|
memory_used = sum(self.sm_state["matrix_operations"]["resource_usage"]["shared_memory_usage"].values())
|
|
|
if memory_used + memory_needed > self._get_max_shared_memory():
|
|
|
return {
|
|
|
'available': False,
|
|
|
'reason': 'insufficient_shared_memory',
|
|
|
'blocking_resources': {'shared_memory': memory_needed}
|
|
|
}
|
|
|
|
|
|
|
|
|
allocated_resources = {
|
|
|
'tensor_cores': [],
|
|
|
'shared_memory': 0,
|
|
|
'allocation_time': time.time_ns()
|
|
|
}
|
|
|
|
|
|
return {
|
|
|
'available': True,
|
|
|
'allocated_resources': allocated_resources,
|
|
|
'allocation_id': f"alloc_{warp_id}_{time.time_ns()}"
|
|
|
}
|
|
|
|
|
|
def complete_warp(self, warp_id: str):
|
|
|
"""Mark a warp as completed using remote storage"""
|
|
|
with self.state_lock:
|
|
|
if warp_id in self.sm_state["active_warps"]:
|
|
|
try:
|
|
|
|
|
|
warp_state = self.sm_state["active_warps"][warp_id]
|
|
|
warp_key = warp_state.get("warp_key")
|
|
|
|
|
|
if warp_key:
|
|
|
|
|
|
success = self.storage.store_state(
|
|
|
component=f"warp_{self.chip_id}_{self.sm_id}",
|
|
|
state_id=warp_key,
|
|
|
state_data={
|
|
|
"warp_id": warp_id,
|
|
|
"warp_state": warp_state,
|
|
|
"sm_key": self.sm_key,
|
|
|
"completed_at": time.time_ns(),
|
|
|
"status": "completed"
|
|
|
}
|
|
|
)
|
|
|
|
|
|
if not success:
|
|
|
logging.error(f"Failed to store completed state for warp {warp_id}")
|
|
|
|
|
|
|
|
|
self.sm_state["warp_scheduler_state"]["active_warps"].remove(warp_id)
|
|
|
self.sm_state["warp_scheduler_state"]["completed_warps"].append(warp_id)
|
|
|
self.sm_state["active_warps"].pop(warp_id)
|
|
|
|
|
|
|
|
|
self.store_sm_state()
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error completing warp {warp_id}: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
return False
|
|
|
|
|
|
def write_register(self, warp_id: str, reg_id: str, data: np.ndarray):
|
|
|
"""Write to register file using remote storage"""
|
|
|
reg_key = f"reg_{self.chip_id}_{self.sm_id}_{warp_id}_{reg_id}_{time.time_ns()}"
|
|
|
|
|
|
try:
|
|
|
|
|
|
success = self.storage.store_tensor(reg_key, data, {
|
|
|
"warp_id": warp_id,
|
|
|
"reg_id": reg_id,
|
|
|
"sm_key": self.sm_key,
|
|
|
"chip_id": self.chip_id,
|
|
|
"written_at": time.time_ns(),
|
|
|
"shape": data.shape,
|
|
|
"dtype": str(data.dtype)
|
|
|
})
|
|
|
|
|
|
if success:
|
|
|
|
|
|
self.sm_state["register_file"][reg_key] = {
|
|
|
"warp_id": warp_id,
|
|
|
"reg_id": reg_id,
|
|
|
"last_accessed": time.time_ns(),
|
|
|
"storage_key": reg_key
|
|
|
}
|
|
|
self.store_sm_state()
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error writing to register {reg_id} for warp {warp_id}: {str(e)}")
|
|
|
return False
|
|
|
|
|
|
def read_register(self, warp_id: str, reg_id: str) -> Optional[np.ndarray]:
|
|
|
"""Read from register file using remote storage"""
|
|
|
|
|
|
reg_keys = [k for k in self.sm_state["register_file"].keys()
|
|
|
if k.startswith(f"reg_{self.chip_id}_{self.sm_id}_{warp_id}_{reg_id}")]
|
|
|
|
|
|
if not reg_keys:
|
|
|
return None
|
|
|
|
|
|
|
|
|
latest_key = max(reg_keys, key=lambda k: self.sm_state["register_file"][k]["last_accessed"])
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = self.storage.load_tensor(latest_key)
|
|
|
|
|
|
if result is not None:
|
|
|
data, metadata = result
|
|
|
|
|
|
self.sm_state["register_file"][latest_key]["last_accessed"] = time.time_ns()
|
|
|
self.store_sm_state()
|
|
|
return data
|
|
|
|
|
|
return None
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(f"Error reading register {reg_id} for warp {warp_id}: {str(e)}")
|
|
|
return None
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Get SM statistics"""
|
|
|
return {
|
|
|
"sm_id": self.sm_id,
|
|
|
"num_cores": self.num_cores,
|
|
|
"active_warps": len(self.sm_state["active_warps"]),
|
|
|
"shared_memory_blocks": len(self.sm_state["shared_memory"]),
|
|
|
"register_file_entries": len(self.sm_state["register_file"]),
|
|
|
"completed_warps": len(self.sm_state["warp_scheduler_state"]["completed_warps"])
|
|
|
}
|
|
|
|