File size: 4,592 Bytes
bf707d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# openai_backend.py
from __future__ import annotations
import os, json, base64, logging, asyncio
from typing import Any, AsyncIterable, Dict, Optional

from openai import AsyncOpenAI
from openai._types import NOT_GIVEN

from config import settings
from oa_server import ChatBackend, ImagesBackend  # reuse your ABCs

log = logging.getLogger(__name__)

def _pick_api_key() -> str:
    # Use HF/OpenAI-compatible key first (from appsettings → env), else OPENAI_API_KEY
    return (
        (getattr(settings, "LlmHFKey", None) or os.getenv("LlmHFKey")) or
        (getattr(settings, "OpenAIApiKey", None) or os.getenv("OpenAIApiKey")) or
        os.getenv("OPENAI_API_KEY", "")
    )

def _pick_base_url() -> Optional[str]:
    # If you’ve configured a custom OpenAI-compatible endpoint (e.g. Novita), use it.
    url = getattr(settings, "LlmHFUrl", None) or os.getenv("LlmHFUrl")
    return url or None

def _pick_default_model(incoming: Dict[str, Any]) -> str:
    # Honor request.model, else prefer HF model id, else OpenAI model from config.
    return (
        incoming.get("model")
        or getattr(settings, "LlmHFModelID", None)
        or getattr(settings, "LlmGptModel", "gpt-4o-mini")
    )

class OpenAICompatChatBackend(ChatBackend):
    """
    Streams Chat Completions from any OpenAI-compatible server.
    - Passes 'tools'/'tool_choice' straight through (function-calling).
    - Accepts multimodal 'messages[*].content' with text+image_url.
    - Streams ChatCompletionChunk objects; we convert to plain dicts.
    """
    def __init__(self):
        api_key  = _pick_api_key()
        base_url = _pick_base_url()
        if not api_key:
            log.warning("No API key found; requests will fail.")
        self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)

    async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
        # Strip our internal fields; forward only OpenAI payload
        req = dict(request)
        req.pop("reply_key", None)  # handled by the caller
        # Ensure streaming on the provider, even if caller omitted it
        req.setdefault("stream", True)
        req.setdefault("model", _pick_default_model(req))

        # Some providers don’t like unknown keys; drop obviously non-OpenAI keys defensively
        for k in ("ExchangeName", "FuncName", "MessageTimeout", "RoutingKeys"):
            req.pop(k, None)

        # OpenAI SDK returns an async iterator of ChatCompletionChunk
        stream = await self.client.chat.completions.create(**req)  # stream=True above
        async for chunk in stream:
            # Convert to plain dict for serialization over MQ
            if hasattr(chunk, "model_dump_json"):
                yield json.loads(chunk.model_dump_json())
            elif hasattr(chunk, "to_dict"):
                yield chunk.to_dict()
            else:
                yield chunk  # already a dict

class OpenAIImagesBackend(ImagesBackend):
    """
    Generates base64 images via OpenAI-compatible Images API.
    """
    def __init__(self):
        api_key  = _pick_api_key()
        base_url = _pick_base_url()
        self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)

    async def generate_b64(self, request: Dict[str, Any]) -> str:
        # Expect OpenAI 'images.generate' style fields
        #  - model (required by most providers)
        #  - prompt / or 'prompt' inside request['prompt']
        #  - size like '1024x1024'
        model = request.get("model") or getattr(settings, "LlmHFModelID", None) or "gpt-image-1"
        size  = request.get("size", "1024x1024")
        n     = int(request.get("n", 1))
        resp = await self.client.images.generate(
            model=model,
            prompt=request.get("prompt", ""),
            size=size,
            n=n,
            # If upstream sends 'background' or 'transparent_background', pass-through if supported:
            background=request.get("background") if "background" in request else NOT_GIVEN,
            transparent_background=request.get("transparent_background") if "transparent_background" in request else NOT_GIVEN,
        )
        # Return first image b64
        data0 = resp.data[0]
        if hasattr(data0, "b64_json") and data0.b64_json:
            return data0.b64_json
        # Some providers return URLs; fetch is out-of-scope here—return placeholder
        log.warning("Images API returned URL instead of b64; returning 1x1 transparent pixel.")
        return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="