T0X1N's picture
chore: codebase audit and fixes (ruff, mypy, pytest)
9659593
"""
MediGuard AI β€” Ollama Client
Production-grade wrapper for the Ollama API with health checks,
streaming, and LangChain integration.
"""
from __future__ import annotations
import logging
from collections.abc import Iterator
from functools import lru_cache
from typing import Any
import httpx
from src.exceptions import OllamaConnectionError, OllamaModelNotFoundError
from src.settings import get_settings
logger = logging.getLogger(__name__)
class OllamaClient:
"""Wrapper around the Ollama REST API."""
def __init__(self, base_url: str, *, timeout: int = 120):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._http = httpx.Client(base_url=self.base_url, timeout=timeout)
# ── Health ───────────────────────────────────────────────────────────
def ping(self) -> bool:
try:
resp = self._http.get("/api/version")
return resp.status_code == 200
except Exception:
return False
def health(self) -> dict[str, Any]:
try:
resp = self._http.get("/api/version")
resp.raise_for_status()
return resp.json()
except Exception as exc:
raise OllamaConnectionError(f"Cannot reach Ollama: {exc}") from exc
def list_models(self) -> list[str]:
try:
resp = self._http.get("/api/tags")
resp.raise_for_status()
return [m["name"] for m in resp.json().get("models", [])]
except Exception as exc:
logger.warning("Failed to list Ollama models: %s", exc)
return []
# ── Generation ───────────────────────────────────────────────────────
def generate(
self,
prompt: str,
*,
model: str | None = None,
system: str = "",
temperature: float = 0.0,
num_ctx: int = 8192,
) -> dict[str, Any]:
"""Synchronous generation β€” returns the full response dict."""
model = model or get_settings().ollama.model
payload: dict[str, Any] = {
"model": model,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature, "num_ctx": num_ctx},
}
if system:
payload["system"] = system
try:
resp = self._http.post("/api/generate", json=payload)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 404:
raise OllamaModelNotFoundError(f"Model '{model}' not found on Ollama server")
raise OllamaConnectionError(str(exc))
except Exception as exc:
raise OllamaConnectionError(str(exc)) from exc
def generate_stream(
self,
prompt: str,
*,
model: str | None = None,
system: str = "",
temperature: float = 0.0,
num_ctx: int = 8192,
) -> Iterator[str]:
"""Streaming generation β€” yields text tokens."""
model = model or get_settings().ollama.model
payload: dict[str, Any] = {
"model": model,
"prompt": prompt,
"stream": True,
"options": {"temperature": temperature, "num_ctx": num_ctx},
}
if system:
payload["system"] = system
try:
with self._http.stream("POST", "/api/generate", json=payload) as resp:
resp.raise_for_status()
import json
for line in resp.iter_lines():
if line:
data = json.loads(line)
token = data.get("response", "")
if token:
yield token
if data.get("done", False):
break
except Exception as exc:
raise OllamaConnectionError(str(exc)) from exc
# ── LangChain integration ────────────────────────────────────────────
def get_langchain_model(
self,
*,
model: str | None = None,
temperature: float = 0.0,
json_mode: bool = False,
):
"""Return a LangChain ChatOllama instance."""
model = model or get_settings().ollama.model
try:
from langchain_ollama import ChatOllama
except ImportError:
from langchain_community.chat_models import ChatOllama
return ChatOllama(
model=model,
temperature=temperature,
base_url=self.base_url,
format="json" if json_mode else None,
)
def close(self):
self._http.close()
@lru_cache(maxsize=1)
def make_ollama_client() -> OllamaClient:
settings = get_settings()
client = OllamaClient(
base_url=settings.ollama.host,
timeout=settings.ollama.timeout,
)
if client.ping():
logger.info("Ollama connected at %s", settings.ollama.host)
else:
logger.warning("Ollama not reachable at %s", settings.ollama.host)
return client