Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Dict, Any | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2") | |
| HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if repo is private | |
| PORT = int(os.getenv("PORT", "7860")) | |
| app = FastAPI(title="Check-in GPT-2 API", version="1.0.0") | |
| # Allow your frontend(s) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| # Load model once | |
| device = 0 if torch.cuda.is_available() else -1 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device | |
| ) | |
| PREFIX = "INPUT: " | |
| SUFFIX = "\nOUTPUT:" | |
| def make_prompt(user_input: str) -> str: | |
| return f"{PREFIX}{user_input}{SUFFIX}" | |
| class GenerateRequest(BaseModel): | |
| input: str = Field(..., min_length=1, description="Short check-in line to expand") | |
| max_new_tokens: int = 180 | |
| temperature: float = 0.7 | |
| top_p: float = 0.95 | |
| top_k: int = 50 | |
| repetition_penalty: float = 1.05 | |
| do_sample: bool = True | |
| num_return_sequences: int = 1 | |
| class GenerateResponse(BaseModel): | |
| output: str | |
| prompt: str | |
| parameters: Dict[str, Any] | |
| def root(): | |
| return { | |
| "message": "Check-in GPT-2 API (POST /generate). Swagger: /docs", | |
| "model": MODEL_ID, | |
| "device": "cuda" if device == 0 else "cpu" | |
| } | |
| def health(): | |
| return {"status": "ok"} | |
| def generate(req: GenerateRequest): | |
| try: | |
| prompt = make_prompt(req.input) | |
| gen = pipe( | |
| prompt, | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| top_k=req.top_k, | |
| repetition_penalty=req.repetition_penalty, | |
| do_sample=req.do_sample, | |
| num_return_sequences=req.num_return_sequences, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| return_full_text=True | |
| ) | |
| text = gen[0]["generated_text"] | |
| output = text.split("OUTPUT:", 1)[-1].strip() | |
| return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump()) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |