File size: 8,217 Bytes
89280a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | import torch
import importlib
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.models.qwen3 import modeling_qwen3
# Import other models as needed via conditional imports or a mapping
try:
from lxt.efficient import monkey_patch
except ImportError:
monkey_patch = None
print("Warning: lxt package not available. LRP attribution methods will be disabled.")
import gc
from .factory import get_decomposer
class ModelManager:
"""
Manages model loading, quantization, and patching.
"""
def __init__(self):
self.model = None
self.tokenizer = None
self.model_name = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.decomposer = None
# Track active configuration for reloading
self.current_model_path = None
self.current_dtype = None
self.current_lrp_rule = None
self.current_quantization = False
self.current_revision = None
def load_model(self, model_path="Qwen/Qwen3-0.6B", quantization_4bit=False, dtype="auto", revision=None, lrp_rule=None):
"""
Loads the model and tokenizer, applies monkey patches for LRP if lrp_rule is specified.
lrp_rule: None (no LRP), "Attn-LRP", or "CP-LRP" (Conservative Propagation)
"""
if revision == "" or revision == "null":
revision = None
print(f"Loading model from {model_path} with revision={revision} and rule={lrp_rule}...")
# Store active configuration
self.current_model_path = model_path
self.current_dtype = dtype
self.current_lrp_rule = lrp_rule
self.current_quantization = quantization_4bit
self.current_revision = revision
# Free up memory if reloading
if self.model is not None:
del self.model
del self.tokenizer
torch.cuda.empty_cache()
gc.collect()
self.model_name = model_path.split('/')[-1]
# Initialize Decomposer
self.decomposer = get_decomposer(self.model_name)
# Apply Monkey Patch for Efficient LRP (only if lrp_rule is specified)
if lrp_rule is not None:
if monkey_patch is None:
print("Warning: lxt package not available. Cannot apply LRP patches. Loading model without LRP.")
else:
target_module = None
patch_map = None
lower_path = model_path.lower()
if "qwen3" in lower_path:
importlib.reload(modeling_qwen3) # Reset to original classes to remove previous patches
target_module = modeling_qwen3
try:
import lxt.efficient.models.qwen3 as lxt_qwen3
importlib.reload(lxt_qwen3) # Reload to update class references from new modeling_qwen3
patch_map = lxt_qwen3.cp_LRP if lrp_rule == "CP-LRP" else lxt_qwen3.attnLRP
except ImportError as e:
print(f"Warning: Could not import lxt.efficient.models.qwen3: {e}")
elif "olmo" in lower_path:
try:
from transformers.models.olmo3 import modeling_olmo3
importlib.reload(modeling_olmo3)
target_module = modeling_olmo3
import lxt.efficient.models.olmo3 as lxt_olmo3
importlib.reload(lxt_olmo3)
patch_map = lxt_olmo3.cp_LRP if lrp_rule == "CP-LRP" else lxt_olmo3.attnLRP
except ImportError as e:
print(f"Warning: Could not import modeling_olmo3 or lxt module. LRP might fail. Error: {e}")
elif "qwen2" in lower_path:
try:
from transformers.models.qwen2 import modeling_qwen2
importlib.reload(modeling_qwen2)
target_module = modeling_qwen2
import lxt.efficient.models.qwen2 as lxt_qwen2
importlib.reload(lxt_qwen2)
patch_map = lxt_qwen2.cp_LRP if lrp_rule == "CP-LRP" else lxt_qwen2.attnLRP
except ImportError as e:
print(f"Warning: Could not import qwen2 or lxt: {e}")
if target_module:
if patch_map:
monkey_patch(target_module, patch_map=patch_map, verbose=True)
print(f"Applied LRP patches with rule: {lrp_rule}")
else:
monkey_patch(target_module, verbose=True) # Fallback to default
print("Applied default LRP patches")
else:
# LRP not enabled - reload modules to remove any previous monkey patches
# so that the model is loaded with vanilla (unpatched) classes.
lower_path = model_path.lower()
if "qwen3" in lower_path:
importlib.reload(modeling_qwen3)
print("Reloaded modeling_qwen3 to remove LRP patches")
elif "olmo" in lower_path:
try:
from transformers.models.olmo3 import modeling_olmo3
importlib.reload(modeling_olmo3)
print("Reloaded modeling_olmo3 to remove LRP patches")
except ImportError:
pass
elif "qwen2" in lower_path:
try:
from transformers.models.qwen2 import modeling_qwen2
importlib.reload(modeling_qwen2)
print("Reloaded modeling_qwen2 to remove LRP patches")
except ImportError:
pass
print("LRP not enabled - model loaded without attribution patches")
# Add else if for other models supported by lxt
# Map string dtype to torch dtype
torch_dtype = "auto"
bnb_dtype = torch.bfloat16 # Default for 4bit compute
if dtype == "float16":
torch_dtype = torch.float16
bnb_dtype = torch.float16
elif dtype == "bfloat16":
torch_dtype = torch.bfloat16
bnb_dtype = torch.bfloat16
elif dtype == "float32":
torch_dtype = torch.float32
bnb_dtype = torch.float32
# Quantization Config
quantization_config = None
if quantization_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=bnb_dtype,
)
# Load Model
if "qwen3" in model_path.lower():
self.model = modeling_qwen3.Qwen3ForCausalLM.from_pretrained(
model_path,
device_map=self.device,
torch_dtype=torch_dtype,
quantization_config=quantization_config,
revision=revision
)
else:
# Fallback for generic loading if specific class fails
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map=self.device,
torch_dtype=torch_dtype,
quantization_config=quantization_config,
revision=revision
)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
# Prepare model for LRP
self.model.eval() # Use eval usually, but test.ipynb uses train() + gradients
# test.ipynb: model.train(), gradient_checkpointing_enable(), requires_grad=False
# "model.train()" is often needed for Gradient Checkpointing to work in HF
self.model.train()
self.model.gradient_checkpointing_enable()
# Deactivate gradients on parameters
for param in self.model.parameters():
param.requires_grad = False
print(f"Model {self.model_name} loaded successfully on {self.device}")
return self.model_name
def get_model(self):
return self.model
def get_tokenizer(self):
return self.tokenizer
|