File size: 2,319 Bytes
4fe7b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from transformers import (
    AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
)
from peft import PeftModel

def load_model(model_name, finetune_type):
    """Loads a fine-tuned model from the Hugging Face repository based on its type."""
    if model_name not in MODEL_REPOS:
        raise ValueError(f"Invalid model name. Choose from: {list(MODEL_REPOS.keys())}")
    
    if finetune_type not in MODEL_REPOS[model_name]:
        raise ValueError(f"Invalid finetune type. Choose from: {list(MODEL_REPOS[model_name].keys())}")

    repo_name = MODEL_REPOS[model_name][finetune_type]
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(repo_name)

    if model_name == "mT5":  # 4-bit quantized + QLoRA fine-tuned
        print(f"Loading {model_name} with {finetune_type} finetuning, 4-bit quantization, and QLoRA...")

        # Load model with 4-bit quantization settings
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )

        base_model_name = "google/mt5-xl"  # Use correct base model
        model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name, quantization_config=quant_config, device_map="auto")

        # Apply fine-tuned LoRA adapter
        model = PeftModel.from_pretrained(model, repo_name)

    elif model_name == "mBART50":  # Normally fine-tuned
        print(f"Loading {model_name} with {finetune_type} fine-tuning...")

        model = AutoModelForSeq2SeqLM.from_pretrained(repo_name)
        model.to(device)

    else:
        raise ValueError(f"Unknown model: {model_name}")

    print(f"{model_name} ({finetune_type}) loaded successfully!")
    return model, tokenizer

MODEL_REPOS = {
    "mT5": {
        "english": "darpanaswal/mT5-english-finetuned",
        "multilingual": "darpanaswal/mT5-multilingual-finetuned",
        "crosslingual": "darpanaswal/mT5-crosslingual-finetuned",
    },
    "mBART50": {
        "english": "darpanaswal/mBART50-english-finetuned",
        "multilingual": "darpanaswal/mBART50-multilingual-finetuned",
        "crosslingual": "darpanaswal/mBART50-crosslingual-finetuned",
    },
}