quant-eval-agent-arena / model_loader.py
pbhappliedsystems's picture
Upload 3 files
813f493 verified
# model_loader.py
# PBH Applied Systems — GGUF model loading via llama-cpp-python.
# ZeroGPU-safe: models are loaded inside @spaces.GPU decorated scope.
#
# The cu121 llama-cpp-python wheel links libllama.so against CUDA 12 SONAMEs
# (libcudart.so.12, libcublasLt.so.12). ZeroGPU runs CUDA 13 system-wide,
# so these files are not present in the system library paths.
#
# Fix: requirements.txt installs nvidia-cuda-runtime-cu12 and nvidia-cublas-cu12
# as Python packages that ship libcudart.so.12 and libcublas*.so.12 inside
# site-packages. model_loader.py finds them there and preloads them with
# ctypes.RTLD_GLOBAL before `from llama_cpp import Llama` fires.
import os
import glob
import ctypes
import logging
import site
import sysconfig
from pathlib import Path
from eval_data import MODELS, pair_is_feasible
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Jinja2 patch — register {% generation %} as a known extension.
# ---------------------------------------------------------------------------
def _patch_jinja2_generation_tag() -> None:
try:
import jinja2
from jinja2 import nodes
from jinja2.ext import Extension
class _GenerationExtension(Extension):
tags = frozenset(["generation"])
def parse(self, parser):
lineno = next(parser.stream).lineno
body = parser.parse_statements(
["name:endgeneration"], drop_needle=True
)
return nodes.Scope(body, lineno=lineno)
_orig_init = jinja2.Environment.__init__
def _patched_init(env_self, *args, **kwargs):
exts = list(kwargs.get("extensions", []))
ext_tags = {
tag
for ext in exts
if isinstance(ext, type) and hasattr(ext, "tags")
for tag in ext.tags
}
if "generation" not in ext_tags:
exts.append(_GenerationExtension)
kwargs["extensions"] = exts
_orig_init(env_self, *args, **kwargs)
jinja2.Environment.__init__ = _patched_init
logger.info("Jinja2 {% generation %} extension registered.")
except Exception as exc:
logger.warning(f"Jinja2 generation tag patch failed (non-fatal): {exc}")
_patch_jinja2_generation_tag()
# ---------------------------------------------------------------------------
# spaces import
# ---------------------------------------------------------------------------
try:
import spaces
ZEROGPU_AVAILABLE = True
except ImportError:
ZEROGPU_AVAILABLE = False
class spaces:
@staticmethod
def GPU(fn=None, duration=None):
if fn is not None:
return fn
def decorator(f):
return f
return decorator
from huggingface_hub import hf_hub_download
CACHE_DIR = Path(os.environ.get("HF_HOME", "/tmp/hf_cache")) / "pbh_gguf"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
DEFAULT_N_CTX = 8192
DEFAULT_N_GPU_LAYERS = -1
DEFAULT_N_THREADS = 4
CHAT_FORMAT_OVERRIDES = {
"ministral-14b-instruct": "mistral-instruct",
"ministral-14b-reasoning": "mistral-instruct",
"phi4-reasoning-plus": "chatml",
"mistral-nemo": "mistral-instruct",
}
_model_cache: dict = {}
# ---------------------------------------------------------------------------
# CUDA runtime preload for llama-cpp-python cu121 wheel
# ---------------------------------------------------------------------------
def _iter_cuda_library_dirs() -> list:
"""Return candidate CUDA library directories visible to this process."""
dirs = []
def add(path):
if path and os.path.isdir(path) and path not in dirs:
dirs.append(path)
# 1. Python-packaged NVIDIA CUDA 12 libraries from requirements.txt.
site_roots = set(site.getsitepackages())
user_site = site.getusersitepackages()
if user_site:
site_roots.add(user_site)
purelib = sysconfig.get_paths().get("purelib")
if purelib:
site_roots.add(purelib)
for root in site_roots:
add(os.path.join(root, "nvidia", "cuda_runtime", "lib"))
add(os.path.join(root, "nvidia", "cublas", "lib"))
# 2. ZeroGPU/runtime CUDA locations.
cuda_home = (
os.environ.get("CUDA_HOME")
or os.environ.get("CUDA_PATH")
or os.environ.get("CUDADIR")
)
if cuda_home:
add(os.path.join(cuda_home, "lib64"))
# 3. Common system paths, including ZeroGPU's mounted CUDA image.
for pattern in [
"/cuda-image/usr/local/cuda*/lib64",
"/usr/local/cuda*/lib64",
"/usr/lib/x86_64-linux-gnu",
"/usr/lib64",
]:
for match in glob.glob(pattern):
add(match)
return dirs
def _prepend_ld_library_path(dirs: list) -> None:
existing = [p for p in os.environ.get("LD_LIBRARY_PATH", "").split(":") if p]
merged = []
for path in dirs + existing:
if path and path not in merged:
merged.append(path)
os.environ["LD_LIBRARY_PATH"] = ":".join(merged)
def _preload_shared_library(filename: str, dirs: list, required: bool = True) -> bool:
"""Load a shared library by absolute path with RTLD_GLOBAL."""
for lib_dir in dirs:
candidate = os.path.join(lib_dir, filename)
if os.path.exists(candidate):
try:
ctypes.CDLL(candidate, mode=ctypes.RTLD_GLOBAL)
logger.info("Preloaded %s from %s", filename, candidate)
return True
except Exception as e:
logger.warning("Failed to preload %s from %s: %s", filename, candidate, e)
if required:
logger.warning("Required CUDA library not found: %s (searched: %s)", filename, dirs)
return False
def _ensure_cuda_compat() -> None:
"""
Ensure llama-cpp-python's cu121 wheel can import libllama.so.
The cu121 wheel is linked against CUDA 12 SONAMEs such as libcudart.so.12.
In a Gradio/ZeroGPU Space, the CUDA 12 runtime must be provided by Python
packages in requirements.txt and preloaded before `from llama_cpp import Llama`.
"""
lib_dirs = _iter_cuda_library_dirs()
logger.info("CUDA library search dirs: %s", lib_dirs)
_prepend_ld_library_path(lib_dirs)
# Load dependency libraries explicitly with RTLD_GLOBAL so the dynamic
# linker finds their symbols when loading libllama.so's DT_NEEDED entries.
loaded_cudart = _preload_shared_library("libcudart.so.12", lib_dirs, required=True)
_preload_shared_library("libcublasLt.so.12", lib_dirs, required=False)
_preload_shared_library("libcublas.so.12", lib_dirs, required=False)
if not loaded_cudart:
logger.warning(
"libcudart.so.12 was not preloaded. llama_cpp import may fail. "
"Check that nvidia-cuda-runtime-cu12 is installed."
)
# ---------------------------------------------------------------------------
# GGUF download
# ---------------------------------------------------------------------------
def _download_gguf(model_key: str) -> str:
m = MODELS[model_key]
logger.info(f"Downloading {m['hf_filename']} from {m['hf_repo']}...")
local_path = hf_hub_download(
repo_id=m["hf_repo"],
filename=m["hf_filename"],
cache_dir=str(CACHE_DIR),
)
logger.info(f"GGUF at: {local_path}")
return local_path
# ---------------------------------------------------------------------------
# Model loading — lazy llama_cpp import inside @spaces.GPU scope
# ---------------------------------------------------------------------------
_cuda_compat_done = False
def load_model(model_key: str, n_ctx: int = DEFAULT_N_CTX):
"""
Load a GGUF model by key. Returns cached instance if already loaded.
Must be called within a @spaces.GPU decorated function on ZeroGPU.
"""
global _cuda_compat_done
if model_key in _model_cache:
logger.info(f"Cache hit: {model_key}")
return _model_cache[model_key]
if model_key not in MODELS:
raise ValueError(f"Unknown model key: {model_key}")
# Preload CUDA 12 runtime libs before first llama_cpp import
if not _cuda_compat_done:
_ensure_cuda_compat()
_cuda_compat_done = True
# Lazy import — runs inside @spaces.GPU where GPU is allocated
from llama_cpp import Llama
m = MODELS[model_key]
logger.info(f"Loading {m['display_name']} (n_ctx={n_ctx})...")
gguf_path = _download_gguf(model_key)
chat_format = CHAT_FORMAT_OVERRIDES.get(model_key, None)
llm = Llama(
model_path=gguf_path,
n_ctx=n_ctx,
n_gpu_layers=DEFAULT_N_GPU_LAYERS,
n_threads=DEFAULT_N_THREADS,
verbose=False,
flash_attn=True,
chat_format=chat_format,
)
_model_cache[model_key] = llm
logger.info(
f"Loaded and cached: {model_key} "
f"(chat_format={chat_format or 'auto'})"
)
return llm
def validate_pair(model_key_a: str, model_key_b: str) -> tuple:
return pair_is_feasible(model_key_a, model_key_b)
def get_model_n_ctx(model_key: str) -> int:
if model_key == "qwen2.5-14b-1m":
return 8192
m = MODELS.get(model_key, {})
return min(DEFAULT_N_CTX, m.get("context_window", DEFAULT_N_CTX))