Spaces:
Sleeping
Sleeping
File size: 3,235 Bytes
e9bb6c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from optimum.quanto import quantize, freeze, qint8, qint4, qint2, qfloat8
from enum import Enum
from typing import Tuple, Any, Optional
class QuantizationType(Enum):
"""Supported quantization types."""
NONE = "none"
INT8 = "int8"
INT4 = "int4"
INT2 = "int2"
FLOAT8 = "float8"
class ModelLoader:
"""Handles model loading with different quantization strategies."""
@staticmethod
def load_standard(model_name: str, device: str) -> Tuple[Any, Any]:
"""Load model without quantization."""
print(f"Loading {model_name} (standard)")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map=device if device != "cpu" else None
)
if device == "cpu":
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
@staticmethod
def load_quantized_transformers(model_name: str, quant_type: QuantizationType) -> Tuple[Any, Any]:
"""Load model using Transformers QuantoConfig integration."""
print(f"Loading {model_name} with {quant_type.value} quantization (Transformers)")
quant_config = QuantoConfig(weights=quant_type.value)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype="auto",
device_map="auto",
quantization_config=quant_config
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
@staticmethod
def load_quantized_direct(model_name: str, quant_type: QuantizationType, device: str) -> Tuple[Any, Any]:
"""Load model using direct quanto quantization API."""
print(f"Loading {model_name} with {quant_type.value} quantization (Direct API)")
# Load base model
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map=device if device != "cpu" else None
)
if device == "cpu":
model = model.to(device)
# Apply quantization
quant_map = {
QuantizationType.INT8: qint8,
QuantizationType.INT4: qint4,
QuantizationType.INT2: qint2,
QuantizationType.FLOAT8: qfloat8
}
quantize(model, weights=quant_map[quant_type])
freeze(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer |