optimization-engineer / models /quantization.py
AIguysingstoo's picture
Upload 9 files
e9bb6c3 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from optimum.quanto import quantize, freeze, qint8, qint4, qint2, qfloat8
from enum import Enum
from typing import Tuple, Any, Optional
class QuantizationType(Enum):
"""Supported quantization types."""
NONE = "none"
INT8 = "int8"
INT4 = "int4"
INT2 = "int2"
FLOAT8 = "float8"
class ModelLoader:
"""Handles model loading with different quantization strategies."""
@staticmethod
def load_standard(model_name: str, device: str) -> Tuple[Any, Any]:
"""Load model without quantization."""
print(f"Loading {model_name} (standard)")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map=device if device != "cpu" else None
)
if device == "cpu":
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
@staticmethod
def load_quantized_transformers(model_name: str, quant_type: QuantizationType) -> Tuple[Any, Any]:
"""Load model using Transformers QuantoConfig integration."""
print(f"Loading {model_name} with {quant_type.value} quantization (Transformers)")
quant_config = QuantoConfig(weights=quant_type.value)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype="auto",
device_map="auto",
quantization_config=quant_config
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
@staticmethod
def load_quantized_direct(model_name: str, quant_type: QuantizationType, device: str) -> Tuple[Any, Any]:
"""Load model using direct quanto quantization API."""
print(f"Loading {model_name} with {quant_type.value} quantization (Direct API)")
# Load base model
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map=device if device != "cpu" else None
)
if device == "cpu":
model = model.to(device)
# Apply quantization
quant_map = {
QuantizationType.INT8: qint8,
QuantizationType.INT4: qint4,
QuantizationType.INT2: qint2,
QuantizationType.FLOAT8: qfloat8
}
quantize(model, weights=quant_map[quant_type])
freeze(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer