"""Dual-backend inference: Ollama (local) or llama-cpp-python (HF Spaces).""" from __future__ import annotations import json from collections.abc import Generator import config def _get_ollama_client(): """Lazy import and create Ollama client.""" import httpx # Large timeout: model cold-load can take 60s+, generation is streamed timeout = httpx.Timeout(connect=10.0, read=300.0, write=10.0, pool=10.0) return httpx.Client(base_url=config.OLLAMA_BASE_URL, timeout=timeout) def _get_llamacpp_model(): """Lazy-load llama-cpp-python model (downloads GGUF if needed).""" from llama_cpp import Llama model_path = config.GGUF_LOCAL_PATH if not model_path: from huggingface_hub import hf_hub_download model_path = hf_hub_download( repo_id=config.GGUF_REPO_ID, filename=config.GGUF_FILENAME, ) return Llama( model_path=model_path, n_ctx=4096, n_gpu_layers=-1, # Use all available GPU layers verbose=False, ) # Module-level cache _llm_model = None def _get_model(): global _llm_model if _llm_model is None and config.BACKEND == "llamacpp": _llm_model = _get_llamacpp_model() return _llm_model def stream_response( user_message: str, history: list[dict], system_prompt: str, ) -> Generator[str, None, None]: """Stream model response token by token. Args: user_message: The latest user message. history: List of {"role": ..., "content": ...} dicts (prior turns). system_prompt: Full system prompt with session context. Yields: Partial response strings (accumulating). """ if config.BACKEND == "ollama": yield from _stream_ollama(user_message, history, system_prompt) else: yield from _stream_llamacpp(user_message, history, system_prompt) def _build_messages(user_message: str, history: list[dict], system_prompt: str) -> list[dict]: """Build the messages list for the model.""" messages = [{"role": "system", "content": system_prompt}] for msg in history: role = msg.get("role", "user") content = msg.get("content", "") # Gradio 6 may store content as a list of part-dicts; flatten to text. if isinstance(content, list): content = " ".join( str(p.get("text", "")) if isinstance(p, dict) else str(p) for p in content ).strip() if isinstance(content, str) and content.strip(): messages.append({"role": role, "content": content}) messages.append({"role": "user", "content": user_message}) return messages def _stream_ollama( user_message: str, history: list[dict], system_prompt: str, ) -> Generator[str, None, None]: """Stream from local Ollama instance.""" messages = _build_messages(user_message, history, system_prompt) client = _get_ollama_client() response = "" with client.stream( "POST", "/api/chat", json={ "model": config.OLLAMA_MODEL, "messages": messages, "stream": True, "keep_alive": config.OLLAMA_KEEP_ALIVE, "options": { "temperature": config.TEMPERATURE, "top_p": config.TOP_P, "num_predict": config.MAX_TOKENS, "repeat_penalty": config.REPEAT_PENALTY, }, }, ) as stream: for line in stream.iter_lines(): if not line: continue try: data = json.loads(line) token = data.get("message", {}).get("content", "") if token: response += token yield response if data.get("done", False): break except json.JSONDecodeError: continue def _stream_llamacpp( user_message: str, history: list[dict], system_prompt: str, ) -> Generator[str, None, None]: """Stream from llama-cpp-python (for HF Spaces).""" messages = _build_messages(user_message, history, system_prompt) model = _get_model() response = "" for chunk in model.create_chat_completion( messages=messages, max_tokens=config.MAX_TOKENS, temperature=config.TEMPERATURE, top_p=config.TOP_P, repeat_penalty=config.REPEAT_PENALTY, stream=True, ): delta = chunk.get("choices", [{}])[0].get("delta", {}) token = delta.get("content", "") if token: response += token yield response