GradLLM / oa_server.py
johnbridges's picture
.
9f3b48c
raw
history blame
4.75 kB
# oa_server.py
from __future__ import annotations
import json, time, uuid, logging
from typing import Any, Dict, List, AsyncIterable, Optional
from rabbit_repo import RabbitRepo
logger = logging.getLogger(__name__)
# ------------------ helpers ------------------
def _now() -> int: return int(time.time())
def _chunk_text(s: str, sz: int = 140) -> List[str]:
return [s[i:i+sz] for i in range(0, len(s or ""), sz)] if s else []
def _last_user_text(messages: List[Dict[str, Any]]) -> str:
for m in reversed(messages or []):
if (m or {}).get("role") == "user":
c = m.get("content", "")
if isinstance(c, str):
return c
if isinstance(c, list):
texts = [p.get("text","") for p in c if p.get("type") == "text"]
return " ".join([t for t in texts if t])
return ""
# ------------------ backends (replace later) ------------------
class ChatBackend:
async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
raise NotImplementedError
class DummyChatBackend(ChatBackend):
async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
rid = f"chatcmpl-{uuid.uuid4().hex[:12]}"
model = request.get("model", "gpt-4o-mini")
text = _last_user_text(request.get("messages", [])) or "(empty)"
out = f"Echo (Rabbit): {text}"
now = _now()
# role delta
yield {"id": rid, "object":"chat.completion.chunk", "created": now, "model": model,
"choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":None}]}
# content deltas
for piece in _chunk_text(out, 140):
yield {"id": rid, "object":"chat.completion.chunk", "created": now, "model": model,
"choices":[{"index":0,"delta":{"content":piece},"finish_reason":None}]}
# final delta
yield {"id": rid, "object":"chat.completion.chunk", "created": now, "model": model,
"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
class ImagesBackend:
async def generate_b64(self, request: Dict[str, Any]) -> str:
# 1x1 transparent PNG (stub)
return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
# ------------------ handler class ------------------
class OpenAIServers:
"""
Handlers you can register in RabbitListenerBase:
- 'oaChatCreate' -> handle_chat_create
- 'oaImagesGenerate' -> handle_images_generate
Uses RabbitRepo.publish(...) to emit CloudEvent-wrapped OpenAI JSON.
"""
def __init__(self, publisher: RabbitRepo,
*, chat_backend: Optional[ChatBackend] = None,
images_backend: Optional[ImagesBackend] = None):
self._pub = publisher
self._chat = chat_backend or DummyChatBackend()
self._img = images_backend or ImagesBackend()
# -------- Chat Completions --------
async def handle_chat_create(self, data: Dict[str, Any]) -> None:
"""
data: OpenAI chat request + 'reply_key' (string)
Server publishes to exchange 'oa.chat.reply' with routing_key = reply_key.
"""
if not isinstance(data, dict):
logger.warning("oaChatCreate: data is not a dict")
return
reply_key = data.get("reply_key")
if not reply_key:
logger.error("oaChatCreate: missing reply_key")
return
try:
async for chunk in self._chat.stream(data):
# CloudEvent-wrapped OpenAI chunk to oa.chat.reply
await self._pub.publish("oa.chat.reply", chunk, routing_key=reply_key)
# Optional sentinel
await self._pub.publish("oa.chat.reply", {"object": "stream.end"}, routing_key=reply_key)
except Exception:
logger.exception("oaChatCreate: streaming failed")
# -------- Images (generations) --------
async def handle_images_generate(self, data: Dict[str, Any]) -> None:
"""
data: OpenAI images.generate request + 'reply_key' (string)
"""
if not isinstance(data, dict):
logger.warning("oaImagesGenerate: data is not a dict")
return
reply_key = data.get("reply_key")
if not reply_key:
logger.error("oaImagesGenerate: missing reply_key")
return
try:
b64 = await self._img.generate_b64(data)
resp = {"created": _now(), "data":[{"b64_json": b64}]}
await self._pub.publish("oa.images.reply", resp, routing_key=reply_key)
except Exception:
logger.exception("oaImagesGenerate: generation failed")