fhoc / app.py
adowu's picture
Update app.py
d8c472e verified
from __future__ import annotations
import os, json, time, uuid, asyncio, logging
from typing import Any, AsyncGenerator
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from gradio_client import Client
load_dotenv()
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
API_KEY = os.getenv("API_KEY", "")
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "")
MODEL_ID = os.getenv("MODEL_ID", "")
DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6"))
DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "1024"))
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Gradio client (singleton)
# ---------------------------------------------------------------------------
_client: Client | None = None
async def get_client() -> Client:
global _client
if _client is None:
log.info("Connecting to %s", HF_SPACE_URL)
_client = await asyncio.to_thread(Client, HF_SPACE_URL)
log.info("Connected.")
return _client
# ---------------------------------------------------------------------------
# Pydantic schemas
# ---------------------------------------------------------------------------
class Message(BaseModel):
role: str
content: str | list[dict] = ""
name: str | None = None
class ChatCompletionRequest(BaseModel):
model: str = MODEL_ID
messages: list[Message]
temperature: float = DEFAULT_TEMP
top_p: float = DEFAULT_TOP_P
max_tokens: int = DEFAULT_TOKENS
stream: bool = False
frequency_penalty: float = 0
presence_penalty: float = 0
stop: str | list[str] | None = None
seed: int | None = None
user: str | None = None
# ---------------------------------------------------------------------------
# Auth
# ---------------------------------------------------------------------------
async def verify_key(request: Request) -> None:
if not API_KEY:
return
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer ") or auth[7:] != API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing API key")
# ---------------------------------------------------------------------------
# Lifespan context manager (modern FastAPI pattern)
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
log.info("Starting up - connecting to Gradio client...")
await get_client()
log.info("Startup complete.")
yield
# Shutdown (if needed)
log.info("Shutting down.")
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
app = FastAPI(
title="Falcon H1R API",
version="3.1.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Business logic - EXACTLY like the HTML chatbot
# ---------------------------------------------------------------------------
def _content_str(m: Message) -> str:
if isinstance(m.content, str):
return m.content
return "".join(p.get("text", "") for p in m.content if p.get("type") == "text")
def _build_prompt(messages: list[Message]) -> str:
"""Flatten messages into a single prompt string."""
system, parts = [], []
for m in messages:
c = _content_str(m)
if m.role == "system": system.append(c)
elif m.role == "user": parts.append(c)
elif m.role == "assistant": parts.append(f"[ASSISTANT]\n{c}")
prefix = "[SYSTEM]\n" + "\n".join(system) + "\n[/SYSTEM]\n" if system else ""
return prefix + "\n".join(parts)
def _extract_text(result) -> str:
"""
HTML chatbot does:
const last = res.data[5].value.at(-1);
const text = Array.isArray(last.content)
? last.content.filter(p => p.type === 'text').map(p => p.content.trim()).join('')
: last.content;
"""
try:
# res.data is a list, index 5 contains the chatbot component
chatbot_data = result.data[5]
# chatbot_data is a dict with 'value' key
conversation = chatbot_data["value"]
# last message
last = conversation[-1]
content = last["content"]
if isinstance(content, list):
# Filter type='text' blocks
return "".join(
p["content"].strip()
for p in content
if p.get("type") == "text"
)
return str(content)
except Exception as e:
log.error("_extract_text failed: %s | raw data: %s", e, result.data)
raise ValueError(f"Failed to extract text: {e}") from e
async def _call_falcon(prompt: str, req: ChatCompletionRequest) -> str:
"""
Exact replica of HTML submit() function:
1. client.predict('/add_message', { input_value: msg, settings_form_value: PARAMS })
2. Extract res.data[5].value.at(-1).content
"""
client = await get_client()
settings = {
"model": req.model,
"temperature": req.temperature,
"max_new_tokens": req.max_tokens,
"top_p": req.top_p,
}
# Step 1: Reset chat (like boot() does once, but we do per request for isolation)
await asyncio.to_thread(
client.predict,
api_name="/new_chat"
)
# Step 2: Send message - EXACTLY like HTML
result = await asyncio.to_thread(
client.predict,
input_value=prompt,
settings_form_value=settings,
api_name="/add_message"
)
return _extract_text(result)
def _make_response(text: str, req: ChatCompletionRequest) -> dict:
pt = sum(len(_content_str(m)) for m in req.messages) // 4
ct = len(text) // 4
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": req.model,
"system_fingerprint": f"fp_{uuid.uuid4().hex[:8]}",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text,
"tool_calls": None,
"function_call": None,
},
"finish_reason": "stop",
"logprobs": None,
}],
"usage": {
"prompt_tokens": pt,
"completion_tokens": ct,
"total_tokens": pt + ct,
},
}
async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
"""Simulate streaming by chunking the full response."""
cid = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
# Stream in small chunks
for i in range(0, len(text), 6):
chunk = {
"id": cid,
"object": "chat.completion.chunk",
"created": created,
"model": req.model,
"choices": [{
"index": 0,
"delta": {"role": "assistant", "content": text[i:i+6]},
"finish_reason": None,
}],
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.01)
# Final chunk
pt = sum(len(_content_str(m)) for m in req.messages) // 4
ct = len(text) // 4
final = {
"id": cid,
"object": "chat.completion.chunk",
"created": created,
"model": req.model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct},
}
yield f"data: {json.dumps(final)}\n\n"
yield "data: [DONE]\n\n"
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/")
async def root():
return {
"service": "Falcon H1R OpenAI-compatible API",
"version": "3.1.0",
"endpoints": {
"health": "/health",
"models": "/v1/models",
"chat": "/v1/chat/completions",
},
}
@app.get("/health")
async def health():
return {"status": "ok", "model": MODEL_ID, "space": HF_SPACE_URL}
@app.get("/v1/models")
async def list_models(_: None = Depends(verify_key)):
return {"object": "list", "data": [{
"id": MODEL_ID,
"object": "model",
"created": 1710000000,
"owned_by": "tiiuae",
}]}
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_key)):
prompt = _build_prompt(req.messages)
log.info("Request | model=%s temp=%.2f tokens=%d stream=%s",
req.model, req.temperature, req.max_tokens, req.stream)
try:
text = await _call_falcon(prompt, req)
except Exception as exc:
log.exception("Falcon call failed")
raise HTTPException(status_code=502, detail=f"Upstream error: {exc}") from exc
if req.stream:
return StreamingResponse(
_stream_sse(text, req),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
return JSONResponse(content=_make_response(text, req))