openLLMbenchmark / engine.py
hf-space-deployer
HF Space deploy from main - 0b1e82967585f1407bf51086f2e5a962f178218a
371efe0
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any, Iterator
from ollama import Client
from model_identity import (
CLOUD_SOURCE,
LOCAL_SOURCE,
normalize_model_source,
resolve_model_host,
)
@dataclass(frozen=True)
class ChatStreamEvent:
content: str = ""
done: bool = False
generated_tokens: int | None = None
prompt_tokens: int | None = None
def _chunk_value(chunk: Any, key: str) -> Any:
if isinstance(chunk, dict):
return chunk.get(key)
return getattr(chunk, key, None)
def _chunk_content(chunk: Any) -> str:
content = ""
message = _chunk_value(chunk, "message")
if isinstance(message, dict):
content = str(message.get("content", "") or "")
elif message is not None:
content = str(getattr(message, "content", "") or "")
if not content:
content = str(_chunk_value(chunk, "response") or "")
return content
def _optional_int(value: Any) -> int | None:
if isinstance(value, bool):
return None
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
return None
def get_cloud_client(api_key: str | None = None) -> Client:
resolved_api_key = str(api_key or "").strip() or os.getenv("OLLAMA_API_KEY", "").strip()
if not resolved_api_key:
raise RuntimeError("OLLAMA_API_KEY is not set. Enter Ollama API Key to use Ollama Cloud models.")
host = resolve_model_host(CLOUD_SOURCE, cloud_host=os.getenv("OLLAMA_HOST", ""))
return Client(host=host, headers={"Authorization": f"Bearer {resolved_api_key}"})
def get_local_client(host: str | None = None) -> Client:
resolved_host = resolve_model_host(LOCAL_SOURCE, local_host=host)
return Client(host=resolved_host)
def get_client_for_source(source: str, host: str | None = None, api_key: str | None = None) -> Client:
normalized_source = normalize_model_source(source)
if normalized_source == LOCAL_SOURCE:
return get_local_client(host)
return get_cloud_client(api_key=api_key)
def get_client(api_key: str | None = None) -> Client:
# Backward-compatible alias for call sites that still use cloud-only path.
return get_cloud_client(api_key=api_key)
def list_models(client: Client, *, source: str = CLOUD_SOURCE) -> list[str]:
normalized_source = normalize_model_source(source)
try:
payload = client.list()
except Exception:
if normalized_source == LOCAL_SOURCE:
return []
raise
models = []
if isinstance(payload, dict):
raw_models = payload.get("models", [])
elif isinstance(payload, list):
raw_models = payload
else:
raw_models = getattr(payload, "models", []) or []
for item in raw_models:
if isinstance(item, dict):
name = item.get("model") or item.get("name")
else:
name = getattr(item, "model", None) or getattr(item, "name", None)
if name:
models.append(str(name))
return sorted(set(models))
def stream_chat_events(
client: Client,
model: str,
prompt: str,
system_prompt: str = "",
) -> Iterator[ChatStreamEvent]:
messages = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt.strip()})
messages.append({"role": "user", "content": prompt.strip()})
stream = client.chat(model=model, messages=messages, stream=True)
for chunk in stream:
content = _chunk_content(chunk)
done = bool(_chunk_value(chunk, "done"))
generated_tokens = _optional_int(_chunk_value(chunk, "eval_count"))
prompt_tokens = _optional_int(_chunk_value(chunk, "prompt_eval_count"))
if content or done or generated_tokens is not None or prompt_tokens is not None:
yield ChatStreamEvent(
content=content,
done=done,
generated_tokens=generated_tokens,
prompt_tokens=prompt_tokens,
)
def stream_chat(
client: Client,
model: str,
prompt: str,
system_prompt: str = "",
) -> Iterator[str]:
for event in stream_chat_events(client=client, model=model, prompt=prompt, system_prompt=system_prompt):
if event.content:
yield event.content