Spaces:
Sleeping
Sleeping
| """ | |
| MediGuard AI β Ollama Client | |
| Production-grade wrapper for the Ollama API with health checks, | |
| streaming, and LangChain integration. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from collections.abc import Iterator | |
| from functools import lru_cache | |
| from typing import Any | |
| import httpx | |
| from src.exceptions import OllamaConnectionError, OllamaModelNotFoundError | |
| from src.settings import get_settings | |
| logger = logging.getLogger(__name__) | |
| class OllamaClient: | |
| """Wrapper around the Ollama REST API.""" | |
| def __init__(self, base_url: str, *, timeout: int = 120): | |
| self.base_url = base_url.rstrip("/") | |
| self.timeout = timeout | |
| self._http = httpx.Client(base_url=self.base_url, timeout=timeout) | |
| # ββ Health βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ping(self) -> bool: | |
| try: | |
| resp = self._http.get("/api/version") | |
| return resp.status_code == 200 | |
| except Exception: | |
| return False | |
| def health(self) -> dict[str, Any]: | |
| try: | |
| resp = self._http.get("/api/version") | |
| resp.raise_for_status() | |
| return resp.json() | |
| except Exception as exc: | |
| raise OllamaConnectionError(f"Cannot reach Ollama: {exc}") from exc | |
| def list_models(self) -> list[str]: | |
| try: | |
| resp = self._http.get("/api/tags") | |
| resp.raise_for_status() | |
| return [m["name"] for m in resp.json().get("models", [])] | |
| except Exception as exc: | |
| logger.warning("Failed to list Ollama models: %s", exc) | |
| return [] | |
| # ββ Generation βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate( | |
| self, | |
| prompt: str, | |
| *, | |
| model: str | None = None, | |
| system: str = "", | |
| temperature: float = 0.0, | |
| num_ctx: int = 8192, | |
| ) -> dict[str, Any]: | |
| """Synchronous generation β returns the full response dict.""" | |
| model = model or get_settings().ollama.model | |
| payload: dict[str, Any] = { | |
| "model": model, | |
| "prompt": prompt, | |
| "stream": False, | |
| "options": {"temperature": temperature, "num_ctx": num_ctx}, | |
| } | |
| if system: | |
| payload["system"] = system | |
| try: | |
| resp = self._http.post("/api/generate", json=payload) | |
| resp.raise_for_status() | |
| return resp.json() | |
| except httpx.HTTPStatusError as exc: | |
| if exc.response.status_code == 404: | |
| raise OllamaModelNotFoundError(f"Model '{model}' not found on Ollama server") | |
| raise OllamaConnectionError(str(exc)) | |
| except Exception as exc: | |
| raise OllamaConnectionError(str(exc)) from exc | |
| def generate_stream( | |
| self, | |
| prompt: str, | |
| *, | |
| model: str | None = None, | |
| system: str = "", | |
| temperature: float = 0.0, | |
| num_ctx: int = 8192, | |
| ) -> Iterator[str]: | |
| """Streaming generation β yields text tokens.""" | |
| model = model or get_settings().ollama.model | |
| payload: dict[str, Any] = { | |
| "model": model, | |
| "prompt": prompt, | |
| "stream": True, | |
| "options": {"temperature": temperature, "num_ctx": num_ctx}, | |
| } | |
| if system: | |
| payload["system"] = system | |
| try: | |
| with self._http.stream("POST", "/api/generate", json=payload) as resp: | |
| resp.raise_for_status() | |
| import json | |
| for line in resp.iter_lines(): | |
| if line: | |
| data = json.loads(line) | |
| token = data.get("response", "") | |
| if token: | |
| yield token | |
| if data.get("done", False): | |
| break | |
| except Exception as exc: | |
| raise OllamaConnectionError(str(exc)) from exc | |
| # ββ LangChain integration ββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_langchain_model( | |
| self, | |
| *, | |
| model: str | None = None, | |
| temperature: float = 0.0, | |
| json_mode: bool = False, | |
| ): | |
| """Return a LangChain ChatOllama instance.""" | |
| model = model or get_settings().ollama.model | |
| try: | |
| from langchain_ollama import ChatOllama | |
| except ImportError: | |
| from langchain_community.chat_models import ChatOllama | |
| return ChatOllama( | |
| model=model, | |
| temperature=temperature, | |
| base_url=self.base_url, | |
| format="json" if json_mode else None, | |
| ) | |
| def close(self): | |
| self._http.close() | |
| def make_ollama_client() -> OllamaClient: | |
| settings = get_settings() | |
| client = OllamaClient( | |
| base_url=settings.ollama.host, | |
| timeout=settings.ollama.timeout, | |
| ) | |
| if client.ping(): | |
| logger.info("Ollama connected at %s", settings.ollama.host) | |
| else: | |
| logger.warning("Ollama not reachable at %s", settings.ollama.host) | |
| return client | |