File size: 2,753 Bytes
c7d9782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from typing import Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2")
HF_TOKEN = os.getenv("HF_TOKEN")  # set in Space Secrets if repo is private
PORT = int(os.getenv("PORT", "7860"))

app = FastAPI(title="Check-in GPT-2 API", version="1.0.0")

# Allow your frontend(s)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)

# Load model once
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=device
)

PREFIX = "INPUT: "
SUFFIX = "\nOUTPUT:"

def make_prompt(user_input: str) -> str:
    return f"{PREFIX}{user_input}{SUFFIX}"

class GenerateRequest(BaseModel):
    input: str = Field(..., min_length=1, description="Short check-in line to expand")
    max_new_tokens: int = 180
    temperature: float = 0.7
    top_p: float = 0.95
    top_k: int = 50
    repetition_penalty: float = 1.05
    do_sample: bool = True
    num_return_sequences: int = 1

class GenerateResponse(BaseModel):
    output: str
    prompt: str
    parameters: Dict[str, Any]

@app.get("/")
def root():
    return {
        "message": "Check-in GPT-2 API (POST /generate). Swagger: /docs",
        "model": MODEL_ID,
        "device": "cuda" if device == 0 else "cpu"
    }

@app.get("/health")
def health():
    return {"status": "ok"}

@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
    try:
        prompt = make_prompt(req.input)
        gen = pipe(
            prompt,
            max_new_tokens=req.max_new_tokens,
            temperature=req.temperature,
            top_p=req.top_p,
            top_k=req.top_k,
            repetition_penalty=req.repetition_penalty,
            do_sample=req.do_sample,
            num_return_sequences=req.num_return_sequences,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            return_full_text=True
        )
        text = gen[0]["generated_text"]
        output = text.split("OUTPUT:", 1)[-1].strip()
        return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))