| import torch |
| import importlib |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| from transformers.models.qwen3 import modeling_qwen3 |
| |
| try: |
| from lxt.efficient import monkey_patch |
| except ImportError: |
| monkey_patch = None |
| print("Warning: lxt package not available. LRP attribution methods will be disabled.") |
| import gc |
| from .factory import get_decomposer |
|
|
| class ModelManager: |
| """ |
| Manages model loading, quantization, and patching. |
| """ |
| def __init__(self): |
| self.model = None |
| self.tokenizer = None |
| self.model_name = None |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.decomposer = None |
|
|
| |
| self.current_model_path = None |
| self.current_dtype = None |
| self.current_lrp_rule = None |
| self.current_quantization = False |
| self.current_revision = None |
|
|
| def load_model(self, model_path="Qwen/Qwen3-0.6B", quantization_4bit=False, dtype="auto", revision=None, lrp_rule=None): |
| """ |
| Loads the model and tokenizer, applies monkey patches for LRP if lrp_rule is specified. |
| lrp_rule: None (no LRP), "Attn-LRP", or "CP-LRP" (Conservative Propagation) |
| """ |
| if revision == "" or revision == "null": |
| revision = None |
|
|
| print(f"Loading model from {model_path} with revision={revision} and rule={lrp_rule}...") |
|
|
| |
| self.current_model_path = model_path |
| self.current_dtype = dtype |
| self.current_lrp_rule = lrp_rule |
| self.current_quantization = quantization_4bit |
| self.current_revision = revision |
|
|
| |
| if self.model is not None: |
| del self.model |
| del self.tokenizer |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| self.model_name = model_path.split('/')[-1] |
|
|
| |
| self.decomposer = get_decomposer(self.model_name) |
|
|
| |
| if lrp_rule is not None: |
| if monkey_patch is None: |
| print("Warning: lxt package not available. Cannot apply LRP patches. Loading model without LRP.") |
| else: |
| target_module = None |
| patch_map = None |
|
|
| lower_path = model_path.lower() |
| if "qwen3" in lower_path: |
| importlib.reload(modeling_qwen3) |
| target_module = modeling_qwen3 |
| try: |
| import lxt.efficient.models.qwen3 as lxt_qwen3 |
| importlib.reload(lxt_qwen3) |
| patch_map = lxt_qwen3.cp_LRP if lrp_rule == "CP-LRP" else lxt_qwen3.attnLRP |
| except ImportError as e: |
| print(f"Warning: Could not import lxt.efficient.models.qwen3: {e}") |
|
|
| elif "olmo" in lower_path: |
| try: |
| from transformers.models.olmo3 import modeling_olmo3 |
| importlib.reload(modeling_olmo3) |
| target_module = modeling_olmo3 |
| import lxt.efficient.models.olmo3 as lxt_olmo3 |
| importlib.reload(lxt_olmo3) |
| patch_map = lxt_olmo3.cp_LRP if lrp_rule == "CP-LRP" else lxt_olmo3.attnLRP |
| except ImportError as e: |
| print(f"Warning: Could not import modeling_olmo3 or lxt module. LRP might fail. Error: {e}") |
|
|
| elif "qwen2" in lower_path: |
| try: |
| from transformers.models.qwen2 import modeling_qwen2 |
| importlib.reload(modeling_qwen2) |
| target_module = modeling_qwen2 |
| import lxt.efficient.models.qwen2 as lxt_qwen2 |
| importlib.reload(lxt_qwen2) |
| patch_map = lxt_qwen2.cp_LRP if lrp_rule == "CP-LRP" else lxt_qwen2.attnLRP |
| except ImportError as e: |
| print(f"Warning: Could not import qwen2 or lxt: {e}") |
|
|
| if target_module: |
| if patch_map: |
| monkey_patch(target_module, patch_map=patch_map, verbose=True) |
| print(f"Applied LRP patches with rule: {lrp_rule}") |
| else: |
| monkey_patch(target_module, verbose=True) |
| print("Applied default LRP patches") |
| else: |
| |
| |
| lower_path = model_path.lower() |
| if "qwen3" in lower_path: |
| importlib.reload(modeling_qwen3) |
| print("Reloaded modeling_qwen3 to remove LRP patches") |
| elif "olmo" in lower_path: |
| try: |
| from transformers.models.olmo3 import modeling_olmo3 |
| importlib.reload(modeling_olmo3) |
| print("Reloaded modeling_olmo3 to remove LRP patches") |
| except ImportError: |
| pass |
| elif "qwen2" in lower_path: |
| try: |
| from transformers.models.qwen2 import modeling_qwen2 |
| importlib.reload(modeling_qwen2) |
| print("Reloaded modeling_qwen2 to remove LRP patches") |
| except ImportError: |
| pass |
| print("LRP not enabled - model loaded without attribution patches") |
|
|
| |
|
|
| |
| torch_dtype = "auto" |
| bnb_dtype = torch.bfloat16 |
|
|
| if dtype == "float16": |
| torch_dtype = torch.float16 |
| bnb_dtype = torch.float16 |
| elif dtype == "bfloat16": |
| torch_dtype = torch.bfloat16 |
| bnb_dtype = torch.bfloat16 |
| elif dtype == "float32": |
| torch_dtype = torch.float32 |
| bnb_dtype = torch.float32 |
|
|
| |
| quantization_config = None |
| if quantization_4bit: |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=bnb_dtype, |
| ) |
|
|
| |
| if "qwen3" in model_path.lower(): |
| self.model = modeling_qwen3.Qwen3ForCausalLM.from_pretrained( |
| model_path, |
| device_map=self.device, |
| torch_dtype=torch_dtype, |
| quantization_config=quantization_config, |
| revision=revision |
| ) |
| else: |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| device_map=self.device, |
| torch_dtype=torch_dtype, |
| quantization_config=quantization_config, |
| revision=revision |
| ) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) |
|
|
| |
| self.model.eval() |
| |
|
|
| |
| self.model.train() |
| self.model.gradient_checkpointing_enable() |
|
|
| |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| print(f"Model {self.model_name} loaded successfully on {self.device}") |
| return self.model_name |
|
|
| def get_model(self): |
| return self.model |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|