""" Loads and wraps the base pretrained model. Supported architectures: - google/flan-t5-xl (recommended, 3B) - google/flan-t5-large (780M, resource-constrained) - facebook/bart-large (400M, excellent denoiser) - meta-llama/Meta-Llama-3.1-8B-Instruct (8B, best quality) """ from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, BitsAndBytesConfig ) from peft import get_peft_model, LoraConfig, TaskType import torch from loguru import logger ENCODER_DECODER_MODELS = { "flan-t5-xl": "google/flan-t5-xl", "flan-t5-large": "google/flan-t5-large", "flan-t5-base": "google/flan-t5-base", "flan-t5-small": "google/flan-t5-small", "bart-large": "facebook/bart-large", } DECODER_ONLY_MODELS = { "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct", } def load_model_and_tokenizer(model_key: str, quantize: bool = False, use_lora: bool = True, lora_config_dict: dict = None): """ Load a pretrained model with optional LoRA and quantization. Args: model_key: Key from ENCODER_DECODER_MODELS or DECODER_ONLY_MODELS quantize: Whether to use 4-bit quantization use_lora: Whether to apply LoRA adapters lora_config_dict: Optional dict with LoRA hyperparams (r, lora_alpha, etc.) Returns: Tuple of (model, tokenizer, is_seq2seq) """ # Determine model type and HuggingFace identifier is_seq2seq = model_key in ENCODER_DECODER_MODELS is_causal = model_key in DECODER_ONLY_MODELS if not is_seq2seq and not is_causal: raise ValueError( f"Unknown model key: '{model_key}'. " f"Available: {list(ENCODER_DECODER_MODELS.keys()) + list(DECODER_ONLY_MODELS.keys())}" ) model_name = ENCODER_DECODER_MODELS.get(model_key) or DECODER_ONLY_MODELS.get(model_key) logger.info(f"Loading model: {model_name} (seq2seq={is_seq2seq}, quantize={quantize}, lora={use_lora})") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Configure quantization if requested model_kwargs = { "torch_dtype": torch.float32, # CPU-optimised: use float32 for stability } # Detect device device = "cpu" if torch.cuda.is_available(): device = "cuda" # Use bfloat16 if Ampere+, else float16 if torch.cuda.get_device_capability()[0] >= 8: model_kwargs["torch_dtype"] = torch.bfloat16 else: model_kwargs["torch_dtype"] = torch.float16 if quantize and device == "cuda": bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=model_kwargs["torch_dtype"], bnb_4bit_use_double_quant=True, ) model_kwargs["quantization_config"] = bnb_config logger.info("Using 4-bit NF4 quantization") elif quantize and device == "cpu": logger.warning("Quantization requested but no GPU available, skipping") # Load model if is_seq2seq: model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs) else: model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) # Move to device if not quantized (quantized models are already on device) if not quantize or device == "cpu": model = model.to(device) logger.info(f"Model loaded on {device} with dtype {model_kwargs.get('torch_dtype')}") # Apply LoRA if requested if use_lora: lora_cfg = lora_config_dict or {} task_type = TaskType.SEQ_2_SEQ_LM if is_seq2seq else TaskType.CAUSAL_LM # Default target modules based on model architecture default_targets = { "flan-t5-xl": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], "flan-t5-large": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], "flan-t5-base": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], "flan-t5-small": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], "bart-large": ["q_proj", "v_proj", "k_proj", "out_proj"], "llama-3.1-8b": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], } lora_config = LoraConfig( task_type=task_type, r=lora_cfg.get("r", 16), lora_alpha=lora_cfg.get("lora_alpha", 32), lora_dropout=lora_cfg.get("lora_dropout", 0.05), target_modules=lora_cfg.get("target_modules", default_targets.get(model_key, ["q", "v"])), bias="none", ) model = get_peft_model(model, lora_config) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) logger.info( f"LoRA applied: {trainable_params:,} trainable params / {total_params:,} total " f"({100 * trainable_params / total_params:.2f}%)" ) return model, tokenizer, is_seq2seq