from __future__ import annotations import os import json import time import uuid import asyncio import logging from typing import Any, AsyncGenerator from contextlib import asynccontextmanager from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel from gradio_client import Client load_dotenv() # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- API_KEY = os.getenv("API_KEY", "") HF_SPACE_URL = os.getenv("HF_SPACE_URL", "") MODEL_ID = os.getenv("MODEL_ID", "") DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6")) DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95")) DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "16000")) REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "120")) MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3")) RETRY_BASE_DELAY = float(os.getenv("RETRY_BASE_DELAY", "1.5")) MAX_INPUT_TOKENS = 16000 # stała wartość # przybliżone przeliczenie: 1 token ~ 4 znaki AVG_CHARS_PER_TOKEN = 4 logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Gradio client (singleton) # --------------------------------------------------------------------------- _client: Client | None = None async def get_client() -> Client: global _client if _client is None: log.info("Connecting to %s", HF_SPACE_URL) _client = await asyncio.to_thread(Client, HF_SPACE_URL) log.info("Connected.") return _client # --------------------------------------------------------------------------- # Schemas # --------------------------------------------------------------------------- class Message(BaseModel): role: str content: str | list[dict] = "" name: str | None = None class ChatCompletionRequest(BaseModel): model: str = MODEL_ID messages: list[Message] temperature: float = DEFAULT_TEMP top_p: float = DEFAULT_TOP_P max_tokens: int = DEFAULT_TOKENS stream: bool = False frequency_penalty: float = 0 presence_penalty: float = 0 stop: str | list[str] | None = None seed: int | None = None user: str | None = None # --------------------------------------------------------------------------- # Auth # --------------------------------------------------------------------------- async def verify_key(request: Request) -> None: if not API_KEY: return auth = request.headers.get("Authorization", "") if not auth.startswith("Bearer ") or auth[7:] != API_KEY: raise HTTPException(status_code=401, detail="Invalid or missing API key") # --------------------------------------------------------------------------- # Lifespan # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): log.info("Startup: connecting to Gradio client...") await get_client() yield log.info("Shutdown.") # --------------------------------------------------------------------------- # Utilities # --------------------------------------------------------------------------- def _content_str(m: Message) -> str: if isinstance(m.content, str): return m.content text_parts = [] for p in m.content: if isinstance(p, dict) and p.get("type") == "text": text_parts.append(p.get("text", "").strip()) return "".join(text_parts) def _token_count(text: str) -> int: return max(1, len(text) // AVG_CHARS_PER_TOKEN) def _condense_messages(messages: list[Message], max_tokens: int) -> str: system_msgs = [m for m in messages if m.role == "system"] user_assistant = [m for m in messages if m.role in ("user", "assistant")] condensed_parts = [] for m in system_msgs: condensed_parts.append(_content_str(m)) tokens_so_far = sum(_token_count(part) for part in condensed_parts) for m in user_assistant: text = _content_str(m) tcount = _token_count(text) if tokens_so_far + tcount > max_tokens: remaining_tokens = max_tokens - tokens_so_far if remaining_tokens <= 0: continue approx_chars = remaining_tokens * AVG_CHARS_PER_TOKEN text = text[-approx_chars:] tcount = _token_count(text) condensed_parts.append(text) tokens_so_far += tcount return "\n".join(condensed_parts) def _build_prompt(messages: list[Message]) -> str: prompt = _condense_messages(messages, MAX_INPUT_TOKENS) log.info("Final prompt token count: ~%d", _token_count(prompt)) return prompt # --------------------------------------------------------------------------- # Extraction # --------------------------------------------------------------------------- def _extract_text(result: Any) -> str: if isinstance(result, tuple): data = result elif hasattr(result, "data"): data = result.data else: data = [result] conversation = None for item in data: if isinstance(item, dict) and "value" in item and isinstance(item["value"], list): conversation = item["value"] break elif isinstance(item, list): conversation = item break if not conversation: raise ValueError("Cannot extract conversation from result") last = conversation[-1] if isinstance(last, dict): content = last.get("content", "") elif isinstance(last, (list, tuple)) and len(last) >= 2: content = last[1] or "" else: content = str(last) if isinstance(content, list): parts = [] for block in content: if isinstance(block, dict) and block.get("type") == "text": parts.append(block.get("content", block.get("text", ""))) return "".join(parts).strip() return str(content).strip() # --------------------------------------------------------------------------- # Retry wrapper # --------------------------------------------------------------------------- async def _call_with_retries(prompt: str, req: ChatCompletionRequest) -> str: last_error = None for attempt in range(1, MAX_RETRIES + 1): try: return await asyncio.wait_for(_call_falcon_once(prompt, req), timeout=REQUEST_TIMEOUT) except Exception as e: last_error = e if attempt == MAX_RETRIES: break delay = RETRY_BASE_DELAY ** attempt log.warning("Attempt %d failed: %s | retrying in %.2fs", attempt, str(e), delay) await asyncio.sleep(delay) raise last_error # --------------------------------------------------------------------------- # Falcon call with explicit api_name # --------------------------------------------------------------------------- async def _call_falcon_once(prompt: str, req: ChatCompletionRequest) -> str: client = await get_client() settings = { "model": req.model, "temperature": req.temperature, "max_new_tokens": req.max_tokens, "top_p": req.top_p, } # Reset chat session await asyncio.to_thread(client.predict, api_name="/new_chat") # Add message with explicit api_name and settings result = await asyncio.to_thread( client.predict, prompt, # pierwszy argument settings_form_value=settings, api_name="/add_message", # <-- tutaj musi być endpoint z View API ) return _extract_text(result) # --------------------------------------------------------------------------- # Streaming # --------------------------------------------------------------------------- async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]: cid = f"chatcmpl-{uuid.uuid4().hex}" created = int(time.time()) for i in range(0, len(text), 8): chunk = { "id": cid, "object": "chat.completion.chunk", "created": created, "model": req.model, "choices": [{"index": 0, "delta": {"content": text[i:i+8]}, "finish_reason": None}] } yield f"data: {json.dumps(chunk)}\n\n" await asyncio.sleep(0.01) final_chunk = { "id": cid, "object": "chat.completion.chunk", "created": created, "model": req.model, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] } yield f"data: {json.dumps(final_chunk)}\n\n" yield "data: [DONE]\n\n" # --------------------------------------------------------------------------- # OpenAI-compatible response # --------------------------------------------------------------------------- def _make_response(text: str, req: ChatCompletionRequest) -> dict: pt = sum(len(_content_str(m)) for m in req.messages) // 4 ct = len(text) // 4 return { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": req.model, "choices": [{"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"}], "usage": {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct}, } # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- app = FastAPI(title="Foc", version="5.0.0", lifespan=lifespan) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) @app.get("/") async def root(): return { "service": "FOC API", "version": "5.0.0", "endpoints": { "health": "/health", "models": "/v1/models", "chat": "/v1/chat/completions" } } @app.get("/health") async def health(): return {"status": "ok", "model": MODEL_ID, "space": HF_SPACE_URL} @app.get("/v1/models") async def list_models(_: None = Depends(verify_key)): return {"object": "list", "data": [{"id": MODEL_ID, "object": "model", "created": 1710000000, "owned_by": "tiiuae"}]} @app.post("/v1/chat/completions") async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_key)): prompt = _build_prompt(req.messages) log.info("Request | model=%s temp=%.2f tokens=%d stream=%s", req.model, req.temperature, req.max_tokens, req.stream) try: text = await _call_with_retries(prompt, req) except Exception: log.exception("Falcon failed after retries") raise HTTPException(status_code=502, detail="Model temporarily unavailable. Please try again.") if req.stream: return StreamingResponse( _stream_sse(text, req), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"}, ) return JSONResponse(content=_make_response(text, req))