Spaces:
Sleeping
Sleeping
File size: 5,430 Bytes
1e732dd 696f787 1e732dd 696f787 1e732dd 696f787 1e732dd 9659593 1e732dd 696f787 1e732dd 696f787 1e732dd 696f787 1e732dd 696f787 1e732dd 9659593 1e732dd 696f787 1e732dd 696f787 1e732dd 9659593 1e732dd 9659593 1e732dd 696f787 1e732dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """
MediGuard AI β Ollama Client
Production-grade wrapper for the Ollama API with health checks,
streaming, and LangChain integration.
"""
from __future__ import annotations
import logging
from collections.abc import Iterator
from functools import lru_cache
from typing import Any
import httpx
from src.exceptions import OllamaConnectionError, OllamaModelNotFoundError
from src.settings import get_settings
logger = logging.getLogger(__name__)
class OllamaClient:
"""Wrapper around the Ollama REST API."""
def __init__(self, base_url: str, *, timeout: int = 120):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._http = httpx.Client(base_url=self.base_url, timeout=timeout)
# ββ Health βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def ping(self) -> bool:
try:
resp = self._http.get("/api/version")
return resp.status_code == 200
except Exception:
return False
def health(self) -> dict[str, Any]:
try:
resp = self._http.get("/api/version")
resp.raise_for_status()
return resp.json()
except Exception as exc:
raise OllamaConnectionError(f"Cannot reach Ollama: {exc}") from exc
def list_models(self) -> list[str]:
try:
resp = self._http.get("/api/tags")
resp.raise_for_status()
return [m["name"] for m in resp.json().get("models", [])]
except Exception as exc:
logger.warning("Failed to list Ollama models: %s", exc)
return []
# ββ Generation βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def generate(
self,
prompt: str,
*,
model: str | None = None,
system: str = "",
temperature: float = 0.0,
num_ctx: int = 8192,
) -> dict[str, Any]:
"""Synchronous generation β returns the full response dict."""
model = model or get_settings().ollama.model
payload: dict[str, Any] = {
"model": model,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature, "num_ctx": num_ctx},
}
if system:
payload["system"] = system
try:
resp = self._http.post("/api/generate", json=payload)
resp.raise_for_status()
return resp.json()
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 404:
raise OllamaModelNotFoundError(f"Model '{model}' not found on Ollama server")
raise OllamaConnectionError(str(exc))
except Exception as exc:
raise OllamaConnectionError(str(exc)) from exc
def generate_stream(
self,
prompt: str,
*,
model: str | None = None,
system: str = "",
temperature: float = 0.0,
num_ctx: int = 8192,
) -> Iterator[str]:
"""Streaming generation β yields text tokens."""
model = model or get_settings().ollama.model
payload: dict[str, Any] = {
"model": model,
"prompt": prompt,
"stream": True,
"options": {"temperature": temperature, "num_ctx": num_ctx},
}
if system:
payload["system"] = system
try:
with self._http.stream("POST", "/api/generate", json=payload) as resp:
resp.raise_for_status()
import json
for line in resp.iter_lines():
if line:
data = json.loads(line)
token = data.get("response", "")
if token:
yield token
if data.get("done", False):
break
except Exception as exc:
raise OllamaConnectionError(str(exc)) from exc
# ββ LangChain integration ββββββββββββββββββββββββββββββββββββββββββββ
def get_langchain_model(
self,
*,
model: str | None = None,
temperature: float = 0.0,
json_mode: bool = False,
):
"""Return a LangChain ChatOllama instance."""
model = model or get_settings().ollama.model
try:
from langchain_ollama import ChatOllama
except ImportError:
from langchain_community.chat_models import ChatOllama
return ChatOllama(
model=model,
temperature=temperature,
base_url=self.base_url,
format="json" if json_mode else None,
)
def close(self):
self._http.close()
@lru_cache(maxsize=1)
def make_ollama_client() -> OllamaClient:
settings = get_settings()
client = OllamaClient(
base_url=settings.ollama.host,
timeout=settings.ollama.timeout,
)
if client.ping():
logger.info("Ollama connected at %s", settings.ollama.host)
else:
logger.warning("Ollama not reachable at %s", settings.ollama.host)
return client
|