Proff12's picture
Upload main.py
edbf0e3 verified
import os
from typing import List, Literal, Optional
import requests
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
APP_TITLE = "HF Chat (Fathom-R1-14B via API)"
APP_VERSION = "0.2.0"
MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B")
STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
app = FastAPI(title=APP_TITLE, version=APP_VERSION)
if ALLOWED_ORIGINS:
origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Message(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
class ChatRequest(BaseModel):
messages: List[Message]
max_new_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.95
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class ChatResponse(BaseModel):
reply: str
model: str
def messages_to_prompt(messages: List[Message]) -> str:
parts = []
for m in messages:
if m.role == "system":
parts.append(f"System: {m.content}")
elif m.role == "user":
parts.append(f"User: {m.content}")
else:
parts.append(f"Assistant: {m.content}")
parts.append("Assistant:")
return "\n".join(parts)
@app.get("/api/health")
def health():
return {"status": "ok", "model": MODEL_ID, "source": "huggingface-inference-api"}
@app.post("/api/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
if not HF_API_TOKEN:
raise HTTPException(status_code=500, detail="HF_API_TOKEN not set")
if not req.messages:
raise HTTPException(status_code=400, detail="messages cannot be empty")
prompt = messages_to_prompt(req.messages)
headers = {
"Authorization": f"Bearer {HF_API_TOKEN}"
}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": req.max_new_tokens,
"temperature": req.temperature,
"top_p": req.top_p,
"repetition_penalty": req.repetition_penalty,
"return_full_text": True,
}
}
response = requests.post(
f"https://api-inference.huggingface.co/models/{MODEL_ID}",
headers=headers,
json=payload
)
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.text)
result = response.json()
if isinstance(result, list) and result and "generated_text" in result[0]:
full = result[0]["generated_text"]
reply = full[len(prompt):].strip() if full.startswith(prompt) else full
else:
reply = str(result)
if not reply:
reply = "(No response generated.)"
return ChatResponse(reply=reply, model=MODEL_ID)
if os.path.isdir(STATIC_DIR):
app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")