| """Utility functions for handling model attributes with DataParallel""" | |
| import torch.nn as nn | |
| def get_model_attr(model, attr_name): | |
| """ | |
| Get attribute from model, handling DataParallel wrapper | |
| Args: | |
| model: Model (possibly wrapped in DataParallel) | |
| attr_name: Attribute name to get | |
| Returns: | |
| Attribute value | |
| """ | |
| if isinstance(model, nn.DataParallel): | |
| return getattr(model.module, attr_name) | |
| else: | |
| return getattr(model, attr_name) | |
| def set_model_attr(model, attr_name, value): | |
| """ | |
| Set attribute on model, handling DataParallel wrapper | |
| Args: | |
| model: Model (possibly wrapped in DataParallel) | |
| attr_name: Attribute name to set | |
| value: Value to set | |
| """ | |
| if isinstance(model, nn.DataParallel): | |
| setattr(model.module, attr_name, value) | |
| else: | |
| setattr(model, attr_name, value) | |
| def call_model_method(model, method_name, *args, **kwargs): | |
| """ | |
| Call method on model, handling DataParallel wrapper | |
| Args: | |
| model: Model (possibly wrapped in DataParallel) | |
| method_name: Method name to call | |
| *args, **kwargs: Arguments to pass to method | |
| Returns: | |
| Method return value | |
| """ | |
| if isinstance(model, nn.DataParallel): | |
| return getattr(model.module, method_name)(*args, **kwargs) | |
| else: | |
| return getattr(model, method_name)(*args, **kwargs) |