File size: 3,289 Bytes
9a43741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# server.py
import os
import json
from typing import List, Optional, AsyncGenerator
from fastapi import FastAPI, HTTPException, Header
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import httpx

SERVER_API_KEY = os.getenv("OPENROUTER_API_KEY")  # optional fallback
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions"
DEFAULT_MODEL = "openai/gpt-4o-mini"  # change as desired

app = FastAPI(title="OpenRouter Chat Backend")


class Message(BaseModel):
    role: str  # "system" | "user" | "assistant"
    content: str


class ChatRequest(BaseModel):
    messages: List[Message]
    model: Optional[str] = None
    temperature: Optional[float] = 0.7
    max_tokens: Optional[int] = None


@app.post("/chat")
async def chat(
    req: ChatRequest,
    x_openrouter_api_key: Optional[str] = Header(default=None, alias="X-OpenRouter-Api-Key"),
    http_referer: Optional[str] = Header(default=None, alias="HTTP-Referer"),
):
    model = req.model or DEFAULT_MODEL
    # Prefer per-request key from client; fallback to server key if available
    api_key = x_openrouter_api_key or SERVER_API_KEY
    if not api_key:
        raise HTTPException(status_code=401, detail="Missing OpenRouter API key")

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
        # Attribution headers (optional)
        "HTTP-Referer": http_referer or "http://localhost:8501",
        "X-Title": "OpenRouter Streamlit Chat",
    }
    payload = {
        "model": model,
        "messages": [m.dict() for m in req.messages],
        "temperature": req.temperature,
        "max_tokens": req.max_tokens,
        "stream": True,
    }

    async def stream() -> AsyncGenerator[bytes, None]:
        try:
            async with httpx.AsyncClient(timeout=None) as client:
                async with client.stream("POST", OPENROUTER_BASE_URL, headers=headers, json=payload) as r:
                    if r.status_code >= 400:
                        text = await r.aread()
                        raise HTTPException(status_code=r.status_code, detail=text.decode("utf-8"))
                    async for line in r.aiter_lines():
                        if not line:
                            continue
                        if line.startswith("data: "):
                            data = line[len("data: "):]
                            if data.strip() == "[DONE]":
                                yield b"event: done\ndata: done\n\n"
                                break
                            try:
                                obj = json.loads(data)
                                delta = obj.get("choices", [{}])[0].get("delta", {})
                                content = delta.get("content")
                                if content:
                                    yield f"data: {json.dumps({'content': content})}\n\n".encode("utf-8")
                            except Exception:
                                yield f"data: {json.dumps({'raw': data})}\n\n".encode("utf-8")
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))

    return StreamingResponse(stream(), media_type="text/event-stream")