File size: 3,207 Bytes
958f33b edbf0e3 958f33b edbf0e3 958f33b edbf0e3 958f33b edbf0e3 958f33b edbf0e3 958f33b edbf0e3 958f33b edbf0e3 958f33b edbf0e3 958f33b edbf0e3 958f33b edbf0e3 958f33b edbf0e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
from typing import List, Literal, Optional
import requests
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
APP_TITLE = "HF Chat (Fathom-R1-14B via API)"
APP_VERSION = "0.2.0"
MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B")
STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
app = FastAPI(title=APP_TITLE, version=APP_VERSION)
if ALLOWED_ORIGINS:
origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Message(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
class ChatRequest(BaseModel):
messages: List[Message]
max_new_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.95
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class ChatResponse(BaseModel):
reply: str
model: str
def messages_to_prompt(messages: List[Message]) -> str:
parts = []
for m in messages:
if m.role == "system":
parts.append(f"System: {m.content}")
elif m.role == "user":
parts.append(f"User: {m.content}")
else:
parts.append(f"Assistant: {m.content}")
parts.append("Assistant:")
return "\n".join(parts)
@app.get("/api/health")
def health():
return {"status": "ok", "model": MODEL_ID, "source": "huggingface-inference-api"}
@app.post("/api/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
if not HF_API_TOKEN:
raise HTTPException(status_code=500, detail="HF_API_TOKEN not set")
if not req.messages:
raise HTTPException(status_code=400, detail="messages cannot be empty")
prompt = messages_to_prompt(req.messages)
headers = {
"Authorization": f"Bearer {HF_API_TOKEN}"
}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": req.max_new_tokens,
"temperature": req.temperature,
"top_p": req.top_p,
"repetition_penalty": req.repetition_penalty,
"return_full_text": True,
}
}
response = requests.post(
f"https://api-inference.huggingface.co/models/{MODEL_ID}",
headers=headers,
json=payload
)
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.text)
result = response.json()
if isinstance(result, list) and result and "generated_text" in result[0]:
full = result[0]["generated_text"]
reply = full[len(prompt):].strip() if full.startswith(prompt) else full
else:
reply = str(result)
if not reply:
reply = "(No response generated.)"
return ChatResponse(reply=reply, model=MODEL_ID)
if os.path.isdir(STATIC_DIR):
app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") |