|
|
"""
|
|
|
Command processor for handling GPU commands including thread management.
|
|
|
"""
|
|
|
from typing import Dict, Any, List
|
|
|
from threading import Lock
|
|
|
import time
|
|
|
|
|
|
class CommandProcessor:
|
|
|
def __init__(self, hal, memory_manager):
|
|
|
self.hal = hal
|
|
|
self.memory_manager = memory_manager
|
|
|
self.command_queue = []
|
|
|
self.queue_lock = Lock()
|
|
|
|
|
|
def add_command(self, command_type: str, **kwargs):
|
|
|
"""Add a command to the queue"""
|
|
|
with self.queue_lock:
|
|
|
self.command_queue.append({
|
|
|
"type": command_type,
|
|
|
"params": kwargs,
|
|
|
"timestamp": time.time_ns()
|
|
|
})
|
|
|
|
|
|
def clear_commands(self):
|
|
|
"""Clear all pending commands"""
|
|
|
with self.queue_lock:
|
|
|
self.command_queue.clear()
|
|
|
|
|
|
def submit_commands(self, chip_id: int = 0):
|
|
|
"""Submit and execute all queued commands"""
|
|
|
results = []
|
|
|
with self.queue_lock:
|
|
|
for cmd in self.command_queue:
|
|
|
if cmd["type"] == "execute_kernel":
|
|
|
result = self._execute_kernel_command(cmd["params"])
|
|
|
elif cmd["type"] == "block_barrier":
|
|
|
result = self._handle_block_barrier(cmd["params"])
|
|
|
elif cmd["type"] == "core_barrier":
|
|
|
result = self._handle_core_barrier(cmd["params"])
|
|
|
elif cmd["type"] == "matmul":
|
|
|
result = self._handle_matmul(cmd["params"])
|
|
|
elif cmd["type"] == "global_barrier":
|
|
|
result = self._handle_global_barrier(cmd["params"])
|
|
|
else:
|
|
|
result = {"status": "error", "message": f"Unknown command type: {cmd['type']}"}
|
|
|
|
|
|
results.append(result)
|
|
|
|
|
|
self.command_queue.clear()
|
|
|
return results
|
|
|
|
|
|
def _execute_kernel_command(self, params: Dict[str, Any]):
|
|
|
"""Execute a kernel across thread blocks"""
|
|
|
try:
|
|
|
chip_id = params["chip_id"]
|
|
|
sm_id = params["sm_id"]
|
|
|
core_id = params["core_id"]
|
|
|
thread_config = params["thread_block_config"]
|
|
|
kernel_func = params["kernel_func"]
|
|
|
args = params.get("args", [])
|
|
|
kwargs = params.get("kwargs", {})
|
|
|
|
|
|
|
|
|
blocks_per_grid = (
|
|
|
thread_config['grid_dim'][0] *
|
|
|
thread_config['grid_dim'][1] *
|
|
|
thread_config['grid_dim'][2]
|
|
|
)
|
|
|
|
|
|
threads_per_block = (
|
|
|
thread_config['block_dim'][0] *
|
|
|
thread_config['block_dim'][1] *
|
|
|
thread_config['block_dim'][2]
|
|
|
)
|
|
|
|
|
|
|
|
|
blocks = []
|
|
|
for block_idx in range(blocks_per_grid):
|
|
|
block = {
|
|
|
'id': block_idx,
|
|
|
'threads': threads_per_block,
|
|
|
'shared_memory_size': thread_config['shared_memory_size'],
|
|
|
'results': []
|
|
|
}
|
|
|
blocks.append(block)
|
|
|
|
|
|
|
|
|
for block in blocks:
|
|
|
|
|
|
for thread_idx in range(block['threads']):
|
|
|
thread_id = block['id'] * block['threads'] + thread_idx
|
|
|
try:
|
|
|
result = kernel_func(
|
|
|
thread_id=thread_id,
|
|
|
block_id=block['id'],
|
|
|
*args,
|
|
|
**kwargs
|
|
|
)
|
|
|
block['results'].append({
|
|
|
'thread_id': thread_id,
|
|
|
'result': result,
|
|
|
'status': 'success'
|
|
|
})
|
|
|
except Exception as e:
|
|
|
block['results'].append({
|
|
|
'thread_id': thread_id,
|
|
|
'error': str(e),
|
|
|
'status': 'error'
|
|
|
})
|
|
|
|
|
|
return {
|
|
|
'status': 'success',
|
|
|
'blocks_executed': len(blocks),
|
|
|
'total_threads': blocks_per_grid * threads_per_block,
|
|
|
'results': [b['results'] for b in blocks]
|
|
|
}
|
|
|
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
'status': 'error',
|
|
|
'message': f'Kernel execution failed: {str(e)}'
|
|
|
}
|
|
|
|
|
|
def _handle_block_barrier(self, params: Dict[str, Any]):
|
|
|
"""Handle block-level thread synchronization"""
|
|
|
try:
|
|
|
chip_id = params["chip_id"]
|
|
|
sm_id = params["sm_id"]
|
|
|
core_id = params["core_id"]
|
|
|
block_id = params["block_id"]
|
|
|
|
|
|
|
|
|
self.hal.block_barrier(chip_id, sm_id, core_id, block_id)
|
|
|
|
|
|
return {
|
|
|
'status': 'success',
|
|
|
'message': f'Block barrier completed for block {block_id}'
|
|
|
}
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
'status': 'error',
|
|
|
'message': f'Block barrier failed: {str(e)}'
|
|
|
}
|
|
|
|
|
|
def _handle_core_barrier(self, params: Dict[str, Any]):
|
|
|
"""Handle core-level thread synchronization"""
|
|
|
try:
|
|
|
chip_id = params["chip_id"]
|
|
|
sm_id = params["sm_id"]
|
|
|
core_id = params["core_id"]
|
|
|
|
|
|
|
|
|
self.hal.core_barrier(chip_id, sm_id, core_id)
|
|
|
|
|
|
return {
|
|
|
'status': 'success',
|
|
|
'message': f'Core barrier completed for core {core_id}'
|
|
|
}
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
'status': 'error',
|
|
|
'message': f'Core barrier failed: {str(e)}'
|
|
|
}
|
|
|
|
|
|
def _handle_matmul(self, params: Dict[str, Any]):
|
|
|
"""Handle matrix multiplication command"""
|
|
|
try:
|
|
|
return self.hal.matmul(
|
|
|
params["chip_id"],
|
|
|
params["sm_id"],
|
|
|
params["A"],
|
|
|
params["B"]
|
|
|
)
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
'status': 'error',
|
|
|
'message': f'Matrix multiplication failed: {str(e)}'
|
|
|
}
|
|
|
|
|
|
def _handle_global_barrier(self, params: Dict[str, Any]):
|
|
|
"""Handle global synchronization across all threads"""
|
|
|
try:
|
|
|
chip_id = params["chip_id"]
|
|
|
|
|
|
|
|
|
self.hal.global_barrier(chip_id)
|
|
|
|
|
|
return {
|
|
|
'status': 'success',
|
|
|
'message': f'Global barrier completed for chip {chip_id}'
|
|
|
}
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
'status': 'error',
|
|
|
'message': f'Global barrier failed: {str(e)}'
|
|
|
}
|
|
|
|