rag-chatbot / components /llm_handler.py
Mobiworks's picture
Sync from GitHub via hub-sync
4cf1913 verified
Raw
History Blame Contribute Delete
4.81 kB
"""
llm_handler.py
--------------
Loads and runs the open-source LLM (Phi-2 GGUF) via llama-cpp-python.
Step 3 Enhancement:
- Added generate_stream() which yields tokens one by one for streaming UI.
- generate() kept unchanged β€” still used by non-streaming code paths.
Design decisions
----------------
* GGUF 4-bit quantisation (Q4_K_M) keeps RAM usage low.
* Model downloaded via HuggingFace Hub global cache (~/.cache/huggingface/hub/)
which persists between Space restarts on code-only pushes β€” no re-download.
* GPU layers default to 0 (CPU-only) but can be set via LLM_N_GPU_LAYERS env var.
"""
import logging
import os
from pathlib import Path
from typing import Generator
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from app.config import (
LLM_CACHE_DIR,
LLM_CONTEXT_LEN,
LLM_MAX_TOKENS,
LLM_MODEL_FILE,
LLM_MODEL_REPO,
LLM_N_GPU_LAYERS,
LLM_N_THREADS,
LLM_TEMPERATURE,
)
logger = logging.getLogger(__name__)
class LLMHandler:
"""
Wraps llama-cpp-python to provide generate() and generate_stream() interfaces.
The model is lazily loaded on the first call to avoid blocking UI startup.
"""
def __init__(self) -> None:
self._llm: Llama | None = None
# ── Public API ───────────────────────────────────────────────────────────
def generate(self, prompt: str) -> str:
"""
Run inference on the given prompt and return the full generated text.
Args:
prompt: Fully formatted RAG prompt string.
Returns:
Generated answer string (stripped of whitespace).
"""
llm = self._get_or_load_model()
logger.debug("Running LLM inference (prompt length=%d chars) …", len(prompt))
output = llm(
prompt,
max_tokens=LLM_MAX_TOKENS,
temperature=LLM_TEMPERATURE,
stop=["Sources:", "</s>"],
echo=False,
)
answer = output["choices"][0]["text"].strip()
logger.debug("LLM generated %d chars.", len(answer))
return answer
def generate_stream(self, prompt: str) -> Generator[str, None, None]:
"""
Run inference and yield tokens one by one as the model generates them.
Used by chat_stream() in chatbot.py to enable word-by-word UI streaming.
The only API difference from generate() is stream=True and yield instead
of return. The "if token:" guard skips empty strings llama-cpp may emit.
Args:
prompt: Fully formatted RAG prompt string.
Yields:
Individual token strings as the model produces them.
"""
llm = self._get_or_load_model()
logger.debug(
"Running streaming LLM inference (prompt length=%d chars) …", len(prompt)
)
output = llm(
prompt,
max_tokens=LLM_MAX_TOKENS,
temperature=LLM_TEMPERATURE,
stop=["Sources:", "</s>"],
echo=False,
stream=True, # ← only difference from generate()
)
for chunk in output:
token = chunk["choices"][0]["text"]
if token: # skip empty strings llama-cpp occasionally emits
yield token
# ── Private helpers ──────────────────────────────────────────────────────
def _get_or_load_model(self) -> Llama:
if self._llm is None:
model_path = self._download_model()
logger.info("Loading LLM from '%s' …", model_path)
self._llm = Llama(
model_path=str(model_path),
n_ctx=LLM_CONTEXT_LEN,
n_threads=LLM_N_THREADS,
n_gpu_layers=LLM_N_GPU_LAYERS,
verbose=False,
)
logger.info("LLM ready.")
return self._llm
@staticmethod
def _download_model() -> Path:
# Use locally cached model β€” no download needed
local_path = Path(LLM_CACHE_DIR) / LLM_MODEL_FILE
if local_path.exists():
logger.info("Model found locally at '%s'.", local_path)
return local_path
# Fallback β€” download from HuggingFace Hub if not found locally
logger.info("Local model not found, downloading from HuggingFace Hub …")
downloaded = hf_hub_download(
repo_id=LLM_MODEL_REPO,
filename=LLM_MODEL_FILE,
token=os.environ.get("HF_TOKEN"),
)
logger.info("Model downloaded to '%s'.", downloaded)
return Path(downloaded)