from __future__ import annotations import os, json, time, uuid, asyncio, logging from typing import Any, AsyncGenerator from contextlib import asynccontextmanager from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel from gradio_client import Client load_dotenv() # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- API_KEY = os.getenv("API_KEY", "") HF_SPACE_URL = os.getenv("HF_SPACE_URL", "") MODEL_ID = os.getenv("MODEL_ID", "") DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6")) DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95")) DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "1024")) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Gradio client (singleton) # --------------------------------------------------------------------------- _client: Client | None = None async def get_client() -> Client: global _client if _client is None: log.info("Connecting to %s", HF_SPACE_URL) _client = await asyncio.to_thread(Client, HF_SPACE_URL) log.info("Connected.") return _client # --------------------------------------------------------------------------- # Pydantic schemas # --------------------------------------------------------------------------- class Message(BaseModel): role: str content: str | list[dict] = "" name: str | None = None class ChatCompletionRequest(BaseModel): model: str = MODEL_ID messages: list[Message] temperature: float = DEFAULT_TEMP top_p: float = DEFAULT_TOP_P max_tokens: int = DEFAULT_TOKENS stream: bool = False frequency_penalty: float = 0 presence_penalty: float = 0 stop: str | list[str] | None = None seed: int | None = None user: str | None = None # --------------------------------------------------------------------------- # Auth # --------------------------------------------------------------------------- async def verify_key(request: Request) -> None: if not API_KEY: return auth = request.headers.get("Authorization", "") if not auth.startswith("Bearer ") or auth[7:] != API_KEY: raise HTTPException(status_code=401, detail="Invalid or missing API key") # --------------------------------------------------------------------------- # Lifespan context manager (modern FastAPI pattern) # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): # Startup log.info("Starting up - connecting to Gradio client...") await get_client() log.info("Startup complete.") yield # Shutdown (if needed) log.info("Shutting down.") # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- app = FastAPI( title="Falcon H1R API", version="3.1.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------------------------- # Business logic - EXACTLY like the HTML chatbot # --------------------------------------------------------------------------- def _content_str(m: Message) -> str: if isinstance(m.content, str): return m.content return "".join(p.get("text", "") for p in m.content if p.get("type") == "text") def _build_prompt(messages: list[Message]) -> str: """Flatten messages into a single prompt string.""" system, parts = [], [] for m in messages: c = _content_str(m) if m.role == "system": system.append(c) elif m.role == "user": parts.append(c) elif m.role == "assistant": parts.append(f"[ASSISTANT]\n{c}") prefix = "[SYSTEM]\n" + "\n".join(system) + "\n[/SYSTEM]\n" if system else "" return prefix + "\n".join(parts) def _extract_text(result) -> str: """ HTML chatbot does: const last = res.data[5].value.at(-1); const text = Array.isArray(last.content) ? last.content.filter(p => p.type === 'text').map(p => p.content.trim()).join('') : last.content; """ try: # res.data is a list, index 5 contains the chatbot component chatbot_data = result.data[5] # chatbot_data is a dict with 'value' key conversation = chatbot_data["value"] # last message last = conversation[-1] content = last["content"] if isinstance(content, list): # Filter type='text' blocks return "".join( p["content"].strip() for p in content if p.get("type") == "text" ) return str(content) except Exception as e: log.error("_extract_text failed: %s | raw data: %s", e, result.data) raise ValueError(f"Failed to extract text: {e}") from e async def _call_falcon(prompt: str, req: ChatCompletionRequest) -> str: """ Exact replica of HTML submit() function: 1. client.predict('/add_message', { input_value: msg, settings_form_value: PARAMS }) 2. Extract res.data[5].value.at(-1).content """ client = await get_client() settings = { "model": req.model, "temperature": req.temperature, "max_new_tokens": req.max_tokens, "top_p": req.top_p, } # Step 1: Reset chat (like boot() does once, but we do per request for isolation) await asyncio.to_thread( client.predict, api_name="/new_chat" ) # Step 2: Send message - EXACTLY like HTML result = await asyncio.to_thread( client.predict, input_value=prompt, settings_form_value=settings, api_name="/add_message" ) return _extract_text(result) def _make_response(text: str, req: ChatCompletionRequest) -> dict: pt = sum(len(_content_str(m)) for m in req.messages) // 4 ct = len(text) // 4 return { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": req.model, "system_fingerprint": f"fp_{uuid.uuid4().hex[:8]}", "choices": [{ "index": 0, "message": { "role": "assistant", "content": text, "tool_calls": None, "function_call": None, }, "finish_reason": "stop", "logprobs": None, }], "usage": { "prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct, }, } async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]: """Simulate streaming by chunking the full response.""" cid = f"chatcmpl-{uuid.uuid4().hex}" created = int(time.time()) # Stream in small chunks for i in range(0, len(text), 6): chunk = { "id": cid, "object": "chat.completion.chunk", "created": created, "model": req.model, "choices": [{ "index": 0, "delta": {"role": "assistant", "content": text[i:i+6]}, "finish_reason": None, }], } yield f"data: {json.dumps(chunk)}\n\n" await asyncio.sleep(0.01) # Final chunk pt = sum(len(_content_str(m)) for m in req.messages) // 4 ct = len(text) // 4 final = { "id": cid, "object": "chat.completion.chunk", "created": created, "model": req.model, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "usage": {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct}, } yield f"data: {json.dumps(final)}\n\n" yield "data: [DONE]\n\n" # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.get("/") async def root(): return { "service": "Falcon H1R OpenAI-compatible API", "version": "3.1.0", "endpoints": { "health": "/health", "models": "/v1/models", "chat": "/v1/chat/completions", }, } @app.get("/health") async def health(): return {"status": "ok", "model": MODEL_ID, "space": HF_SPACE_URL} @app.get("/v1/models") async def list_models(_: None = Depends(verify_key)): return {"object": "list", "data": [{ "id": MODEL_ID, "object": "model", "created": 1710000000, "owned_by": "tiiuae", }]} @app.post("/v1/chat/completions") async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_key)): prompt = _build_prompt(req.messages) log.info("Request | model=%s temp=%.2f tokens=%d stream=%s", req.model, req.temperature, req.max_tokens, req.stream) try: text = await _call_falcon(prompt, req) except Exception as exc: log.exception("Falcon call failed") raise HTTPException(status_code=502, detail=f"Upstream error: {exc}") from exc if req.stream: return StreamingResponse( _stream_sse(text, req), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) return JSONResponse(content=_make_response(text, req))