Spaces:
Sleeping
Sleeping
File size: 955 Bytes
21bfda5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | # Common configuration for all models, including device and dtype settings.
import os
import torch
TOKEN = os.getenv("HF_TOKEN")
QUANTIZE_4_BIT = os.getenv("QUANTIZE_4_BIT", "false").lower() == "true"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32
print(f"Using {torch_device} with dtype {torch_dtype}...")
model_config = {
"torch_dtype": torch_dtype,
"device_map": torch_device,
"token": TOKEN,
}
tokenizer_config = {
"token": TOKEN,
}
pipeline_config = {
"torch_dtype": torch_dtype,
"device_map": "auto",
}
def enable_quantization():
print("Enabling 4-bit quantization for compatible models...")
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_config["quantization_config"] = quantization_config
if QUANTIZE_4_BIT:
enable_quantization()
|