File size: 2,915 Bytes
abf7d79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
This module handles loading and saving of LLaMA models with efficient quantization.
This is already implemented and ready to use -- you don't need to modify this file.

Key Features:
- Loads LLaMA models from Hugging Face or local storage
- Implements 4-bit quantization for memory efficiency
- Provides save/load functionality for model persistence
- Handles model loading errors gracefully

Example Usage:
    from model import load_model, save_model
    
    # Load a model (will download if not found locally)
    model, tokenizer = load_model("meta-llama/Llama-2-7b-chat-hf")
    
    # Save model after making changes
    save_model(model, tokenizer)
"""

import os
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Choose a model
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # Change this to your preferred model
# Other options:
# MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
# MODEL_NAME = "openlm-research/open_llama_3b"

# Path to save and load models
MODEL_SAVE_PATH = "models/school_chatbot"


def save_model(model, tokenizer, save_directory="models/school_chatbot"):
    """
    Save the model and tokenizer to a local directory
    """
    # Create directory if it doesn't exist
    os.makedirs(save_directory, exist_ok=True)
    
    # Save model and tokenizer
    model.save_pretrained(save_directory)
    tokenizer.save_pretrained(save_directory)
    
    print(f"Model and tokenizer saved to {save_directory}")


def load_model():
    """
    Load the model with 4-bit quantization
    """
    try:
        # Use quantization to reduce memory usage
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,              # Enable 4-bit quantization
            bnb_4bit_compute_dtype=torch.float16,  # Compute dtype
            bnb_4bit_quant_type="nf4",     # Normalized float 4 format
            bnb_4bit_use_double_quant=True # Use nested quantization
        )

        if os.path.exists(MODEL_SAVE_PATH):
            print("Loading quantized model from local storage...")
            tokenizer = AutoTokenizer.from_pretrained(MODEL_SAVE_PATH)
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_SAVE_PATH,
                quantization_config=quantization_config,
                device_map="auto"
            )
        else:
            print("Downloading and quantizing model from Hugging Face...")
            tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                quantization_config=quantization_config,
                device_map="auto"
            )
            # Save for future use
            save_model(model, tokenizer)
            
        return model, tokenizer
        
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None