""" Matrix Operations Scheduler for GPU Processing """ from typing import Dict, Any, List, Optional import numpy as np import time class MatrixOpMetadata: def __init__(self, op_type: str, input_shape: tuple, output_shape: tuple): self.op_type = op_type self.input_shape = input_shape self.output_shape = output_shape self.timestamp = time.time() self.compute_cycles = 0 self.memory_accesses = 0 def estimate_compute_cycles(self) -> int: """Estimate number of compute cycles needed based on operation type and shapes""" if self.op_type == "matmul": m, n = self.input_shape[0], self.output_shape[1] k = self.input_shape[1] # Inner dimension return m * n * k # One cycle per multiply-add elif self.op_type in ["add", "sub", "mul", "div"]: elements = np.prod(self.input_shape) return elements # One cycle per element return 0 def estimate_memory_accesses(self) -> int: """Estimate number of memory accesses needed""" if self.op_type == "matmul": m, n = self.input_shape[0], self.output_shape[1] k = self.input_shape[1] # Read each input element once, write each output once return m*k + k*n + m*n elif self.op_type in ["add", "sub", "mul", "div"]: elements = np.prod(self.input_shape) return elements * 2 # Read input + write output return 0 class MatrixOpScheduler: def __init__(self): self.pending_ops: List[MatrixOpMetadata] = [] self.completed_ops: List[MatrixOpMetadata] = [] self.current_op: Optional[MatrixOpMetadata] = None self.stats = { "total_compute_cycles": 0, "total_memory_accesses": 0, "ops_completed": 0 } def schedule_op(self, op: MatrixOpMetadata) -> None: """Add a new matrix operation to the scheduler queue""" self.pending_ops.append(op) op.compute_cycles = op.estimate_compute_cycles() op.memory_accesses = op.estimate_memory_accesses() def get_next_op(self) -> Optional[MatrixOpMetadata]: """Get the next operation to process""" if not self.pending_ops: return None return self.pending_ops.pop(0) def complete_current_op(self) -> None: """Mark the current operation as complete and update stats""" if self.current_op: self.completed_ops.append(self.current_op) self.stats["total_compute_cycles"] += self.current_op.compute_cycles self.stats["total_memory_accesses"] += self.current_op.memory_accesses self.stats["ops_completed"] += 1 self.current_op = None def get_stats(self) -> Dict[str, Any]: """Get current scheduler statistics""" return { **self.stats, "pending_ops": len(self.pending_ops), "completed_ops": len(self.completed_ops) }