File size: 955 Bytes
21bfda5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Common configuration for all models, including device and dtype settings.

import os
import torch

TOKEN = os.getenv("HF_TOKEN")
QUANTIZE_4_BIT = os.getenv("QUANTIZE_4_BIT", "false").lower() == "true"

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32

print(f"Using {torch_device} with dtype {torch_dtype}...")

model_config = {
    "torch_dtype": torch_dtype,
    "device_map": torch_device,
    "token": TOKEN,
}

tokenizer_config = {
    "token": TOKEN,
}

pipeline_config = {
    "torch_dtype": torch_dtype,
    "device_map": "auto",
}


def enable_quantization():
    print("Enabling 4-bit quantization for compatible models...")
    from transformers import BitsAndBytesConfig

    quantization_config = BitsAndBytesConfig(load_in_4bit=True)
    model_config["quantization_config"] = quantization_config


if QUANTIZE_4_BIT:
    enable_quantization()