Spaces:
Sleeping
Sleeping
Venkatesh Rajagopal
REFRAME: live CBT studio — fine-tuned Gemma 12B on Modal + Cohere voice (ZeroGPU)
4ae4ae8 | """Dual-backend inference: Ollama (local) or llama-cpp-python (HF Spaces).""" | |
| from __future__ import annotations | |
| import json | |
| from collections.abc import Generator | |
| import config | |
| def _get_ollama_client(): | |
| """Lazy import and create Ollama client.""" | |
| import httpx | |
| # Large timeout: model cold-load can take 60s+, generation is streamed | |
| timeout = httpx.Timeout(connect=10.0, read=300.0, write=10.0, pool=10.0) | |
| return httpx.Client(base_url=config.OLLAMA_BASE_URL, timeout=timeout) | |
| def _get_llamacpp_model(): | |
| """Lazy-load llama-cpp-python model (downloads GGUF if needed).""" | |
| from llama_cpp import Llama | |
| model_path = config.GGUF_LOCAL_PATH | |
| if not model_path: | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download( | |
| repo_id=config.GGUF_REPO_ID, | |
| filename=config.GGUF_FILENAME, | |
| ) | |
| return Llama( | |
| model_path=model_path, | |
| n_ctx=4096, | |
| n_gpu_layers=-1, # Use all available GPU layers | |
| verbose=False, | |
| ) | |
| # Module-level cache | |
| _llm_model = None | |
| def _get_model(): | |
| global _llm_model | |
| if _llm_model is None and config.BACKEND == "llamacpp": | |
| _llm_model = _get_llamacpp_model() | |
| return _llm_model | |
| def stream_response( | |
| user_message: str, | |
| history: list[dict], | |
| system_prompt: str, | |
| ) -> Generator[str, None, None]: | |
| """Stream model response token by token. | |
| Args: | |
| user_message: The latest user message. | |
| history: List of {"role": ..., "content": ...} dicts (prior turns). | |
| system_prompt: Full system prompt with session context. | |
| Yields: | |
| Partial response strings (accumulating). | |
| """ | |
| if config.BACKEND == "ollama": | |
| yield from _stream_ollama(user_message, history, system_prompt) | |
| else: | |
| yield from _stream_llamacpp(user_message, history, system_prompt) | |
| def _build_messages(user_message: str, history: list[dict], system_prompt: str) -> list[dict]: | |
| """Build the messages list for the model.""" | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for msg in history: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| # Gradio 6 may store content as a list of part-dicts; flatten to text. | |
| if isinstance(content, list): | |
| content = " ".join( | |
| str(p.get("text", "")) if isinstance(p, dict) else str(p) | |
| for p in content | |
| ).strip() | |
| if isinstance(content, str) and content.strip(): | |
| messages.append({"role": role, "content": content}) | |
| messages.append({"role": "user", "content": user_message}) | |
| return messages | |
| def _stream_ollama( | |
| user_message: str, | |
| history: list[dict], | |
| system_prompt: str, | |
| ) -> Generator[str, None, None]: | |
| """Stream from local Ollama instance.""" | |
| messages = _build_messages(user_message, history, system_prompt) | |
| client = _get_ollama_client() | |
| response = "" | |
| with client.stream( | |
| "POST", | |
| "/api/chat", | |
| json={ | |
| "model": config.OLLAMA_MODEL, | |
| "messages": messages, | |
| "stream": True, | |
| "keep_alive": config.OLLAMA_KEEP_ALIVE, | |
| "options": { | |
| "temperature": config.TEMPERATURE, | |
| "top_p": config.TOP_P, | |
| "num_predict": config.MAX_TOKENS, | |
| "repeat_penalty": config.REPEAT_PENALTY, | |
| }, | |
| }, | |
| ) as stream: | |
| for line in stream.iter_lines(): | |
| if not line: | |
| continue | |
| try: | |
| data = json.loads(line) | |
| token = data.get("message", {}).get("content", "") | |
| if token: | |
| response += token | |
| yield response | |
| if data.get("done", False): | |
| break | |
| except json.JSONDecodeError: | |
| continue | |
| def _stream_llamacpp( | |
| user_message: str, | |
| history: list[dict], | |
| system_prompt: str, | |
| ) -> Generator[str, None, None]: | |
| """Stream from llama-cpp-python (for HF Spaces).""" | |
| messages = _build_messages(user_message, history, system_prompt) | |
| model = _get_model() | |
| response = "" | |
| for chunk in model.create_chat_completion( | |
| messages=messages, | |
| max_tokens=config.MAX_TOKENS, | |
| temperature=config.TEMPERATURE, | |
| top_p=config.TOP_P, | |
| repeat_penalty=config.REPEAT_PENALTY, | |
| stream=True, | |
| ): | |
| delta = chunk.get("choices", [{}])[0].get("delta", {}) | |
| token = delta.get("content", "") | |
| if token: | |
| response += token | |
| yield response | |