Spaces:
Sleeping
Sleeping
File size: 2,753 Bytes
c7d9782 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2")
HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if repo is private
PORT = int(os.getenv("PORT", "7860"))
app = FastAPI(title="Check-in GPT-2 API", version="1.0.0")
# Allow your frontend(s)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
# Load model once
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device
)
PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"
def make_prompt(user_input: str) -> str:
return f"{PREFIX}{user_input}{SUFFIX}"
class GenerateRequest(BaseModel):
input: str = Field(..., min_length=1, description="Short check-in line to expand")
max_new_tokens: int = 180
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 50
repetition_penalty: float = 1.05
do_sample: bool = True
num_return_sequences: int = 1
class GenerateResponse(BaseModel):
output: str
prompt: str
parameters: Dict[str, Any]
@app.get("/")
def root():
return {
"message": "Check-in GPT-2 API (POST /generate). Swagger: /docs",
"model": MODEL_ID,
"device": "cuda" if device == 0 else "cpu"
}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
try:
prompt = make_prompt(req.input)
gen = pipe(
prompt,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
top_k=req.top_k,
repetition_penalty=req.repetition_penalty,
do_sample=req.do_sample,
num_return_sequences=req.num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
return_full_text=True
)
text = gen[0]["generated_text"]
output = text.split("OUTPUT:", 1)[-1].strip()
return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|