File size: 3,207 Bytes
958f33b
 
edbf0e3
958f33b
 
 
 
 
 
edbf0e3
958f33b
 
 
 
 
edbf0e3
958f33b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edbf0e3
 
 
 
 
 
 
 
 
 
958f33b
 
 
edbf0e3
958f33b
 
 
edbf0e3
 
 
958f33b
 
 
edbf0e3
 
 
 
958f33b
 
edbf0e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
958f33b
 
edbf0e3
 
958f33b
 
edbf0e3
958f33b
 
 
edbf0e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from typing import List, Literal, Optional
import requests

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel

APP_TITLE = "HF Chat (Fathom-R1-14B via API)"
APP_VERSION = "0.2.0"

MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B")
STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")

app = FastAPI(title=APP_TITLE, version=APP_VERSION)

if ALLOWED_ORIGINS:
    origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()]
    app.add_middleware(
        CORSMiddleware,
        allow_origins=origins,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

class Message(BaseModel):
    role: Literal["system", "user", "assistant"]
    content: str

class ChatRequest(BaseModel):
    messages: List[Message]
    max_new_tokens: int = 512
    temperature: float = 0.7
    top_p: float = 0.95
    repetition_penalty: Optional[float] = 1.0
    stop: Optional[List[str]] = None

class ChatResponse(BaseModel):
    reply: str
    model: str

def messages_to_prompt(messages: List[Message]) -> str:
    parts = []
    for m in messages:
        if m.role == "system":
            parts.append(f"System: {m.content}")
        elif m.role == "user":
            parts.append(f"User: {m.content}")
        else:
            parts.append(f"Assistant: {m.content}")
    parts.append("Assistant:")
    return "\n".join(parts)

@app.get("/api/health")
def health():
    return {"status": "ok", "model": MODEL_ID, "source": "huggingface-inference-api"}

@app.post("/api/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
    if not HF_API_TOKEN:
        raise HTTPException(status_code=500, detail="HF_API_TOKEN not set")

    if not req.messages:
        raise HTTPException(status_code=400, detail="messages cannot be empty")

    prompt = messages_to_prompt(req.messages)

    headers = {
        "Authorization": f"Bearer {HF_API_TOKEN}"
    }

    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": req.max_new_tokens,
            "temperature": req.temperature,
            "top_p": req.top_p,
            "repetition_penalty": req.repetition_penalty,
            "return_full_text": True,
        }
    }

    response = requests.post(
        f"https://api-inference.huggingface.co/models/{MODEL_ID}",
        headers=headers,
        json=payload
    )

    if response.status_code != 200:
        raise HTTPException(status_code=response.status_code, detail=response.text)

    result = response.json()
    if isinstance(result, list) and result and "generated_text" in result[0]:
        full = result[0]["generated_text"]
        reply = full[len(prompt):].strip() if full.startswith(prompt) else full
    else:
        reply = str(result)

    if not reply:
        reply = "(No response generated.)"

    return ChatResponse(reply=reply, model=MODEL_ID)

if os.path.isdir(STATIC_DIR):
    app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")