File size: 8,244 Bytes
15d27ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# openai_server.py
from __future__ import annotations
import asyncio, json, time, uuid, math, logging
from typing import Any, AsyncIterable, Dict, List, Optional

import aio_pika

logger = logging.getLogger(__name__)

# --------------------------- Helpers ---------------------------

def _now() -> int:
    return int(time.time())

def _chunk_text(s: str, sz: int = 120) -> List[str]:
    if not s:
        return []
    return [s[i:i+sz] for i in range(0, len(s), sz)]

def _last_user_text(messages: List[Dict[str, Any]]) -> str:
    # Accept either string or multimodal parts [{type:"text"/"image_url"/...}]
    for m in reversed(messages or []):
        if (m or {}).get("role") == "user":
            c = m.get("content", "")
            if isinstance(c, str):
                return c
            if isinstance(c, list):
                texts = [p.get("text","") for p in c if p.get("type") == "text"]
                return " ".join([t for t in texts if t])
    return ""

# --------------------------- Backends ---------------------------
# You can replace DummyChatBackend with a real LLM (OpenAI/HF/local).
class ChatBackend:
    async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
        raise NotImplementedError

class DummyChatBackend(ChatBackend):
    async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
        """
        Emits OpenAI-shaped *streaming* chunks.
        - No tool_calls for now (keeps server simple)
        - Mimics delta frames + final finish_reason
        """
        rid = f"chatcmpl-{uuid.uuid4().hex[:12]}"
        model = request.get("model", "gpt-4o-mini")
        text = _last_user_text(request.get("messages", [])) or "(empty)"
        answer = f"Echo (RabbitMQ): {text}"
        now = _now()

        # First delta sets the role per OpenAI stream shape
        yield {
            "id": rid, "object": "chat.completion.chunk", "created": now, "model": model,
            "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}]
        }

        # Stream content in small pieces
        for piece in _chunk_text(answer, 140):
            yield {
                "id": rid, "object": "chat.completion.chunk", "created": now, "model": model,
                "choices": [{"index": 0, "delta": {"content": piece}, "finish_reason": None}]
            }

        # Final delta with finish_reason
        yield {
            "id": rid, "object": "chat.completion.chunk", "created": now, "model": model,
            "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
        }

class ImagesBackend:
    async def generate_b64(self, request: Dict[str, Any]) -> str:
        """
        Return base64 image string. This is a stub.
        Replace with your image generator (e.g., SDXL, OpenAI gpt-image-1, etc.).
        """
        # For now, return a 1x1 transparent PNG
        return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="

# --------------------------- Servers ---------------------------

class ChatCompletionsServer:
    """
    Consumes OpenAI Chat Completions requests from exchange `oa.chat.create`,
    routing-key `default`, and streams OpenAI-shaped chunks back to `reply_to`.
    """
    def __init__(self, amqp_url: str, *, exchange_name: str = "oa.chat.create", routing_key: str = "default", backend: Optional[ChatBackend] = None):
        self._amqp_url = amqp_url
        self._exchange_name = exchange_name
        self._routing_key = routing_key
        self._backend = backend or DummyChatBackend()
        self._conn: Optional[aio_pika.RobustConnection] = None
        self._ch: Optional[aio_pika.RobustChannel] = None
        self._ex: Optional[aio_pika.Exchange] = None
        self._queue_name = f"{exchange_name}.{routing_key}"

    async def start(self):
        self._conn = await aio_pika.connect_robust(self._amqp_url)
        self._ch = await self._conn.channel()
        self._ex = await self._ch.declare_exchange(self._exchange_name, aio_pika.ExchangeType.DIRECT, durable=True)
        q = await self._ch.declare_queue(self._queue_name, durable=True)
        await q.bind(self._ex, routing_key=self._routing_key)
        await q.consume(self._on_message)
        logger.info("ChatCompletionsServer listening on %s/%s β†’ %s", self._exchange_name, self._routing_key, self._queue_name)

    async def _on_message(self, msg: aio_pika.IncomingMessage):
        async with msg.process(ignore_processed=True):
            try:
                req = json.loads(msg.body.decode("utf-8", errors="replace"))
                reply_to = msg.reply_to
                corr_id = msg.correlation_id
                if not reply_to or not corr_id:
                    logger.warning("Missing reply_to/correlation_id; dropping.")
                    return

                async for chunk in self._backend.stream(req):
                    await self._ch.default_exchange.publish(
                        aio_pika.Message(
                            body=json.dumps(chunk).encode("utf-8"),
                            correlation_id=corr_id,
                            content_type="application/json",
                            delivery_mode=aio_pika.DeliveryMode.NOT_PERSISTENT,
                        ),
                        routing_key=reply_to,
                    )

                # Optional end sentinel
                await self._ch.default_exchange.publish(
                    aio_pika.Message(
                        body=b'{"object":"stream.end"}',
                        correlation_id=corr_id,
                        content_type="application/json",
                    ),
                    routing_key=reply_to,
                )

            except Exception:
                logger.exception("ChatCompletionsServer: failed to process message")

class ImagesServer:
    """
    Consumes OpenAI Images API requests from exchange `oa.images.generate`,
    routing-key `default`, and replies once with {data:[{b64_json:...}], created:...}.
    """
    def __init__(self, amqp_url: str, *, exchange_name: str = "oa.images.generate", routing_key: str = "default", backend: Optional[ImagesBackend] = None):
        self._amqp_url = amqp_url
        self._exchange_name = exchange_name
        self._routing_key = routing_key
        self._backend = backend or ImagesBackend()
        self._conn: Optional[aio_pika.RobustConnection] = None
        self._ch: Optional[aio_pika.RobustChannel] = None
        self._ex: Optional[aio_pika.Exchange] = None
        self._queue_name = f"{exchange_name}.{routing_key}"

    async def start(self):
        self._conn = await aio_pika.connect_robust(self._amqp_url)
        self._ch = await self._conn.channel()
        self._ex = await self._ch.declare_exchange(self._exchange_name, aio_pika.ExchangeType.DIRECT, durable=True)
        q = await self._ch.declare_queue(self._queue_name, durable=True)
        await q.bind(self._ex, routing_key=self._routing_key)
        await q.consume(self._on_message)
        logger.info("ImagesServer listening on %s/%s β†’ %s", self._exchange_name, self._routing_key, self._queue_name)

    async def _on_message(self, msg: aio_pika.IncomingMessage):
        async with msg.process(ignore_processed=True):
            try:
                req = json.loads(msg.body.decode("utf-8", errors="replace"))
                reply_to = msg.reply_to
                corr_id = msg.correlation_id
                if not reply_to or not corr_id:
                    logger.warning("Missing reply_to/correlation_id; dropping.")
                    return

                b64_img = await self._backend.generate_b64(req)
                resp = {"created": _now(), "data": [{"b64_json": b64_img}]}

                await self._ch.default_exchange.publish(
                    aio_pika.Message(
                        body=json.dumps(resp).encode("utf-8"),
                        correlation_id=corr_id,
                        content_type="application/json",
                    ),
                    routing_key=reply_to,
                )

            except Exception:
                logger.exception("ImagesServer: failed to process message")