Spaces:
Running
Running
| """Model quantization utilities for reduced VRAM usage. | |
| Phase 8: Supports 4-bit and 8-bit quantization via bitsandbytes. | |
| Halves or quarters VRAM requirements for MedGemma models. | |
| Set ``MODEL_QUANTIZATION_ENABLED=true`` and ``MODEL_QUANTIZATION_BITS=4`` | |
| in config or .env to enable. | |
| """ | |
| import logging | |
| from typing import Any, Dict, Optional | |
| from app.config import settings | |
| logger = logging.getLogger(__name__) | |
| def get_quantization_config() -> Optional[Any]: | |
| """Build a BitsAndBytesConfig if quantization is enabled. | |
| Returns None if quantization is disabled or bitsandbytes is not installed. | |
| """ | |
| if not settings.model_quantization_enabled: | |
| return None | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| except ImportError: | |
| logger.warning("transformers BitsAndBytesConfig not available; skipping quantization") | |
| return None | |
| try: | |
| import bitsandbytes # noqa: F401 | |
| except ImportError: | |
| logger.warning("bitsandbytes not installed; quantization disabled") | |
| return None | |
| bits = settings.model_quantization_bits | |
| if bits == 4: | |
| config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype="bfloat16", | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| logger.info("Quantization: 4-bit NF4 with double quantization enabled") | |
| elif bits == 8: | |
| config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| ) | |
| logger.info("Quantization: 8-bit enabled") | |
| else: | |
| logger.warning("Unsupported quantization bits=%d; use 4 or 8", bits) | |
| return None | |
| return config | |
| def get_model_load_kwargs() -> Dict[str, Any]: | |
| """Get kwargs to pass to AutoModelForCausalLM.from_pretrained(). | |
| Handles quantization config, device mapping, and dtype selection. | |
| """ | |
| import torch | |
| kwargs: Dict[str, Any] = { | |
| "low_cpu_mem_usage": True, | |
| "token": settings.hf_token if settings.hf_token else None, | |
| } | |
| quant_config = get_quantization_config() | |
| if quant_config is not None: | |
| kwargs["quantization_config"] = quant_config | |
| kwargs["device_map"] = "auto" | |
| logger.info("Model will load with quantization + device_map=auto") | |
| elif settings.enable_gpu and torch.cuda.is_available(): | |
| kwargs["torch_dtype"] = torch.bfloat16 | |
| kwargs["device_map"] = "auto" | |
| else: | |
| kwargs["torch_dtype"] = torch.float32 | |
| return kwargs | |
| def estimate_vram_usage() -> Dict[str, Any]: | |
| """Estimate VRAM usage for the current quantization setting.""" | |
| import torch | |
| model_name = settings.medgemma_model | |
| param_estimate_b = 4.0 # ~4B params for medgemma-4b | |
| if "27b" in model_name.lower(): | |
| param_estimate_b = 27.0 | |
| elif "7b" in model_name.lower(): | |
| param_estimate_b = 7.0 | |
| elif "2b" in model_name.lower(): | |
| param_estimate_b = 2.0 | |
| bits = settings.model_quantization_bits if settings.model_quantization_enabled else 16 | |
| bytes_per_param = bits / 8 | |
| estimated_gb = (param_estimate_b * 1e9 * bytes_per_param) / (1024 ** 3) | |
| gpu_available = torch.cuda.is_available() | |
| gpu_total_gb = 0.0 | |
| if gpu_available: | |
| gpu_total_gb = torch.cuda.get_device_properties(0).total_mem / (1024 ** 3) | |
| return { | |
| "model": model_name, | |
| "estimated_params_billion": param_estimate_b, | |
| "quantization_bits": bits, | |
| "estimated_vram_gb": round(estimated_gb, 2), | |
| "gpu_available": gpu_available, | |
| "gpu_total_vram_gb": round(gpu_total_gb, 2), | |
| "fits_in_vram": gpu_total_gb >= estimated_gb if gpu_available else False, | |
| } | |