File size: 4,368 Bytes
371efe0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Any, Iterator

from ollama import Client

from model_identity import (
    CLOUD_SOURCE,
    LOCAL_SOURCE,
    normalize_model_source,
    resolve_model_host,
)


@dataclass(frozen=True)
class ChatStreamEvent:
    content: str = ""
    done: bool = False
    generated_tokens: int | None = None
    prompt_tokens: int | None = None


def _chunk_value(chunk: Any, key: str) -> Any:
    if isinstance(chunk, dict):
        return chunk.get(key)
    return getattr(chunk, key, None)


def _chunk_content(chunk: Any) -> str:
    content = ""
    message = _chunk_value(chunk, "message")
    if isinstance(message, dict):
        content = str(message.get("content", "") or "")
    elif message is not None:
        content = str(getattr(message, "content", "") or "")
    if not content:
        content = str(_chunk_value(chunk, "response") or "")
    return content


def _optional_int(value: Any) -> int | None:
    if isinstance(value, bool):
        return None
    if isinstance(value, int):
        return value
    if isinstance(value, float) and value.is_integer():
        return int(value)
    return None


def get_cloud_client(api_key: str | None = None) -> Client:
    resolved_api_key = str(api_key or "").strip() or os.getenv("OLLAMA_API_KEY", "").strip()
    if not resolved_api_key:
        raise RuntimeError("OLLAMA_API_KEY is not set. Enter Ollama API Key to use Ollama Cloud models.")
    host = resolve_model_host(CLOUD_SOURCE, cloud_host=os.getenv("OLLAMA_HOST", ""))
    return Client(host=host, headers={"Authorization": f"Bearer {resolved_api_key}"})


def get_local_client(host: str | None = None) -> Client:
    resolved_host = resolve_model_host(LOCAL_SOURCE, local_host=host)
    return Client(host=resolved_host)


def get_client_for_source(source: str, host: str | None = None, api_key: str | None = None) -> Client:
    normalized_source = normalize_model_source(source)
    if normalized_source == LOCAL_SOURCE:
        return get_local_client(host)
    return get_cloud_client(api_key=api_key)


def get_client(api_key: str | None = None) -> Client:
    # Backward-compatible alias for call sites that still use cloud-only path.
    return get_cloud_client(api_key=api_key)


def list_models(client: Client, *, source: str = CLOUD_SOURCE) -> list[str]:
    normalized_source = normalize_model_source(source)
    try:
        payload = client.list()
    except Exception:
        if normalized_source == LOCAL_SOURCE:
            return []
        raise

    models = []
    if isinstance(payload, dict):
        raw_models = payload.get("models", [])
    elif isinstance(payload, list):
        raw_models = payload
    else:
        raw_models = getattr(payload, "models", []) or []

    for item in raw_models:
        if isinstance(item, dict):
            name = item.get("model") or item.get("name")
        else:
            name = getattr(item, "model", None) or getattr(item, "name", None)
        if name:
            models.append(str(name))
    return sorted(set(models))


def stream_chat_events(
    client: Client,
    model: str,
    prompt: str,
    system_prompt: str = "",
) -> Iterator[ChatStreamEvent]:
    messages = []
    if system_prompt.strip():
        messages.append({"role": "system", "content": system_prompt.strip()})
    messages.append({"role": "user", "content": prompt.strip()})

    stream = client.chat(model=model, messages=messages, stream=True)
    for chunk in stream:
        content = _chunk_content(chunk)
        done = bool(_chunk_value(chunk, "done"))
        generated_tokens = _optional_int(_chunk_value(chunk, "eval_count"))
        prompt_tokens = _optional_int(_chunk_value(chunk, "prompt_eval_count"))
        if content or done or generated_tokens is not None or prompt_tokens is not None:
            yield ChatStreamEvent(
                content=content,
                done=done,
                generated_tokens=generated_tokens,
                prompt_tokens=prompt_tokens,
            )


def stream_chat(
    client: Client,
    model: str,
    prompt: str,
    system_prompt: str = "",
) -> Iterator[str]:
    for event in stream_chat_events(client=client, model=model, prompt=prompt, system_prompt=system_prompt):
        if event.content:
            yield event.content