| |
| import os |
| import json |
| from typing import List, Optional, AsyncGenerator |
| from fastapi import FastAPI, HTTPException, Header |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel |
| import httpx |
|
|
| SERVER_API_KEY = os.getenv("OPENROUTER_API_KEY") |
| OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions" |
| DEFAULT_MODEL = "openai/gpt-4o-mini" |
|
|
| app = FastAPI(title="OpenRouter Chat Backend") |
|
|
|
|
| class Message(BaseModel): |
| role: str |
| content: str |
|
|
|
|
| class ChatRequest(BaseModel): |
| messages: List[Message] |
| model: Optional[str] = None |
| temperature: Optional[float] = 0.7 |
| max_tokens: Optional[int] = None |
|
|
|
|
| @app.post("/chat") |
| async def chat( |
| req: ChatRequest, |
| x_openrouter_api_key: Optional[str] = Header(default=None, alias="X-OpenRouter-Api-Key"), |
| http_referer: Optional[str] = Header(default=None, alias="HTTP-Referer"), |
| ): |
| model = req.model or DEFAULT_MODEL |
| |
| api_key = x_openrouter_api_key or SERVER_API_KEY |
| if not api_key: |
| raise HTTPException(status_code=401, detail="Missing OpenRouter API key") |
|
|
| headers = { |
| "Authorization": f"Bearer {api_key}", |
| "Content-Type": "application/json", |
| |
| "HTTP-Referer": http_referer or "http://localhost:8501", |
| "X-Title": "OpenRouter Streamlit Chat", |
| } |
| payload = { |
| "model": model, |
| "messages": [m.dict() for m in req.messages], |
| "temperature": req.temperature, |
| "max_tokens": req.max_tokens, |
| "stream": True, |
| } |
|
|
| async def stream() -> AsyncGenerator[bytes, None]: |
| try: |
| async with httpx.AsyncClient(timeout=None) as client: |
| async with client.stream("POST", OPENROUTER_BASE_URL, headers=headers, json=payload) as r: |
| if r.status_code >= 400: |
| text = await r.aread() |
| raise HTTPException(status_code=r.status_code, detail=text.decode("utf-8")) |
| async for line in r.aiter_lines(): |
| if not line: |
| continue |
| if line.startswith("data: "): |
| data = line[len("data: "):] |
| if data.strip() == "[DONE]": |
| yield b"event: done\ndata: done\n\n" |
| break |
| try: |
| obj = json.loads(data) |
| delta = obj.get("choices", [{}])[0].get("delta", {}) |
| content = delta.get("content") |
| if content: |
| yield f"data: {json.dumps({'content': content})}\n\n".encode("utf-8") |
| except Exception: |
| yield f"data: {json.dumps({'raw': data})}\n\n".encode("utf-8") |
| except HTTPException: |
| raise |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| return StreamingResponse(stream(), media_type="text/event-stream") |