import os from typing import List, Literal, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch APP_TITLE = "HF Chat (Fathom-R1-14B)" APP_VERSION = "0.2.0" MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B") PIPELINE_TASK = os.getenv("PIPELINE_TASK", "text-generation") MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "8192")) STATIC_DIR = os.getenv("STATIC_DIR", "/app/static") ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "") QUANTIZE = os.getenv("QUANTIZE", "auto") app = FastAPI(title=APP_TITLE, version=APP_VERSION) if ALLOWED_ORIGINS: origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class Message(BaseModel): role: Literal["system", "user", "assistant"] content: str class ChatRequest(BaseModel): messages: List[Message] max_new_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.95 repetition_penalty: Optional[float] = 1.0 stop: Optional[List[str]] = None class ChatResponse(BaseModel): reply: str model: str tokenizer = None model = None generator = None def load_pipeline(): global tokenizer, model, generator device = "cuda" if torch.cuda.is_available() else "cpu" # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) if tokenizer.pad_token is None and tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token # Determine load strategy load_kwargs = {} dtype = torch.bfloat16 if device == "cuda" else torch.float32 if device == "cuda": # try quantization if requested if QUANTIZE.lower() in ("4bit", "8bit", "auto"): try: import bitsandbytes as bnb # noqa: F401 if QUANTIZE.lower() == "8bit": load_kwargs.update(dict(load_in_8bit=True)) else: # 4bit or auto (prefer 4bit) load_kwargs.update(dict(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)) except Exception: # bitsandbytes not available; fall back to full precision on GPU pass load_kwargs.setdefault("torch_dtype", dtype) load_kwargs.setdefault("device_map", "auto") else: # CPU fallback load_kwargs.setdefault("torch_dtype", dtype) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs) generator = pipeline( PIPELINE_TASK, model=model, tokenizer=tokenizer, device_map=load_kwargs.get("device_map", None) or (0 if device == "cuda" else -1), ) @app.on_event("startup") def _startup(): load_pipeline() def messages_to_prompt(messages: List[Message]) -> str: """ Prefer tokenizer chat template (Qwen-based models ship one). Fallback to a simple transcript. """ try: # Convert to HF chat format: list of dicts with role/content chat = [{"role": m.role, "content": m.content} for m in messages] return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) except Exception: # Fallback formatting parts = [] for m in messages: if m.role == "system": parts.append(f"System: {m.content}") elif m.role == "user": parts.append(f"User: {m.content}") else: parts.append(f"Assistant: {m.content}") parts.append("Assistant:") return "".join(parts) def truncate_prompt(prompt: str, max_tokens: int) -> str: ids = tokenizer(prompt, return_tensors="pt", truncation=False)["input_ids"][0] if len(ids) <= max_tokens: return prompt trimmed = ids[-max_tokens:] return tokenizer.decode(trimmed, skip_special_tokens=True) @app.get("/api/health") def health(): device = next(model.parameters()).device.type if model is not None else "N/A" return {"status": "ok", "model": MODEL_ID, "task": PIPELINE_TASK, "device": device} @app.post("/api/chat", response_model=ChatResponse) def chat(req: ChatRequest): if generator is None: raise HTTPException(status_code=503, detail="Model not loaded") if not req.messages: raise HTTPException(status_code=400, detail="messages cannot be empty") raw_prompt = messages_to_prompt(req.messages) prompt = truncate_prompt(raw_prompt, MAX_INPUT_TOKENS) gen_kwargs = { "max_new_tokens": req.max_new_tokens, "do_sample": req.temperature > 0, "temperature": req.temperature, "top_p": req.top_p, "repetition_penalty": req.repetition_penalty, "eos_token_id": tokenizer.eos_token_id, "pad_token_id": tokenizer.pad_token_id, "return_full_text": True, } if req.stop: gen_kwargs["stop"] = req.stop outputs = generator(prompt, **gen_kwargs) if isinstance(outputs, list) and outputs and "generated_text" in outputs[0]: full = outputs[0]["generated_text"] reply = full[len(prompt):].strip() if full.startswith(prompt) else full else: reply = str(outputs) if not reply: reply = "(No response generated.)" return ChatResponse(reply=reply, model=MODEL_ID) # Serve frontend build (if present) if os.path.isdir(STATIC_DIR): app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")