"""Utility functions for handling model attributes with DataParallel""" import torch.nn as nn def get_model_attr(model, attr_name): """ Get attribute from model, handling DataParallel wrapper Args: model: Model (possibly wrapped in DataParallel) attr_name: Attribute name to get Returns: Attribute value """ if isinstance(model, nn.DataParallel): return getattr(model.module, attr_name) else: return getattr(model, attr_name) def set_model_attr(model, attr_name, value): """ Set attribute on model, handling DataParallel wrapper Args: model: Model (possibly wrapped in DataParallel) attr_name: Attribute name to set value: Value to set """ if isinstance(model, nn.DataParallel): setattr(model.module, attr_name, value) else: setattr(model, attr_name, value) def call_model_method(model, method_name, *args, **kwargs): """ Call method on model, handling DataParallel wrapper Args: model: Model (possibly wrapped in DataParallel) method_name: Method name to call *args, **kwargs: Arguments to pass to method Returns: Method return value """ if isinstance(model, nn.DataParallel): return getattr(model.module, method_name)(*args, **kwargs) else: return getattr(model, method_name)(*args, **kwargs)