elysium / backend /model_loader.py
pmrinal2005's picture
Upload folder using huggingface_hub
0cebb30 verified
Raw
History Blame Contribute Delete
2.92 kB
"""Loads the fine-tuned GGUF via llama-cpp-python.
Pattern follows the HF ZeroGPU + small-talk reference:
- hf_hub_download files at module import (warm cache)
- instantiate Llama inside @spaces.GPU function on each request
Robustness:
- Uses HF_TOKEN if available (avoids 401 on rate-limited / gated lookups)
- Treats MMPROJ_FILE as optional. If unset OR download fails, vision is
disabled gracefully (text-only chat still works).
- Fails loudly with a helpful message if the main GGUF cannot be fetched.
"""
import os
import traceback
from huggingface_hub import hf_hub_download
from .config import MODEL_REPO, GGUF_FILE, MMPROJ_FILE, HF_TOKEN
def _download(repo_id: str, filename: str):
"""hf_hub_download with optional token, returns path or raises."""
kwargs = {"repo_id": repo_id, "filename": filename}
if HF_TOKEN:
kwargs["token"] = HF_TOKEN
return hf_hub_download(**kwargs)
print(f"[model_loader] downloading {MODEL_REPO}/{GGUF_FILE} …")
try:
MODEL_PATH = _download(MODEL_REPO, GGUF_FILE)
print(f"[model_loader] main model ready: {MODEL_PATH}")
except Exception as e:
# We re-raise so the Space fails fast with a clear message rather than
# silently running with no brain. The traceback is already useful.
print(f"[model_loader] FAILED to download {MODEL_REPO}/{GGUF_FILE}: {e}")
print("[model_loader] Check that ELYSIUM_MODEL_REPO and ELYSIUM_GGUF_FILE "
"point at a real public file, and (for private repos) that HF_TOKEN is set.")
raise
# ─── mmproj (vision projector) is OPTIONAL ──────────────────────────────────
MMPROJ_PATH = None
if MMPROJ_FILE:
print(f"[model_loader] attempting mmproj {MODEL_REPO}/{MMPROJ_FILE} …")
try:
MMPROJ_PATH = _download(MODEL_REPO, MMPROJ_FILE)
print(f"[model_loader] mmproj ready: {MMPROJ_PATH}")
except Exception as e:
print(f"[model_loader] mmproj unavailable ({e}) β€” vision disabled")
MMPROJ_PATH = None
else:
print("[model_loader] ELYSIUM_MMPROJ_FILE not set β€” vision disabled (text-only mode)")
def make_llm():
"""Create a fresh Llama inside a GPU context.
The .gguf file is filesystem-cached, so this is fast after the first call.
"""
from llama_cpp import Llama
chat_handler = None
if MMPROJ_PATH:
try:
from llama_cpp.llama_chat_format import MiniCPMv26ChatHandler
chat_handler = MiniCPMv26ChatHandler(clip_model_path=MMPROJ_PATH, verbose=False)
except Exception as e:
print(f"[model_loader] vision chat handler failed: {e}")
traceback.print_exc()
chat_handler = None
return Llama(
model_path=MODEL_PATH,
chat_handler=chat_handler,
n_gpu_layers=-1,
n_ctx=8192,
flash_attn=True,
verbose=False,
)