reframe / inference.py
Venkatesh Rajagopal
REFRAME: live CBT studio — fine-tuned Gemma 12B on Modal + Cohere voice (ZeroGPU)
4ae4ae8
Raw
History Blame Contribute Delete
4.64 kB
"""Dual-backend inference: Ollama (local) or llama-cpp-python (HF Spaces)."""
from __future__ import annotations
import json
from collections.abc import Generator
import config
def _get_ollama_client():
"""Lazy import and create Ollama client."""
import httpx
# Large timeout: model cold-load can take 60s+, generation is streamed
timeout = httpx.Timeout(connect=10.0, read=300.0, write=10.0, pool=10.0)
return httpx.Client(base_url=config.OLLAMA_BASE_URL, timeout=timeout)
def _get_llamacpp_model():
"""Lazy-load llama-cpp-python model (downloads GGUF if needed)."""
from llama_cpp import Llama
model_path = config.GGUF_LOCAL_PATH
if not model_path:
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(
repo_id=config.GGUF_REPO_ID,
filename=config.GGUF_FILENAME,
)
return Llama(
model_path=model_path,
n_ctx=4096,
n_gpu_layers=-1, # Use all available GPU layers
verbose=False,
)
# Module-level cache
_llm_model = None
def _get_model():
global _llm_model
if _llm_model is None and config.BACKEND == "llamacpp":
_llm_model = _get_llamacpp_model()
return _llm_model
def stream_response(
user_message: str,
history: list[dict],
system_prompt: str,
) -> Generator[str, None, None]:
"""Stream model response token by token.
Args:
user_message: The latest user message.
history: List of {"role": ..., "content": ...} dicts (prior turns).
system_prompt: Full system prompt with session context.
Yields:
Partial response strings (accumulating).
"""
if config.BACKEND == "ollama":
yield from _stream_ollama(user_message, history, system_prompt)
else:
yield from _stream_llamacpp(user_message, history, system_prompt)
def _build_messages(user_message: str, history: list[dict], system_prompt: str) -> list[dict]:
"""Build the messages list for the model."""
messages = [{"role": "system", "content": system_prompt}]
for msg in history:
role = msg.get("role", "user")
content = msg.get("content", "")
# Gradio 6 may store content as a list of part-dicts; flatten to text.
if isinstance(content, list):
content = " ".join(
str(p.get("text", "")) if isinstance(p, dict) else str(p)
for p in content
).strip()
if isinstance(content, str) and content.strip():
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": user_message})
return messages
def _stream_ollama(
user_message: str,
history: list[dict],
system_prompt: str,
) -> Generator[str, None, None]:
"""Stream from local Ollama instance."""
messages = _build_messages(user_message, history, system_prompt)
client = _get_ollama_client()
response = ""
with client.stream(
"POST",
"/api/chat",
json={
"model": config.OLLAMA_MODEL,
"messages": messages,
"stream": True,
"keep_alive": config.OLLAMA_KEEP_ALIVE,
"options": {
"temperature": config.TEMPERATURE,
"top_p": config.TOP_P,
"num_predict": config.MAX_TOKENS,
"repeat_penalty": config.REPEAT_PENALTY,
},
},
) as stream:
for line in stream.iter_lines():
if not line:
continue
try:
data = json.loads(line)
token = data.get("message", {}).get("content", "")
if token:
response += token
yield response
if data.get("done", False):
break
except json.JSONDecodeError:
continue
def _stream_llamacpp(
user_message: str,
history: list[dict],
system_prompt: str,
) -> Generator[str, None, None]:
"""Stream from llama-cpp-python (for HF Spaces)."""
messages = _build_messages(user_message, history, system_prompt)
model = _get_model()
response = ""
for chunk in model.create_chat_completion(
messages=messages,
max_tokens=config.MAX_TOKENS,
temperature=config.TEMPERATURE,
top_p=config.TOP_P,
repeat_penalty=config.REPEAT_PENALTY,
stream=True,
):
delta = chunk.get("choices", [{}])[0].get("delta", {})
token = delta.get("content", "")
if token:
response += token
yield response