key_word_Fast_API / services /model_loader.py
ihtesham0345's picture
feat: Add Grammar Correction API v2.1.0
ff08af5
Raw
History Blame Contribute Delete
4.81 kB
import os
import torch
from transformers import pipeline, BitsAndBytesConfig
from dotenv import load_dotenv
from pathlib import Path
env_path = Path(__file__).resolve().parent.parent / ".env"
load_dotenv(dotenv_path=env_path)
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
QUANTIZATION = os.getenv("QUANTIZATION", "auto")
USE_DOUBLE_QUANT = os.getenv("USE_DOUBLE_QUANT", "true").lower() == "true"
_pipe = None
_current_model = None
def _log(msg: str):
print(f"[ModelLoader] {msg}")
def _has_gpu() -> bool:
return torch.cuda.is_available()
def _gpu_name() -> str:
if _has_gpu():
return torch.cuda.get_device_name(0)
return "None"
def _gpu_memory_gb() -> float:
if _has_gpu():
try:
return torch.cuda.get_device_properties(0).total_mem / 1e9
except:
return 0
return 0
def _select_quantization() -> str:
"""Auto-select quantization tier based on MODEL_ID and hardware."""
user_mode = QUANTIZATION.lower()
if user_mode == "none":
return "none"
if user_mode != "auto":
return user_mode
# Auto-detect: GPU with enough VRAM for requested model
if "7B" in MODEL_ID:
if _has_gpu() and _gpu_memory_gb() >= 5.5:
_log(f"7B model detected, GPU {_gpu_name()} ({_gpu_memory_gb():.1f}GB) — using 4-bit")
return "4bit"
_log("7B model requested but no GPU with 5.5GB+ VRAM — falling back to 1.5B 8-bit")
return "cpu_fallback_8bit"
if "1.5B" in MODEL_ID:
if _has_gpu():
_log(f"1.5B model detected, GPU available — using 8-bit")
return "8bit"
_log("1.5B model detected, CPU only — using bfloat16")
return "none"
return "none"
def _build_model_kwargs(quant_mode: str) -> dict:
"""Build pipeline kwargs based on quantization mode."""
kwargs = {
"trust_remote_code": True,
}
if quant_mode == "4bit":
kwargs["device_map"] = "auto"
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=USE_DOUBLE_QUANT,
bnb_4bit_quant_type="nf4",
)
_log("[OK] 4-bit quantization enabled (NF4, double quant)")
elif quant_mode == "8bit":
kwargs["device_map"] = "auto"
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
)
_log("[OK] 8-bit quantization enabled")
elif quant_mode == "cpu_fallback_8bit":
kwargs["device_map"] = "auto"
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
)
_log("[OK] CPU fallback 8-bit for 1.5B model")
else:
kwargs["torch_dtype"] = torch.bfloat16
kwargs["device_map"] = "auto"
_log(f"[OK] Loading {MODEL_ID} in bfloat16 (CPU-friendly)")
return kwargs
def get_pipe():
global _pipe, _current_model
if _pipe is not None:
return _pipe
actual_model_id = MODEL_ID
quant_mode = _select_quantization()
# Handle CPU fallback for 7B → 1.5B
if quant_mode == "cpu_fallback_8bit":
actual_model_id = "Qwen/Qwen2.5-1.5B-Instruct"
_log(f"[FALLBACK] loading {actual_model_id} instead of {MODEL_ID}")
_log(f"Loading {actual_model_id} (quantization: {quant_mode})")
_log(f" Hardware: GPU={_gpu_name()}, VRAM={_gpu_memory_gb():.1f}GB, CUDA={_has_gpu()}")
try:
kwargs = _build_model_kwargs(quant_mode)
_pipe = pipeline(
"text-generation",
model=actual_model_id,
**kwargs
)
_current_model = actual_model_id
_log("[DONE] Model loaded successfully!")
except ImportError as e:
if "bitsandbytes" in str(e):
_log("[ERROR] bitsandbytes not installed. Falling back to CPU bfloat16.")
_pipe = pipeline(
"text-generation",
model=actual_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
_current_model = actual_model_id
_log("[DONE] Model loaded with CPU fallback")
else:
_log(f"[ERROR] Model load failed: {e}")
_pipe = None
except Exception as e:
_log(f"❌ Model load failed: {e}")
_pipe = None
return _pipe
def generate_text(messages, temperature=0.3, max_new_tokens=2000):
pipe = get_pipe()
if pipe is None:
return None
outputs = pipe(
messages,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9
)
return outputs[0]["generated_text"][-1]["content"]