File size: 3,813 Bytes
a4b70d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from __future__ import annotations

import json
import requests
import os

from ..needs_auth.OpenaiAPI import OpenaiAPI
from ...requests import StreamSession, raise_for_status
from ...providers.response import Usage, Reasoning
from ...tools.run_tools import AuthManager
from ...typing import AsyncResult, Messages

class Ollama(OpenaiAPI):
    label = "Ollama 🦙"
    url = "https://ollama.com"
    login_url = "https://ollama.com/settings/keys"
    api_endpoint = "https://ollama.com/api/chat"
    needs_auth = False
    working = True
    active_by_default = True
    local_models: list[str] = []
    model_aliases = {
        "gpt-oss-120b": "gpt-oss:120b",
        "gpt-oss-20b": "gpt-oss:20b"
    }

    @classmethod
    def get_models(cls, api_key: str = None, api_base: str = None, **kwargs):
        if not cls.models:
            cls.models = []
            if not api_key:
                api_key = AuthManager.load_api_key(cls)
            if api_key:
                models = requests.get("https://ollama.com/api/tags", {"headers": {"Authorization": f"Bearer {api_key}"}}).json()["models"]
                if models:
                    cls.live += 1
                cls.models = [model["name"] for model in models]
            if api_base is None:
                host = os.getenv("OLLAMA_HOST", "127.0.0.1")
                port = os.getenv("OLLAMA_PORT", "11434")
                url = f"http://{host}:{port}/api/tags"
            else:
                url = api_base.replace("/v1", "/api/tags")
            try:
                models = requests.get(url).json()["models"]
            except requests.exceptions.RequestException as e:
                return cls.models
            if cls.live == 0 and models:
                cls.live += 1
            cls.local_models = [model["name"] for model in models]
            cls.models = cls.models + cls.local_models
            cls.default_model = next(iter(cls.models), None)
        return cls.models

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        api_key: str = None,
        api_base: str = None,
        proxy: str = None,
        **kwargs
    ) -> AsyncResult:
        if api_base is None:
            host = os.getenv("OLLAMA_HOST", "localhost")
            port = os.getenv("OLLAMA_PORT", "11434")
            api_base: str = f"http://{host}:{port}/v1"
        if model in cls.local_models or not api_key:
            async for chunk in super().create_async_generator(
                model, messages, api_base=api_base, proxy=proxy, **kwargs
            ):
                yield chunk
        else:
            async with StreamSession(headers={"Authorization": f"Bearer {api_key}"}, proxy=proxy) as session:
                async with session.post(cls.api_endpoint, json={
                    "model": model,
                    "messages": messages,
                }) as response:
                    await raise_for_status(response)
                    last_data = {}
                    async for chunk in response.iter_lines():
                        data = json.loads(chunk)
                        last_data = data
                        thinking = data.get("message", {}).get("thinking", "")
                        if thinking:
                            yield Reasoning(thinking)
                        content = data.get("message", {}).get("content", "")
                        if content:
                            yield content
                    yield Usage(
                        prompt_tokens=last_data.get("prompt_eval_count", 0),
                        completion_tokens=last_data.get("eval_count", 0),
                        total_tokens=last_data.get("prompt_eval_count", 0) + last_data.get("eval_count", 0),
                    )