""" Base module class for Helium deep learning framework """ from typing import Dict, List, Optional, Union, Any, Tuple from virtual_gpu_driver.src.driver_api import VirtualGPUDriver from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout, Tensor class HeliumModule: """Base class for all neural network modules in Helium""" def __init__(self): self.training = True self.device = None self._parameters = {} self._buffers = {} self._modules = {} def parameters(self) -> Dict[str, Tensor]: """Return all parameters in the module""" params = {} for name, param in self._parameters.items(): params[name] = param for name, module in self._modules.items(): module_params = module.parameters() for param_name, param in module_params.items(): params[f"{name}.{param_name}"] = param return params def to(self, device: Device) -> 'HeliumModule': """Move module to specified device""" self.device = device for param in self._parameters.values(): param.to(device) for buffer in self._buffers.values(): buffer.to(device) for module in self._modules.values(): module.to(device) return self def train(self, mode: bool = True): """Set training mode""" self.training = mode for module in self._modules.values(): module.train(mode) return self def eval(self): """Set evaluation mode""" return self.train(False) def register_parameter(self, name: str, param: Tensor): """Register a parameter with the module""" if name in self._parameters: raise KeyError(f"Parameter {name} already registered") self._parameters[name] = param def register_buffer(self, name: str, buffer: Tensor): """Register a persistent buffer""" if name in self._buffers: raise KeyError(f"Buffer {name} already registered") self._buffers[name] = buffer def add_module(self, name: str, module: 'HeliumModule'): """Add a child module""" if name in self._modules: raise KeyError(f"Module {name} already added") self._modules[name] = module