rewrite / src /model /base_model.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
Loads and wraps the base pretrained model.
Supported architectures:
- google/flan-t5-xl (recommended, 3B)
- google/flan-t5-large (780M, resource-constrained)
- facebook/bart-large (400M, excellent denoiser)
- meta-llama/Meta-Llama-3.1-8B-Instruct (8B, best quality)
"""
from transformers import (
AutoTokenizer, AutoModelForSeq2SeqLM,
AutoModelForCausalLM, BitsAndBytesConfig
)
from peft import get_peft_model, LoraConfig, TaskType
import torch
from loguru import logger
ENCODER_DECODER_MODELS = {
"flan-t5-xl": "google/flan-t5-xl",
"flan-t5-large": "google/flan-t5-large",
"flan-t5-base": "google/flan-t5-base",
"flan-t5-small": "google/flan-t5-small",
"bart-large": "facebook/bart-large",
}
DECODER_ONLY_MODELS = {
"llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
}
def load_model_and_tokenizer(model_key: str, quantize: bool = False, use_lora: bool = True,
lora_config_dict: dict = None):
"""
Load a pretrained model with optional LoRA and quantization.
Args:
model_key: Key from ENCODER_DECODER_MODELS or DECODER_ONLY_MODELS
quantize: Whether to use 4-bit quantization
use_lora: Whether to apply LoRA adapters
lora_config_dict: Optional dict with LoRA hyperparams (r, lora_alpha, etc.)
Returns:
Tuple of (model, tokenizer, is_seq2seq)
"""
# Determine model type and HuggingFace identifier
is_seq2seq = model_key in ENCODER_DECODER_MODELS
is_causal = model_key in DECODER_ONLY_MODELS
if not is_seq2seq and not is_causal:
raise ValueError(
f"Unknown model key: '{model_key}'. "
f"Available: {list(ENCODER_DECODER_MODELS.keys()) + list(DECODER_ONLY_MODELS.keys())}"
)
model_name = ENCODER_DECODER_MODELS.get(model_key) or DECODER_ONLY_MODELS.get(model_key)
logger.info(f"Loading model: {model_name} (seq2seq={is_seq2seq}, quantize={quantize}, lora={use_lora})")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Configure quantization if requested
model_kwargs = {
"torch_dtype": torch.float32, # CPU-optimised: use float32 for stability
}
# Detect device
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
# Use bfloat16 if Ampere+, else float16
if torch.cuda.get_device_capability()[0] >= 8:
model_kwargs["torch_dtype"] = torch.bfloat16
else:
model_kwargs["torch_dtype"] = torch.float16
if quantize and device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
bnb_4bit_use_double_quant=True,
)
model_kwargs["quantization_config"] = bnb_config
logger.info("Using 4-bit NF4 quantization")
elif quantize and device == "cpu":
logger.warning("Quantization requested but no GPU available, skipping")
# Load model
if is_seq2seq:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
# Move to device if not quantized (quantized models are already on device)
if not quantize or device == "cpu":
model = model.to(device)
logger.info(f"Model loaded on {device} with dtype {model_kwargs.get('torch_dtype')}")
# Apply LoRA if requested
if use_lora:
lora_cfg = lora_config_dict or {}
task_type = TaskType.SEQ_2_SEQ_LM if is_seq2seq else TaskType.CAUSAL_LM
# Default target modules based on model architecture
default_targets = {
"flan-t5-xl": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"flan-t5-large": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"flan-t5-base": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"flan-t5-small": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"bart-large": ["q_proj", "v_proj", "k_proj", "out_proj"],
"llama-3.1-8b": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
}
lora_config = LoraConfig(
task_type=task_type,
r=lora_cfg.get("r", 16),
lora_alpha=lora_cfg.get("lora_alpha", 32),
lora_dropout=lora_cfg.get("lora_dropout", 0.05),
target_modules=lora_cfg.get("target_modules", default_targets.get(model_key, ["q", "v"])),
bias="none",
)
model = get_peft_model(model, lora_config)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
logger.info(
f"LoRA applied: {trainable_params:,} trainable params / {total_params:,} total "
f"({100 * trainable_params / total_params:.2f}%)"
)
return model, tokenizer, is_seq2seq