File size: 8,217 Bytes
89280a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import torch
import importlib
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.models.qwen3 import modeling_qwen3
# Import other models as needed via conditional imports or a mapping
try:
    from lxt.efficient import monkey_patch
except ImportError:
    monkey_patch = None
    print("Warning: lxt package not available. LRP attribution methods will be disabled.")
import gc
from .factory import get_decomposer

class ModelManager:
    """
    Manages model loading, quantization, and patching.
    """
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_name = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.decomposer = None

        # Track active configuration for reloading
        self.current_model_path = None
        self.current_dtype = None
        self.current_lrp_rule = None
        self.current_quantization = False
        self.current_revision = None

    def load_model(self, model_path="Qwen/Qwen3-0.6B", quantization_4bit=False, dtype="auto", revision=None, lrp_rule=None):
        """
        Loads the model and tokenizer, applies monkey patches for LRP if lrp_rule is specified.
        lrp_rule: None (no LRP), "Attn-LRP", or "CP-LRP" (Conservative Propagation)
        """
        if revision == "" or revision == "null":
            revision = None

        print(f"Loading model from {model_path} with revision={revision} and rule={lrp_rule}...")

        # Store active configuration
        self.current_model_path = model_path
        self.current_dtype = dtype
        self.current_lrp_rule = lrp_rule
        self.current_quantization = quantization_4bit
        self.current_revision = revision

        # Free up memory if reloading
        if self.model is not None:
            del self.model
            del self.tokenizer
            torch.cuda.empty_cache()
            gc.collect()

        self.model_name = model_path.split('/')[-1]

        # Initialize Decomposer
        self.decomposer = get_decomposer(self.model_name)

        # Apply Monkey Patch for Efficient LRP (only if lrp_rule is specified)
        if lrp_rule is not None:
            if monkey_patch is None:
                print("Warning: lxt package not available. Cannot apply LRP patches. Loading model without LRP.")
            else:
                target_module = None
                patch_map = None

                lower_path = model_path.lower()
                if "qwen3" in lower_path:
                    importlib.reload(modeling_qwen3) # Reset to original classes to remove previous patches
                    target_module = modeling_qwen3
                    try:
                        import lxt.efficient.models.qwen3 as lxt_qwen3
                        importlib.reload(lxt_qwen3) # Reload to update class references from new modeling_qwen3
                        patch_map = lxt_qwen3.cp_LRP if lrp_rule == "CP-LRP" else lxt_qwen3.attnLRP
                    except ImportError as e:
                        print(f"Warning: Could not import lxt.efficient.models.qwen3: {e}")

                elif "olmo" in lower_path:
                    try:
                        from transformers.models.olmo3 import modeling_olmo3
                        importlib.reload(modeling_olmo3)
                        target_module = modeling_olmo3
                        import lxt.efficient.models.olmo3 as lxt_olmo3
                        importlib.reload(lxt_olmo3)
                        patch_map = lxt_olmo3.cp_LRP if lrp_rule == "CP-LRP" else lxt_olmo3.attnLRP
                    except ImportError as e:
                        print(f"Warning: Could not import modeling_olmo3 or lxt module. LRP might fail. Error: {e}")

                elif "qwen2" in lower_path:
                     try:
                         from transformers.models.qwen2 import modeling_qwen2
                         importlib.reload(modeling_qwen2)
                         target_module = modeling_qwen2
                         import lxt.efficient.models.qwen2 as lxt_qwen2
                         importlib.reload(lxt_qwen2)
                         patch_map = lxt_qwen2.cp_LRP if lrp_rule == "CP-LRP" else lxt_qwen2.attnLRP
                     except ImportError as e:
                         print(f"Warning: Could not import qwen2 or lxt: {e}")

                if target_module:
                    if patch_map:
                        monkey_patch(target_module, patch_map=patch_map, verbose=True)
                        print(f"Applied LRP patches with rule: {lrp_rule}")
                    else:
                        monkey_patch(target_module, verbose=True) # Fallback to default
                        print("Applied default LRP patches")
        else:
            # LRP not enabled - reload modules to remove any previous monkey patches
            # so that the model is loaded with vanilla (unpatched) classes.
            lower_path = model_path.lower()
            if "qwen3" in lower_path:
                importlib.reload(modeling_qwen3)
                print("Reloaded modeling_qwen3 to remove LRP patches")
            elif "olmo" in lower_path:
                try:
                    from transformers.models.olmo3 import modeling_olmo3
                    importlib.reload(modeling_olmo3)
                    print("Reloaded modeling_olmo3 to remove LRP patches")
                except ImportError:
                    pass
            elif "qwen2" in lower_path:
                try:
                    from transformers.models.qwen2 import modeling_qwen2
                    importlib.reload(modeling_qwen2)
                    print("Reloaded modeling_qwen2 to remove LRP patches")
                except ImportError:
                    pass
            print("LRP not enabled - model loaded without attribution patches")

        # Add else if for other models supported by lxt

        # Map string dtype to torch dtype
        torch_dtype = "auto"
        bnb_dtype = torch.bfloat16 # Default for 4bit compute

        if dtype == "float16":
            torch_dtype = torch.float16
            bnb_dtype = torch.float16
        elif dtype == "bfloat16":
            torch_dtype = torch.bfloat16
            bnb_dtype = torch.bfloat16
        elif dtype == "float32":
            torch_dtype = torch.float32
            bnb_dtype = torch.float32

        # Quantization Config
        quantization_config = None
        if quantization_4bit:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=bnb_dtype,
            )

        # Load Model
        if "qwen3" in model_path.lower():
            self.model = modeling_qwen3.Qwen3ForCausalLM.from_pretrained(
                model_path,
                device_map=self.device,
                torch_dtype=torch_dtype,
                quantization_config=quantization_config,
                revision=revision
            )
        else:
             # Fallback for generic loading if specific class fails
             self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                device_map=self.device,
                torch_dtype=torch_dtype,
                quantization_config=quantization_config,
                revision=revision
            )

        self.tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)

        # Prepare model for LRP
        self.model.eval() # Use eval usually, but test.ipynb uses train() + gradients
        # test.ipynb: model.train(), gradient_checkpointing_enable(), requires_grad=False

        # "model.train()" is often needed for Gradient Checkpointing to work in HF
        self.model.train()
        self.model.gradient_checkpointing_enable()

        # Deactivate gradients on parameters
        for param in self.model.parameters():
            param.requires_grad = False

        print(f"Model {self.model_name} loaded successfully on {self.device}")
        return self.model_name

    def get_model(self):
        return self.model

    def get_tokenizer(self):
        return self.tokenizer