foc / main.py
adowu's picture
Update main.py
71b8430 verified
from __future__ import annotations
import os
import json
import time
import uuid
import asyncio
import logging
from typing import Any, AsyncGenerator
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from gradio_client import Client
load_dotenv()
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
API_KEY = os.getenv("API_KEY", "")
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "")
MODEL_ID = os.getenv("MODEL_ID", "")
DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6"))
DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "16000"))
REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "120"))
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3"))
RETRY_BASE_DELAY = float(os.getenv("RETRY_BASE_DELAY", "1.5"))
MAX_INPUT_TOKENS = 16000 # stała wartość
# przybliżone przeliczenie: 1 token ~ 4 znaki
AVG_CHARS_PER_TOKEN = 4
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Gradio client (singleton)
# ---------------------------------------------------------------------------
_client: Client | None = None
async def get_client() -> Client:
global _client
if _client is None:
log.info("Connecting to %s", HF_SPACE_URL)
_client = await asyncio.to_thread(Client, HF_SPACE_URL)
log.info("Connected.")
return _client
# ---------------------------------------------------------------------------
# Schemas
# ---------------------------------------------------------------------------
class Message(BaseModel):
role: str
content: str | list[dict] = ""
name: str | None = None
class ChatCompletionRequest(BaseModel):
model: str = MODEL_ID
messages: list[Message]
temperature: float = DEFAULT_TEMP
top_p: float = DEFAULT_TOP_P
max_tokens: int = DEFAULT_TOKENS
stream: bool = False
frequency_penalty: float = 0
presence_penalty: float = 0
stop: str | list[str] | None = None
seed: int | None = None
user: str | None = None
# ---------------------------------------------------------------------------
# Auth
# ---------------------------------------------------------------------------
async def verify_key(request: Request) -> None:
if not API_KEY:
return
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer ") or auth[7:] != API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing API key")
# ---------------------------------------------------------------------------
# Lifespan
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
log.info("Startup: connecting to Gradio client...")
await get_client()
yield
log.info("Shutdown.")
# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------
def _content_str(m: Message) -> str:
if isinstance(m.content, str):
return m.content
text_parts = []
for p in m.content:
if isinstance(p, dict) and p.get("type") == "text":
text_parts.append(p.get("text", "").strip())
return "".join(text_parts)
def _token_count(text: str) -> int:
return max(1, len(text) // AVG_CHARS_PER_TOKEN)
def _condense_messages(messages: list[Message], max_tokens: int) -> str:
system_msgs = [m for m in messages if m.role == "system"]
user_assistant = [m for m in messages if m.role in ("user", "assistant")]
condensed_parts = []
for m in system_msgs:
condensed_parts.append(_content_str(m))
tokens_so_far = sum(_token_count(part) for part in condensed_parts)
for m in user_assistant:
text = _content_str(m)
tcount = _token_count(text)
if tokens_so_far + tcount > max_tokens:
remaining_tokens = max_tokens - tokens_so_far
if remaining_tokens <= 0:
continue
approx_chars = remaining_tokens * AVG_CHARS_PER_TOKEN
text = text[-approx_chars:]
tcount = _token_count(text)
condensed_parts.append(text)
tokens_so_far += tcount
return "\n".join(condensed_parts)
def _build_prompt(messages: list[Message]) -> str:
prompt = _condense_messages(messages, MAX_INPUT_TOKENS)
log.info("Final prompt token count: ~%d", _token_count(prompt))
return prompt
# ---------------------------------------------------------------------------
# Extraction
# ---------------------------------------------------------------------------
def _extract_text(result: Any) -> str:
if isinstance(result, tuple):
data = result
elif hasattr(result, "data"):
data = result.data
else:
data = [result]
conversation = None
for item in data:
if isinstance(item, dict) and "value" in item and isinstance(item["value"], list):
conversation = item["value"]
break
elif isinstance(item, list):
conversation = item
break
if not conversation:
raise ValueError("Cannot extract conversation from result")
last = conversation[-1]
if isinstance(last, dict):
content = last.get("content", "")
elif isinstance(last, (list, tuple)) and len(last) >= 2:
content = last[1] or ""
else:
content = str(last)
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
parts.append(block.get("content", block.get("text", "")))
return "".join(parts).strip()
return str(content).strip()
# ---------------------------------------------------------------------------
# Retry wrapper
# ---------------------------------------------------------------------------
async def _call_with_retries(prompt: str, req: ChatCompletionRequest) -> str:
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
return await asyncio.wait_for(_call_falcon_once(prompt, req), timeout=REQUEST_TIMEOUT)
except Exception as e:
last_error = e
if attempt == MAX_RETRIES:
break
delay = RETRY_BASE_DELAY ** attempt
log.warning("Attempt %d failed: %s | retrying in %.2fs", attempt, str(e), delay)
await asyncio.sleep(delay)
raise last_error
# ---------------------------------------------------------------------------
# Falcon call with explicit api_name
# ---------------------------------------------------------------------------
async def _call_falcon_once(prompt: str, req: ChatCompletionRequest) -> str:
client = await get_client()
settings = {
"model": req.model,
"temperature": req.temperature,
"max_new_tokens": req.max_tokens,
"top_p": req.top_p,
}
# Reset chat session
await asyncio.to_thread(client.predict, api_name="/new_chat")
# Add message with explicit api_name and settings
result = await asyncio.to_thread(
client.predict,
prompt, # pierwszy argument
settings_form_value=settings,
api_name="/add_message", # <-- tutaj musi być endpoint z View API
)
return _extract_text(result)
# ---------------------------------------------------------------------------
# Streaming
# ---------------------------------------------------------------------------
async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
cid = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
for i in range(0, len(text), 8):
chunk = {
"id": cid,
"object": "chat.completion.chunk",
"created": created,
"model": req.model,
"choices": [{"index": 0, "delta": {"content": text[i:i+8]}, "finish_reason": None}]
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.01)
final_chunk = {
"id": cid,
"object": "chat.completion.chunk",
"created": created,
"model": req.model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
# ---------------------------------------------------------------------------
# OpenAI-compatible response
# ---------------------------------------------------------------------------
def _make_response(text: str, req: ChatCompletionRequest) -> dict:
pt = sum(len(_content_str(m)) for m in req.messages) // 4
ct = len(text) // 4
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": req.model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct},
}
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
app = FastAPI(title="Foc", version="5.0.0", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
@app.get("/")
async def root():
return {
"service": "FOC API",
"version": "5.0.0",
"endpoints": {
"health": "/health",
"models": "/v1/models",
"chat": "/v1/chat/completions"
}
}
@app.get("/health")
async def health():
return {"status": "ok", "model": MODEL_ID, "space": HF_SPACE_URL}
@app.get("/v1/models")
async def list_models(_: None = Depends(verify_key)):
return {"object": "list", "data": [{"id": MODEL_ID, "object": "model", "created": 1710000000, "owned_by": "tiiuae"}]}
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_key)):
prompt = _build_prompt(req.messages)
log.info("Request | model=%s temp=%.2f tokens=%d stream=%s", req.model, req.temperature, req.max_tokens, req.stream)
try:
text = await _call_with_retries(prompt, req)
except Exception:
log.exception("Falcon failed after retries")
raise HTTPException(status_code=502, detail="Model temporarily unavailable. Please try again.")
if req.stream:
return StreamingResponse(
_stream_sse(text, req),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"},
)
return JSONResponse(content=_make_response(text, req))