|
|
|
|
|
import os |
|
|
from typing import List, Literal, Optional |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
import torch |
|
|
|
|
|
APP_TITLE = "HF Chat (Fathom-R1-14B)" |
|
|
APP_VERSION = "0.2.0" |
|
|
|
|
|
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B") |
|
|
PIPELINE_TASK = os.getenv("PIPELINE_TASK", "text-generation") |
|
|
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "8192")) |
|
|
STATIC_DIR = os.getenv("STATIC_DIR", "/app/static") |
|
|
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "") |
|
|
QUANTIZE = os.getenv("QUANTIZE", "auto") |
|
|
|
|
|
app = FastAPI(title=APP_TITLE, version=APP_VERSION) |
|
|
|
|
|
if ALLOWED_ORIGINS: |
|
|
origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()] |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
class Message(BaseModel): |
|
|
role: Literal["system", "user", "assistant"] |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[Message] |
|
|
max_new_tokens: int = 512 |
|
|
temperature: float = 0.7 |
|
|
top_p: float = 0.95 |
|
|
repetition_penalty: Optional[float] = 1.0 |
|
|
stop: Optional[List[str]] = None |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
reply: str |
|
|
model: str |
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
generator = None |
|
|
|
|
|
def load_pipeline(): |
|
|
global tokenizer, model, generator |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
|
|
if tokenizer.pad_token is None and tokenizer.eos_token is not None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
load_kwargs = {} |
|
|
dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
|
|
|
|
|
if device == "cuda": |
|
|
|
|
|
if QUANTIZE.lower() in ("4bit", "8bit", "auto"): |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
if QUANTIZE.lower() == "8bit": |
|
|
load_kwargs.update(dict(load_in_8bit=True)) |
|
|
else: |
|
|
|
|
|
load_kwargs.update(dict(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)) |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
load_kwargs.setdefault("torch_dtype", dtype) |
|
|
load_kwargs.setdefault("device_map", "auto") |
|
|
else: |
|
|
|
|
|
load_kwargs.setdefault("torch_dtype", dtype) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs) |
|
|
|
|
|
generator = pipeline( |
|
|
PIPELINE_TASK, |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device_map=load_kwargs.get("device_map", None) or (0 if device == "cuda" else -1), |
|
|
) |
|
|
|
|
|
@app.on_event("startup") |
|
|
def _startup(): |
|
|
load_pipeline() |
|
|
|
|
|
def messages_to_prompt(messages: List[Message]) -> str: |
|
|
""" |
|
|
Prefer tokenizer chat template (Qwen-based models ship one). Fallback to a simple transcript. |
|
|
""" |
|
|
try: |
|
|
|
|
|
chat = [{"role": m.role, "content": m.content} for m in messages] |
|
|
return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
|
except Exception: |
|
|
|
|
|
parts = [] |
|
|
for m in messages: |
|
|
if m.role == "system": |
|
|
parts.append(f"System: {m.content} |
|
|
") |
|
|
elif m.role == "user": |
|
|
parts.append(f"User: {m.content} |
|
|
") |
|
|
else: |
|
|
parts.append(f"Assistant: {m.content} |
|
|
") |
|
|
parts.append("Assistant:") |
|
|
return " |
|
|
".join(parts) |
|
|
|
|
|
def truncate_prompt(prompt: str, max_tokens: int) -> str: |
|
|
ids = tokenizer(prompt, return_tensors="pt", truncation=False)["input_ids"][0] |
|
|
if len(ids) <= max_tokens: |
|
|
return prompt |
|
|
trimmed = ids[-max_tokens:] |
|
|
return tokenizer.decode(trimmed, skip_special_tokens=True) |
|
|
|
|
|
@app.get("/api/health") |
|
|
def health(): |
|
|
device = next(model.parameters()).device.type if model is not None else "N/A" |
|
|
return {"status": "ok", "model": MODEL_ID, "task": PIPELINE_TASK, "device": device} |
|
|
|
|
|
@app.post("/api/chat", response_model=ChatResponse) |
|
|
def chat(req: ChatRequest): |
|
|
if generator is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
if not req.messages: |
|
|
raise HTTPException(status_code=400, detail="messages cannot be empty") |
|
|
|
|
|
raw_prompt = messages_to_prompt(req.messages) |
|
|
prompt = truncate_prompt(raw_prompt, MAX_INPUT_TOKENS) |
|
|
|
|
|
gen_kwargs = { |
|
|
"max_new_tokens": req.max_new_tokens, |
|
|
"do_sample": req.temperature > 0, |
|
|
"temperature": req.temperature, |
|
|
"top_p": req.top_p, |
|
|
"repetition_penalty": req.repetition_penalty, |
|
|
"eos_token_id": tokenizer.eos_token_id, |
|
|
"pad_token_id": tokenizer.pad_token_id, |
|
|
"return_full_text": True, |
|
|
} |
|
|
if req.stop: |
|
|
gen_kwargs["stop"] = req.stop |
|
|
|
|
|
outputs = generator(prompt, **gen_kwargs) |
|
|
if isinstance(outputs, list) and outputs and "generated_text" in outputs[0]: |
|
|
full = outputs[0]["generated_text"] |
|
|
reply = full[len(prompt):].strip() if full.startswith(prompt) else full |
|
|
else: |
|
|
reply = str(outputs) |
|
|
if not reply: |
|
|
reply = "(No response generated.)" |
|
|
return ChatResponse(reply=reply, model=MODEL_ID) |
|
|
|
|
|
|
|
|
if os.path.isdir(STATIC_DIR): |
|
|
app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") |
|
|
|