File size: 2,915 Bytes
abf7d79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
"""
This module handles loading and saving of LLaMA models with efficient quantization.
This is already implemented and ready to use -- you don't need to modify this file.
Key Features:
- Loads LLaMA models from Hugging Face or local storage
- Implements 4-bit quantization for memory efficiency
- Provides save/load functionality for model persistence
- Handles model loading errors gracefully
Example Usage:
from model import load_model, save_model
# Load a model (will download if not found locally)
model, tokenizer = load_model("meta-llama/Llama-2-7b-chat-hf")
# Save model after making changes
save_model(model, tokenizer)
"""
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
# Choose a model
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Change this to your preferred model
# Other options:
# MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
# MODEL_NAME = "openlm-research/open_llama_3b"
# Path to save and load models
MODEL_SAVE_PATH = "models/school_chatbot"
def save_model(model, tokenizer, save_directory="models/school_chatbot"):
"""
Save the model and tokenizer to a local directory
"""
# Create directory if it doesn't exist
os.makedirs(save_directory, exist_ok=True)
# Save model and tokenizer
model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)
print(f"Model and tokenizer saved to {save_directory}")
def load_model():
"""
Load the model with 4-bit quantization
"""
try:
# Use quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Enable 4-bit quantization
bnb_4bit_compute_dtype=torch.float16, # Compute dtype
bnb_4bit_quant_type="nf4", # Normalized float 4 format
bnb_4bit_use_double_quant=True # Use nested quantization
)
if os.path.exists(MODEL_SAVE_PATH):
print("Loading quantized model from local storage...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_SAVE_PATH)
model = AutoModelForCausalLM.from_pretrained(
MODEL_SAVE_PATH,
quantization_config=quantization_config,
device_map="auto"
)
else:
print("Downloading and quantizing model from Hugging Face...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto"
)
# Save for future use
save_model(model, tokenizer)
return model, tokenizer
except Exception as e:
print(f"Error loading model: {e}")
return None, None
|