import os from typing import List, Literal, Optional import requests from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from pydantic import BaseModel APP_TITLE = "HF Chat (Fathom-R1-14B via API)" APP_VERSION = "0.2.0" MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B") STATIC_DIR = os.getenv("STATIC_DIR", "/app/static") ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "") HF_API_TOKEN = os.getenv("HF_API_TOKEN") app = FastAPI(title=APP_TITLE, version=APP_VERSION) if ALLOWED_ORIGINS: origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class Message(BaseModel): role: Literal["system", "user", "assistant"] content: str class ChatRequest(BaseModel): messages: List[Message] max_new_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.95 repetition_penalty: Optional[float] = 1.0 stop: Optional[List[str]] = None class ChatResponse(BaseModel): reply: str model: str def messages_to_prompt(messages: List[Message]) -> str: parts = [] for m in messages: if m.role == "system": parts.append(f"System: {m.content}") elif m.role == "user": parts.append(f"User: {m.content}") else: parts.append(f"Assistant: {m.content}") parts.append("Assistant:") return "\n".join(parts) @app.get("/api/health") def health(): return {"status": "ok", "model": MODEL_ID, "source": "huggingface-inference-api"} @app.post("/api/chat", response_model=ChatResponse) def chat(req: ChatRequest): if not HF_API_TOKEN: raise HTTPException(status_code=500, detail="HF_API_TOKEN not set") if not req.messages: raise HTTPException(status_code=400, detail="messages cannot be empty") prompt = messages_to_prompt(req.messages) headers = { "Authorization": f"Bearer {HF_API_TOKEN}" } payload = { "inputs": prompt, "parameters": { "max_new_tokens": req.max_new_tokens, "temperature": req.temperature, "top_p": req.top_p, "repetition_penalty": req.repetition_penalty, "return_full_text": True, } } response = requests.post( f"https://api-inference.huggingface.co/models/{MODEL_ID}", headers=headers, json=payload ) if response.status_code != 200: raise HTTPException(status_code=response.status_code, detail=response.text) result = response.json() if isinstance(result, list) and result and "generated_text" in result[0]: full = result[0]["generated_text"] reply = full[len(prompt):].strip() if full.startswith(prompt) else full else: reply = str(result) if not reply: reply = "(No response generated.)" return ChatResponse(reply=reply, model=MODEL_ID) if os.path.isdir(STATIC_DIR): app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")