File size: 5,125 Bytes
12fd5f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
Loads and wraps the base pretrained model.
Supported architectures:
- google/flan-t5-xl (recommended, 3B)
- google/flan-t5-large (780M, resource-constrained)
- facebook/bart-large (400M, excellent denoiser)
- meta-llama/Meta-Llama-3.1-8B-Instruct (8B, best quality)
"""
from transformers import (
AutoTokenizer, AutoModelForSeq2SeqLM,
AutoModelForCausalLM, BitsAndBytesConfig
)
from peft import get_peft_model, LoraConfig, TaskType
import torch
from loguru import logger
ENCODER_DECODER_MODELS = {
"flan-t5-xl": "google/flan-t5-xl",
"flan-t5-large": "google/flan-t5-large",
"flan-t5-base": "google/flan-t5-base",
"flan-t5-small": "google/flan-t5-small",
"bart-large": "facebook/bart-large",
}
DECODER_ONLY_MODELS = {
"llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
}
def load_model_and_tokenizer(model_key: str, quantize: bool = False, use_lora: bool = True,
lora_config_dict: dict = None):
"""
Load a pretrained model with optional LoRA and quantization.
Args:
model_key: Key from ENCODER_DECODER_MODELS or DECODER_ONLY_MODELS
quantize: Whether to use 4-bit quantization
use_lora: Whether to apply LoRA adapters
lora_config_dict: Optional dict with LoRA hyperparams (r, lora_alpha, etc.)
Returns:
Tuple of (model, tokenizer, is_seq2seq)
"""
# Determine model type and HuggingFace identifier
is_seq2seq = model_key in ENCODER_DECODER_MODELS
is_causal = model_key in DECODER_ONLY_MODELS
if not is_seq2seq and not is_causal:
raise ValueError(
f"Unknown model key: '{model_key}'. "
f"Available: {list(ENCODER_DECODER_MODELS.keys()) + list(DECODER_ONLY_MODELS.keys())}"
)
model_name = ENCODER_DECODER_MODELS.get(model_key) or DECODER_ONLY_MODELS.get(model_key)
logger.info(f"Loading model: {model_name} (seq2seq={is_seq2seq}, quantize={quantize}, lora={use_lora})")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Configure quantization if requested
model_kwargs = {
"torch_dtype": torch.float32, # CPU-optimised: use float32 for stability
}
# Detect device
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
# Use bfloat16 if Ampere+, else float16
if torch.cuda.get_device_capability()[0] >= 8:
model_kwargs["torch_dtype"] = torch.bfloat16
else:
model_kwargs["torch_dtype"] = torch.float16
if quantize and device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
bnb_4bit_use_double_quant=True,
)
model_kwargs["quantization_config"] = bnb_config
logger.info("Using 4-bit NF4 quantization")
elif quantize and device == "cpu":
logger.warning("Quantization requested but no GPU available, skipping")
# Load model
if is_seq2seq:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
# Move to device if not quantized (quantized models are already on device)
if not quantize or device == "cpu":
model = model.to(device)
logger.info(f"Model loaded on {device} with dtype {model_kwargs.get('torch_dtype')}")
# Apply LoRA if requested
if use_lora:
lora_cfg = lora_config_dict or {}
task_type = TaskType.SEQ_2_SEQ_LM if is_seq2seq else TaskType.CAUSAL_LM
# Default target modules based on model architecture
default_targets = {
"flan-t5-xl": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"flan-t5-large": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"flan-t5-base": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"flan-t5-small": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"bart-large": ["q_proj", "v_proj", "k_proj", "out_proj"],
"llama-3.1-8b": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
}
lora_config = LoraConfig(
task_type=task_type,
r=lora_cfg.get("r", 16),
lora_alpha=lora_cfg.get("lora_alpha", 32),
lora_dropout=lora_cfg.get("lora_dropout", 0.05),
target_modules=lora_cfg.get("target_modules", default_targets.get(model_key, ["q", "v"])),
bias="none",
)
model = get_peft_model(model, lora_config)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
logger.info(
f"LoRA applied: {trainable_params:,} trainable params / {total_params:,} total "
f"({100 * trainable_params / total_params:.2f}%)"
)
return model, tokenizer, is_seq2seq
|