| | from __future__ import annotations |
| |
|
| | import os |
| | import json |
| | import time |
| | import uuid |
| | import asyncio |
| | import logging |
| | from typing import Any, AsyncGenerator |
| | from contextlib import asynccontextmanager |
| |
|
| | from dotenv import load_dotenv |
| | from fastapi import FastAPI, HTTPException, Request, Depends |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import StreamingResponse, JSONResponse |
| | from pydantic import BaseModel |
| | from gradio_client import Client |
| |
|
| | load_dotenv() |
| |
|
| | |
| | |
| | |
| |
|
| | API_KEY = os.getenv("API_KEY", "") |
| | HF_SPACE_URL = os.getenv("HF_SPACE_URL", "") |
| | MODEL_ID = os.getenv("MODEL_ID", "") |
| |
|
| | DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6")) |
| | DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95")) |
| | DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "16000")) |
| |
|
| | REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "120")) |
| | MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3")) |
| | RETRY_BASE_DELAY = float(os.getenv("RETRY_BASE_DELAY", "1.5")) |
| |
|
| | MAX_INPUT_TOKENS = 16000 |
| |
|
| | |
| | AVG_CHARS_PER_TOKEN = 4 |
| |
|
| | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| | log = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| |
|
| | _client: Client | None = None |
| |
|
| | async def get_client() -> Client: |
| | global _client |
| | if _client is None: |
| | log.info("Connecting to %s", HF_SPACE_URL) |
| | _client = await asyncio.to_thread(Client, HF_SPACE_URL) |
| | log.info("Connected.") |
| | return _client |
| |
|
| | |
| | |
| | |
| |
|
| | class Message(BaseModel): |
| | role: str |
| | content: str | list[dict] = "" |
| | name: str | None = None |
| |
|
| | class ChatCompletionRequest(BaseModel): |
| | model: str = MODEL_ID |
| | messages: list[Message] |
| | temperature: float = DEFAULT_TEMP |
| | top_p: float = DEFAULT_TOP_P |
| | max_tokens: int = DEFAULT_TOKENS |
| | stream: bool = False |
| | frequency_penalty: float = 0 |
| | presence_penalty: float = 0 |
| | stop: str | list[str] | None = None |
| | seed: int | None = None |
| | user: str | None = None |
| |
|
| | |
| | |
| | |
| |
|
| | async def verify_key(request: Request) -> None: |
| | if not API_KEY: |
| | return |
| | auth = request.headers.get("Authorization", "") |
| | if not auth.startswith("Bearer ") or auth[7:] != API_KEY: |
| | raise HTTPException(status_code=401, detail="Invalid or missing API key") |
| |
|
| | |
| | |
| | |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | log.info("Startup: connecting to Gradio client...") |
| | await get_client() |
| | yield |
| | log.info("Shutdown.") |
| |
|
| | |
| | |
| | |
| |
|
| | def _content_str(m: Message) -> str: |
| | if isinstance(m.content, str): |
| | return m.content |
| | text_parts = [] |
| | for p in m.content: |
| | if isinstance(p, dict) and p.get("type") == "text": |
| | text_parts.append(p.get("text", "").strip()) |
| | return "".join(text_parts) |
| |
|
| | def _token_count(text: str) -> int: |
| | return max(1, len(text) // AVG_CHARS_PER_TOKEN) |
| |
|
| | def _condense_messages(messages: list[Message], max_tokens: int) -> str: |
| | system_msgs = [m for m in messages if m.role == "system"] |
| | user_assistant = [m for m in messages if m.role in ("user", "assistant")] |
| |
|
| | condensed_parts = [] |
| | for m in system_msgs: |
| | condensed_parts.append(_content_str(m)) |
| |
|
| | tokens_so_far = sum(_token_count(part) for part in condensed_parts) |
| |
|
| | for m in user_assistant: |
| | text = _content_str(m) |
| | tcount = _token_count(text) |
| | if tokens_so_far + tcount > max_tokens: |
| | remaining_tokens = max_tokens - tokens_so_far |
| | if remaining_tokens <= 0: |
| | continue |
| | approx_chars = remaining_tokens * AVG_CHARS_PER_TOKEN |
| | text = text[-approx_chars:] |
| | tcount = _token_count(text) |
| | condensed_parts.append(text) |
| | tokens_so_far += tcount |
| |
|
| | return "\n".join(condensed_parts) |
| |
|
| | def _build_prompt(messages: list[Message]) -> str: |
| | prompt = _condense_messages(messages, MAX_INPUT_TOKENS) |
| | log.info("Final prompt token count: ~%d", _token_count(prompt)) |
| | return prompt |
| |
|
| | |
| | |
| | |
| |
|
| | def _extract_text(result: Any) -> str: |
| | if isinstance(result, tuple): |
| | data = result |
| | elif hasattr(result, "data"): |
| | data = result.data |
| | else: |
| | data = [result] |
| |
|
| | conversation = None |
| | for item in data: |
| | if isinstance(item, dict) and "value" in item and isinstance(item["value"], list): |
| | conversation = item["value"] |
| | break |
| | elif isinstance(item, list): |
| | conversation = item |
| | break |
| |
|
| | if not conversation: |
| | raise ValueError("Cannot extract conversation from result") |
| |
|
| | last = conversation[-1] |
| |
|
| | if isinstance(last, dict): |
| | content = last.get("content", "") |
| | elif isinstance(last, (list, tuple)) and len(last) >= 2: |
| | content = last[1] or "" |
| | else: |
| | content = str(last) |
| |
|
| | if isinstance(content, list): |
| | parts = [] |
| | for block in content: |
| | if isinstance(block, dict) and block.get("type") == "text": |
| | parts.append(block.get("content", block.get("text", ""))) |
| | return "".join(parts).strip() |
| | return str(content).strip() |
| |
|
| | |
| | |
| | |
| |
|
| | async def _call_with_retries(prompt: str, req: ChatCompletionRequest) -> str: |
| | last_error = None |
| | for attempt in range(1, MAX_RETRIES + 1): |
| | try: |
| | return await asyncio.wait_for(_call_falcon_once(prompt, req), timeout=REQUEST_TIMEOUT) |
| | except Exception as e: |
| | last_error = e |
| | if attempt == MAX_RETRIES: |
| | break |
| | delay = RETRY_BASE_DELAY ** attempt |
| | log.warning("Attempt %d failed: %s | retrying in %.2fs", attempt, str(e), delay) |
| | await asyncio.sleep(delay) |
| | raise last_error |
| |
|
| | |
| | |
| | |
| |
|
| | async def _call_falcon_once(prompt: str, req: ChatCompletionRequest) -> str: |
| | client = await get_client() |
| | settings = { |
| | "model": req.model, |
| | "temperature": req.temperature, |
| | "max_new_tokens": req.max_tokens, |
| | "top_p": req.top_p, |
| | } |
| |
|
| | |
| | await asyncio.to_thread(client.predict, api_name="/new_chat") |
| |
|
| | |
| | result = await asyncio.to_thread( |
| | client.predict, |
| | prompt, |
| | settings_form_value=settings, |
| | api_name="/add_message", |
| | ) |
| | return _extract_text(result) |
| |
|
| | |
| | |
| | |
| |
|
| | async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]: |
| | cid = f"chatcmpl-{uuid.uuid4().hex}" |
| | created = int(time.time()) |
| | for i in range(0, len(text), 8): |
| | chunk = { |
| | "id": cid, |
| | "object": "chat.completion.chunk", |
| | "created": created, |
| | "model": req.model, |
| | "choices": [{"index": 0, "delta": {"content": text[i:i+8]}, "finish_reason": None}] |
| | } |
| | yield f"data: {json.dumps(chunk)}\n\n" |
| | await asyncio.sleep(0.01) |
| | final_chunk = { |
| | "id": cid, |
| | "object": "chat.completion.chunk", |
| | "created": created, |
| | "model": req.model, |
| | "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] |
| | } |
| | yield f"data: {json.dumps(final_chunk)}\n\n" |
| | yield "data: [DONE]\n\n" |
| |
|
| | |
| | |
| | |
| |
|
| | def _make_response(text: str, req: ChatCompletionRequest) -> dict: |
| | pt = sum(len(_content_str(m)) for m in req.messages) // 4 |
| | ct = len(text) // 4 |
| | return { |
| | "id": f"chatcmpl-{uuid.uuid4().hex}", |
| | "object": "chat.completion", |
| | "created": int(time.time()), |
| | "model": req.model, |
| | "choices": [{"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"}], |
| | "usage": {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct}, |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | app = FastAPI(title="Foc", version="5.0.0", lifespan=lifespan) |
| | app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) |
| |
|
| | @app.get("/") |
| | async def root(): |
| | return { |
| | "service": "FOC API", |
| | "version": "5.0.0", |
| | "endpoints": { |
| | "health": "/health", |
| | "models": "/v1/models", |
| | "chat": "/v1/chat/completions" |
| | } |
| | } |
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return {"status": "ok", "model": MODEL_ID, "space": HF_SPACE_URL} |
| |
|
| | @app.get("/v1/models") |
| | async def list_models(_: None = Depends(verify_key)): |
| | return {"object": "list", "data": [{"id": MODEL_ID, "object": "model", "created": 1710000000, "owned_by": "tiiuae"}]} |
| |
|
| | @app.post("/v1/chat/completions") |
| | async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_key)): |
| | prompt = _build_prompt(req.messages) |
| | log.info("Request | model=%s temp=%.2f tokens=%d stream=%s", req.model, req.temperature, req.max_tokens, req.stream) |
| |
|
| | try: |
| | text = await _call_with_retries(prompt, req) |
| | except Exception: |
| | log.exception("Falcon failed after retries") |
| | raise HTTPException(status_code=502, detail="Model temporarily unavailable. Please try again.") |
| |
|
| | if req.stream: |
| | return StreamingResponse( |
| | _stream_sse(text, req), |
| | media_type="text/event-stream", |
| | headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"}, |
| | ) |
| |
|
| | return JSONResponse(content=_make_response(text, req)) |