|
|
"""
|
|
|
Matrix Operations Scheduler for GPU Processing
|
|
|
"""
|
|
|
from typing import Dict, Any, List, Optional
|
|
|
import numpy as np
|
|
|
import time
|
|
|
|
|
|
|
|
|
class MatrixOpMetadata:
|
|
|
def __init__(self, op_type: str, input_shape: tuple, output_shape: tuple):
|
|
|
self.op_type = op_type
|
|
|
self.input_shape = input_shape
|
|
|
self.output_shape = output_shape
|
|
|
self.timestamp = time.time()
|
|
|
self.compute_cycles = 0
|
|
|
self.memory_accesses = 0
|
|
|
|
|
|
def estimate_compute_cycles(self) -> int:
|
|
|
"""Estimate number of compute cycles needed based on operation type and shapes"""
|
|
|
if self.op_type == "matmul":
|
|
|
m, n = self.input_shape[0], self.output_shape[1]
|
|
|
k = self.input_shape[1]
|
|
|
return m * n * k
|
|
|
elif self.op_type in ["add", "sub", "mul", "div"]:
|
|
|
elements = np.prod(self.input_shape)
|
|
|
return elements
|
|
|
return 0
|
|
|
|
|
|
def estimate_memory_accesses(self) -> int:
|
|
|
"""Estimate number of memory accesses needed"""
|
|
|
if self.op_type == "matmul":
|
|
|
m, n = self.input_shape[0], self.output_shape[1]
|
|
|
k = self.input_shape[1]
|
|
|
|
|
|
return m*k + k*n + m*n
|
|
|
elif self.op_type in ["add", "sub", "mul", "div"]:
|
|
|
elements = np.prod(self.input_shape)
|
|
|
return elements * 2
|
|
|
return 0
|
|
|
|
|
|
|
|
|
class MatrixOpScheduler:
|
|
|
def __init__(self):
|
|
|
self.pending_ops: List[MatrixOpMetadata] = []
|
|
|
self.completed_ops: List[MatrixOpMetadata] = []
|
|
|
self.current_op: Optional[MatrixOpMetadata] = None
|
|
|
self.stats = {
|
|
|
"total_compute_cycles": 0,
|
|
|
"total_memory_accesses": 0,
|
|
|
"ops_completed": 0
|
|
|
}
|
|
|
|
|
|
def schedule_op(self, op: MatrixOpMetadata) -> None:
|
|
|
"""Add a new matrix operation to the scheduler queue"""
|
|
|
self.pending_ops.append(op)
|
|
|
op.compute_cycles = op.estimate_compute_cycles()
|
|
|
op.memory_accesses = op.estimate_memory_accesses()
|
|
|
|
|
|
def get_next_op(self) -> Optional[MatrixOpMetadata]:
|
|
|
"""Get the next operation to process"""
|
|
|
if not self.pending_ops:
|
|
|
return None
|
|
|
return self.pending_ops.pop(0)
|
|
|
|
|
|
def complete_current_op(self) -> None:
|
|
|
"""Mark the current operation as complete and update stats"""
|
|
|
if self.current_op:
|
|
|
self.completed_ops.append(self.current_op)
|
|
|
self.stats["total_compute_cycles"] += self.current_op.compute_cycles
|
|
|
self.stats["total_memory_accesses"] += self.current_op.memory_accesses
|
|
|
self.stats["ops_completed"] += 1
|
|
|
self.current_op = None
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Get current scheduler statistics"""
|
|
|
return {
|
|
|
**self.stats,
|
|
|
"pending_ops": len(self.pending_ops),
|
|
|
"completed_ops": len(self.completed_ops)
|
|
|
}
|
|
|
|