ethnmcl commited on
Commit
c7d9782
·
verified ·
1 Parent(s): 6a6674d

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +89 -0
main.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Any
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel, Field
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ import torch
8
+
9
+ MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2")
10
+ HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if repo is private
11
+ PORT = int(os.getenv("PORT", "7860"))
12
+
13
+ app = FastAPI(title="Check-in GPT-2 API", version="1.0.0")
14
+
15
+ # Allow your frontend(s)
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
19
+ )
20
+
21
+ # Load model once
22
+ device = 0 if torch.cuda.is_available() else -1
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
24
+ if tokenizer.pad_token is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
27
+
28
+ pipe = pipeline(
29
+ "text-generation",
30
+ model=model,
31
+ tokenizer=tokenizer,
32
+ device=device
33
+ )
34
+
35
+ PREFIX = "INPUT: "
36
+ SUFFIX = "\nOUTPUT:"
37
+
38
+ def make_prompt(user_input: str) -> str:
39
+ return f"{PREFIX}{user_input}{SUFFIX}"
40
+
41
+ class GenerateRequest(BaseModel):
42
+ input: str = Field(..., min_length=1, description="Short check-in line to expand")
43
+ max_new_tokens: int = 180
44
+ temperature: float = 0.7
45
+ top_p: float = 0.95
46
+ top_k: int = 50
47
+ repetition_penalty: float = 1.05
48
+ do_sample: bool = True
49
+ num_return_sequences: int = 1
50
+
51
+ class GenerateResponse(BaseModel):
52
+ output: str
53
+ prompt: str
54
+ parameters: Dict[str, Any]
55
+
56
+ @app.get("/")
57
+ def root():
58
+ return {
59
+ "message": "Check-in GPT-2 API (POST /generate). Swagger: /docs",
60
+ "model": MODEL_ID,
61
+ "device": "cuda" if device == 0 else "cpu"
62
+ }
63
+
64
+ @app.get("/health")
65
+ def health():
66
+ return {"status": "ok"}
67
+
68
+ @app.post("/generate", response_model=GenerateResponse)
69
+ def generate(req: GenerateRequest):
70
+ try:
71
+ prompt = make_prompt(req.input)
72
+ gen = pipe(
73
+ prompt,
74
+ max_new_tokens=req.max_new_tokens,
75
+ temperature=req.temperature,
76
+ top_p=req.top_p,
77
+ top_k=req.top_k,
78
+ repetition_penalty=req.repetition_penalty,
79
+ do_sample=req.do_sample,
80
+ num_return_sequences=req.num_return_sequences,
81
+ pad_token_id=tokenizer.eos_token_id,
82
+ eos_token_id=tokenizer.eos_token_id,
83
+ return_full_text=True
84
+ )
85
+ text = gen[0]["generated_text"]
86
+ output = text.split("OUTPUT:", 1)[-1].strip()
87
+ return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
88
+ except Exception as e:
89
+ raise HTTPException(status_code=500, detail=str(e))