INV / matrix_ops.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Matrix Operations Scheduler for GPU Processing
"""
from typing import Dict, Any, List, Optional
import numpy as np
import time
class MatrixOpMetadata:
def __init__(self, op_type: str, input_shape: tuple, output_shape: tuple):
self.op_type = op_type
self.input_shape = input_shape
self.output_shape = output_shape
self.timestamp = time.time()
self.compute_cycles = 0
self.memory_accesses = 0
def estimate_compute_cycles(self) -> int:
"""Estimate number of compute cycles needed based on operation type and shapes"""
if self.op_type == "matmul":
m, n = self.input_shape[0], self.output_shape[1]
k = self.input_shape[1] # Inner dimension
return m * n * k # One cycle per multiply-add
elif self.op_type in ["add", "sub", "mul", "div"]:
elements = np.prod(self.input_shape)
return elements # One cycle per element
return 0
def estimate_memory_accesses(self) -> int:
"""Estimate number of memory accesses needed"""
if self.op_type == "matmul":
m, n = self.input_shape[0], self.output_shape[1]
k = self.input_shape[1]
# Read each input element once, write each output once
return m*k + k*n + m*n
elif self.op_type in ["add", "sub", "mul", "div"]:
elements = np.prod(self.input_shape)
return elements * 2 # Read input + write output
return 0
class MatrixOpScheduler:
def __init__(self):
self.pending_ops: List[MatrixOpMetadata] = []
self.completed_ops: List[MatrixOpMetadata] = []
self.current_op: Optional[MatrixOpMetadata] = None
self.stats = {
"total_compute_cycles": 0,
"total_memory_accesses": 0,
"ops_completed": 0
}
def schedule_op(self, op: MatrixOpMetadata) -> None:
"""Add a new matrix operation to the scheduler queue"""
self.pending_ops.append(op)
op.compute_cycles = op.estimate_compute_cycles()
op.memory_accesses = op.estimate_memory_accesses()
def get_next_op(self) -> Optional[MatrixOpMetadata]:
"""Get the next operation to process"""
if not self.pending_ops:
return None
return self.pending_ops.pop(0)
def complete_current_op(self) -> None:
"""Mark the current operation as complete and update stats"""
if self.current_op:
self.completed_ops.append(self.current_op)
self.stats["total_compute_cycles"] += self.current_op.compute_cycles
self.stats["total_memory_accesses"] += self.current_op.memory_accesses
self.stats["ops_completed"] += 1
self.current_op = None
def get_stats(self) -> Dict[str, Any]:
"""Get current scheduler statistics"""
return {
**self.stats,
"pending_ops": len(self.pending_ops),
"completed_ops": len(self.completed_ops)
}