CheckInAPI / main.py
ethnmcl's picture
Create main.py
c7d9782 verified
raw
history blame
2.75 kB
import os
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2")
HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if repo is private
PORT = int(os.getenv("PORT", "7860"))
app = FastAPI(title="Check-in GPT-2 API", version="1.0.0")
# Allow your frontend(s)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
# Load model once
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device
)
PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"
def make_prompt(user_input: str) -> str:
return f"{PREFIX}{user_input}{SUFFIX}"
class GenerateRequest(BaseModel):
input: str = Field(..., min_length=1, description="Short check-in line to expand")
max_new_tokens: int = 180
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 50
repetition_penalty: float = 1.05
do_sample: bool = True
num_return_sequences: int = 1
class GenerateResponse(BaseModel):
output: str
prompt: str
parameters: Dict[str, Any]
@app.get("/")
def root():
return {
"message": "Check-in GPT-2 API (POST /generate). Swagger: /docs",
"model": MODEL_ID,
"device": "cuda" if device == 0 else "cpu"
}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
try:
prompt = make_prompt(req.input)
gen = pipe(
prompt,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
top_k=req.top_k,
repetition_penalty=req.repetition_penalty,
do_sample=req.do_sample,
num_return_sequences=req.num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
return_full_text=True
)
text = gen[0]["generated_text"]
output = text.split("OUTPUT:", 1)[-1].strip()
return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))