File size: 5,125 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Loads and wraps the base pretrained model.
Supported architectures:
  - google/flan-t5-xl          (recommended, 3B)
  - google/flan-t5-large       (780M, resource-constrained)
  - facebook/bart-large        (400M, excellent denoiser)
  - meta-llama/Meta-Llama-3.1-8B-Instruct (8B, best quality)
"""

from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    AutoModelForCausalLM, BitsAndBytesConfig
)
from peft import get_peft_model, LoraConfig, TaskType
import torch
from loguru import logger


ENCODER_DECODER_MODELS = {
    "flan-t5-xl": "google/flan-t5-xl",
    "flan-t5-large": "google/flan-t5-large",
    "flan-t5-base": "google/flan-t5-base",
    "flan-t5-small": "google/flan-t5-small",
    "bart-large": "facebook/bart-large",
}

DECODER_ONLY_MODELS = {
    "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
}


def load_model_and_tokenizer(model_key: str, quantize: bool = False, use_lora: bool = True,
                              lora_config_dict: dict = None):
    """
    Load a pretrained model with optional LoRA and quantization.

    Args:
        model_key: Key from ENCODER_DECODER_MODELS or DECODER_ONLY_MODELS
        quantize: Whether to use 4-bit quantization
        use_lora: Whether to apply LoRA adapters
        lora_config_dict: Optional dict with LoRA hyperparams (r, lora_alpha, etc.)

    Returns:
        Tuple of (model, tokenizer, is_seq2seq)
    """
    # Determine model type and HuggingFace identifier
    is_seq2seq = model_key in ENCODER_DECODER_MODELS
    is_causal = model_key in DECODER_ONLY_MODELS

    if not is_seq2seq and not is_causal:
        raise ValueError(
            f"Unknown model key: '{model_key}'. "
            f"Available: {list(ENCODER_DECODER_MODELS.keys()) + list(DECODER_ONLY_MODELS.keys())}"
        )

    model_name = ENCODER_DECODER_MODELS.get(model_key) or DECODER_ONLY_MODELS.get(model_key)
    logger.info(f"Loading model: {model_name} (seq2seq={is_seq2seq}, quantize={quantize}, lora={use_lora})")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Configure quantization if requested
    model_kwargs = {
        "torch_dtype": torch.float32,  # CPU-optimised: use float32 for stability
    }

    # Detect device
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
        # Use bfloat16 if Ampere+, else float16
        if torch.cuda.get_device_capability()[0] >= 8:
            model_kwargs["torch_dtype"] = torch.bfloat16
        else:
            model_kwargs["torch_dtype"] = torch.float16

    if quantize and device == "cuda":
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
            bnb_4bit_use_double_quant=True,
        )
        model_kwargs["quantization_config"] = bnb_config
        logger.info("Using 4-bit NF4 quantization")
    elif quantize and device == "cpu":
        logger.warning("Quantization requested but no GPU available, skipping")

    # Load model
    if is_seq2seq:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)

    # Move to device if not quantized (quantized models are already on device)
    if not quantize or device == "cpu":
        model = model.to(device)

    logger.info(f"Model loaded on {device} with dtype {model_kwargs.get('torch_dtype')}")

    # Apply LoRA if requested
    if use_lora:
        lora_cfg = lora_config_dict or {}
        task_type = TaskType.SEQ_2_SEQ_LM if is_seq2seq else TaskType.CAUSAL_LM

        # Default target modules based on model architecture
        default_targets = {
            "flan-t5-xl": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
            "flan-t5-large": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
            "flan-t5-base": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
            "flan-t5-small": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
            "bart-large": ["q_proj", "v_proj", "k_proj", "out_proj"],
            "llama-3.1-8b": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        }

        lora_config = LoraConfig(
            task_type=task_type,
            r=lora_cfg.get("r", 16),
            lora_alpha=lora_cfg.get("lora_alpha", 32),
            lora_dropout=lora_cfg.get("lora_dropout", 0.05),
            target_modules=lora_cfg.get("target_modules", default_targets.get(model_key, ["q", "v"])),
            bias="none",
        )

        model = get_peft_model(model, lora_config)
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        logger.info(
            f"LoRA applied: {trainable_params:,} trainable params / {total_params:,} total "
            f"({100 * trainable_params / total_params:.2f}%)"
        )

    return model, tokenizer, is_seq2seq