import os from typing import Dict, Any from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2") HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if the model repo is private app = FastAPI(title="Check-in GPT-2 API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) device = 0 if torch.cuda.is_available() else -1 # ✅ use token= (not use_auth_token) and rely on HF_HOME=/data/huggingface tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(MODEL_ID, token=HF_TOKEN) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=device ) PREFIX = "INPUT: " SUFFIX = "\nOUTPUT:" def make_prompt(user_input: str) -> str: return f"{PREFIX}{user_input}{SUFFIX}" class GenerateRequest(BaseModel): input: str = Field(..., min_length=1) max_new_tokens: int = 180 temperature: float = 0.7 top_p: float = 0.95 top_k: int = 50 repetition_penalty: float = 1.05 do_sample: bool = True num_return_sequences: int = 1 class GenerateResponse(BaseModel): output: str prompt: str parameters: Dict[str, Any] @app.get("/") def root(): return {"message": "Check-in GPT-2 API. POST /generate", "model": MODEL_ID, "device": "cuda" if device == 0 else "cpu"} @app.get("/health") def health(): return {"status": "ok"} @app.post("/generate", response_model=GenerateResponse) def generate(req: GenerateRequest): try: prompt = make_prompt(req.input) gen = pipe( prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, top_k=req.top_k, repetition_penalty=req.repetition_penalty, do_sample=req.do_sample, num_return_sequences=req.num_return_sequences, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, return_full_text=True ) text = gen[0]["generated_text"] output = text.split("OUTPUT:", 1)[-1].strip() return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump()) except Exception as e: raise HTTPException(status_code=500, detail=str(e))