File size: 1,439 Bytes
8e263cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Utility functions for handling model attributes with DataParallel"""

import torch.nn as nn

def get_model_attr(model, attr_name):
    """
    Get attribute from model, handling DataParallel wrapper
    
    Args:
        model: Model (possibly wrapped in DataParallel)
        attr_name: Attribute name to get
        
    Returns:
        Attribute value
    """
    if isinstance(model, nn.DataParallel):
        return getattr(model.module, attr_name)
    else:
        return getattr(model, attr_name)

def set_model_attr(model, attr_name, value):
    """
    Set attribute on model, handling DataParallel wrapper
    
    Args:
        model: Model (possibly wrapped in DataParallel)
        attr_name: Attribute name to set
        value: Value to set
    """
    if isinstance(model, nn.DataParallel):
        setattr(model.module, attr_name, value)
    else:
        setattr(model, attr_name, value)

def call_model_method(model, method_name, *args, **kwargs):
    """
    Call method on model, handling DataParallel wrapper
    
    Args:
        model: Model (possibly wrapped in DataParallel)
        method_name: Method name to call
        *args, **kwargs: Arguments to pass to method
        
    Returns:
        Method return value
    """
    if isinstance(model, nn.DataParallel):
        return getattr(model.module, method_name)(*args, **kwargs)
    else:
        return getattr(model, method_name)(*args, **kwargs)