| """ |
| Loads and wraps the base pretrained model. |
| Supported architectures: |
| - google/flan-t5-xl (recommended, 3B) |
| - google/flan-t5-large (780M, resource-constrained) |
| - facebook/bart-large (400M, excellent denoiser) |
| - meta-llama/Meta-Llama-3.1-8B-Instruct (8B, best quality) |
| """ |
|
|
| from transformers import ( |
| AutoTokenizer, AutoModelForSeq2SeqLM, |
| AutoModelForCausalLM, BitsAndBytesConfig |
| ) |
| from peft import get_peft_model, LoraConfig, TaskType |
| import torch |
| from loguru import logger |
|
|
|
|
| ENCODER_DECODER_MODELS = { |
| "flan-t5-xl": "google/flan-t5-xl", |
| "flan-t5-large": "google/flan-t5-large", |
| "flan-t5-base": "google/flan-t5-base", |
| "flan-t5-small": "google/flan-t5-small", |
| "bart-large": "facebook/bart-large", |
| } |
|
|
| DECODER_ONLY_MODELS = { |
| "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct", |
| } |
|
|
|
|
| def load_model_and_tokenizer(model_key: str, quantize: bool = False, use_lora: bool = True, |
| lora_config_dict: dict = None): |
| """ |
| Load a pretrained model with optional LoRA and quantization. |
| |
| Args: |
| model_key: Key from ENCODER_DECODER_MODELS or DECODER_ONLY_MODELS |
| quantize: Whether to use 4-bit quantization |
| use_lora: Whether to apply LoRA adapters |
| lora_config_dict: Optional dict with LoRA hyperparams (r, lora_alpha, etc.) |
| |
| Returns: |
| Tuple of (model, tokenizer, is_seq2seq) |
| """ |
| |
| is_seq2seq = model_key in ENCODER_DECODER_MODELS |
| is_causal = model_key in DECODER_ONLY_MODELS |
|
|
| if not is_seq2seq and not is_causal: |
| raise ValueError( |
| f"Unknown model key: '{model_key}'. " |
| f"Available: {list(ENCODER_DECODER_MODELS.keys()) + list(DECODER_ONLY_MODELS.keys())}" |
| ) |
|
|
| model_name = ENCODER_DECODER_MODELS.get(model_key) or DECODER_ONLY_MODELS.get(model_key) |
| logger.info(f"Loading model: {model_name} (seq2seq={is_seq2seq}, quantize={quantize}, lora={use_lora})") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| model_kwargs = { |
| "torch_dtype": torch.float32, |
| } |
|
|
| |
| device = "cpu" |
| if torch.cuda.is_available(): |
| device = "cuda" |
| |
| if torch.cuda.get_device_capability()[0] >= 8: |
| model_kwargs["torch_dtype"] = torch.bfloat16 |
| else: |
| model_kwargs["torch_dtype"] = torch.float16 |
|
|
| if quantize and device == "cuda": |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=model_kwargs["torch_dtype"], |
| bnb_4bit_use_double_quant=True, |
| ) |
| model_kwargs["quantization_config"] = bnb_config |
| logger.info("Using 4-bit NF4 quantization") |
| elif quantize and device == "cpu": |
| logger.warning("Quantization requested but no GPU available, skipping") |
|
|
| |
| if is_seq2seq: |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs) |
| else: |
| model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) |
|
|
| |
| if not quantize or device == "cpu": |
| model = model.to(device) |
|
|
| logger.info(f"Model loaded on {device} with dtype {model_kwargs.get('torch_dtype')}") |
|
|
| |
| if use_lora: |
| lora_cfg = lora_config_dict or {} |
| task_type = TaskType.SEQ_2_SEQ_LM if is_seq2seq else TaskType.CAUSAL_LM |
|
|
| |
| default_targets = { |
| "flan-t5-xl": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], |
| "flan-t5-large": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], |
| "flan-t5-base": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], |
| "flan-t5-small": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"], |
| "bart-large": ["q_proj", "v_proj", "k_proj", "out_proj"], |
| "llama-3.1-8b": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| } |
|
|
| lora_config = LoraConfig( |
| task_type=task_type, |
| r=lora_cfg.get("r", 16), |
| lora_alpha=lora_cfg.get("lora_alpha", 32), |
| lora_dropout=lora_cfg.get("lora_dropout", 0.05), |
| target_modules=lora_cfg.get("target_modules", default_targets.get(model_key, ["q", "v"])), |
| bias="none", |
| ) |
|
|
| model = get_peft_model(model, lora_config) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total_params = sum(p.numel() for p in model.parameters()) |
| logger.info( |
| f"LoRA applied: {trainable_params:,} trainable params / {total_params:,} total " |
| f"({100 * trainable_params / total_params:.2f}%)" |
| ) |
|
|
| return model, tokenizer, is_seq2seq |
|
|