CheckInAPI / main.py
ethnmcl's picture
Update main.py
adeaf8c verified
raw
history blame
2.65 kB
import os
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2")
HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if the model repo is private
app = FastAPI(title="Check-in GPT-2 API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
device = 0 if torch.cuda.is_available() else -1
# ✅ use token= (not use_auth_token) and rely on HF_HOME=/data/huggingface
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, token=HF_TOKEN)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device
)
PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"
def make_prompt(user_input: str) -> str:
return f"{PREFIX}{user_input}{SUFFIX}"
class GenerateRequest(BaseModel):
input: str = Field(..., min_length=1)
max_new_tokens: int = 180
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 50
repetition_penalty: float = 1.05
do_sample: bool = True
num_return_sequences: int = 1
class GenerateResponse(BaseModel):
output: str
prompt: str
parameters: Dict[str, Any]
@app.get("/")
def root():
return {"message": "Check-in GPT-2 API. POST /generate", "model": MODEL_ID, "device": "cuda" if device == 0 else "cpu"}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
try:
prompt = make_prompt(req.input)
gen = pipe(
prompt,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
top_k=req.top_k,
repetition_penalty=req.repetition_penalty,
do_sample=req.do_sample,
num_return_sequences=req.num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
return_full_text=True
)
text = gen[0]["generated_text"]
output = text.split("OUTPUT:", 1)[-1].strip()
return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))