import os from typing import Dict, Any from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2") HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if repo is private PORT = int(os.getenv("PORT", "7860")) app = FastAPI(title="Check-in GPT-2 API", version="1.0.0") # Allow your frontend(s) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model once device = 0 if torch.cuda.is_available() else -1 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=device ) PREFIX = "INPUT: " SUFFIX = "\nOUTPUT:" def make_prompt(user_input: str) -> str: return f"{PREFIX}{user_input}{SUFFIX}" class GenerateRequest(BaseModel): input: str = Field(..., min_length=1, description="Short check-in line to expand") max_new_tokens: int = 180 temperature: float = 0.7 top_p: float = 0.95 top_k: int = 50 repetition_penalty: float = 1.05 do_sample: bool = True num_return_sequences: int = 1 class GenerateResponse(BaseModel): output: str prompt: str parameters: Dict[str, Any] @app.get("/") def root(): return { "message": "Check-in GPT-2 API (POST /generate). Swagger: /docs", "model": MODEL_ID, "device": "cuda" if device == 0 else "cpu" } @app.get("/health") def health(): return {"status": "ok"} @app.post("/generate", response_model=GenerateResponse) def generate(req: GenerateRequest): try: prompt = make_prompt(req.input) gen = pipe( prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, top_k=req.top_k, repetition_penalty=req.repetition_penalty, do_sample=req.do_sample, num_return_sequences=req.num_return_sequences, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, return_full_text=True ) text = gen[0]["generated_text"] output = text.split("OUTPUT:", 1)[-1].strip() return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump()) except Exception as e: raise HTTPException(status_code=500, detail=str(e))