ethnmcl commited on
Commit
adeaf8c
·
verified ·
1 Parent(s): 9fccefa

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -12
main.py CHANGED
@@ -7,23 +7,22 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  import torch
8
 
9
  MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2")
10
- HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if repo is private
11
- PORT = int(os.getenv("PORT", "7860"))
12
 
13
  app = FastAPI(title="Check-in GPT-2 API", version="1.0.0")
14
 
15
- # Allow your frontend(s)
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
19
  )
20
 
21
- # Load model once
22
  device = 0 if torch.cuda.is_available() else -1
23
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
 
 
24
  if tokenizer.pad_token is None:
25
  tokenizer.pad_token = tokenizer.eos_token
26
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
27
 
28
  pipe = pipeline(
29
  "text-generation",
@@ -39,7 +38,7 @@ def make_prompt(user_input: str) -> str:
39
  return f"{PREFIX}{user_input}{SUFFIX}"
40
 
41
  class GenerateRequest(BaseModel):
42
- input: str = Field(..., min_length=1, description="Short check-in line to expand")
43
  max_new_tokens: int = 180
44
  temperature: float = 0.7
45
  top_p: float = 0.95
@@ -55,11 +54,7 @@ class GenerateResponse(BaseModel):
55
 
56
  @app.get("/")
57
  def root():
58
- return {
59
- "message": "Check-in GPT-2 API (POST /generate). Swagger: /docs",
60
- "model": MODEL_ID,
61
- "device": "cuda" if device == 0 else "cpu"
62
- }
63
 
64
  @app.get("/health")
65
  def health():
@@ -87,3 +82,4 @@ def generate(req: GenerateRequest):
87
  return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
88
  except Exception as e:
89
  raise HTTPException(status_code=500, detail=str(e))
 
 
7
  import torch
8
 
9
  MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/checkin-gpt2")
10
+ HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Secrets if the model repo is private
 
11
 
12
  app = FastAPI(title="Check-in GPT-2 API", version="1.0.0")
13
 
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
17
  )
18
 
 
19
  device = 0 if torch.cuda.is_available() else -1
20
+
21
+ # ✅ use token= (not use_auth_token) and rely on HF_HOME=/data/huggingface
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
23
  if tokenizer.pad_token is None:
24
  tokenizer.pad_token = tokenizer.eos_token
25
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, token=HF_TOKEN)
26
 
27
  pipe = pipeline(
28
  "text-generation",
 
38
  return f"{PREFIX}{user_input}{SUFFIX}"
39
 
40
  class GenerateRequest(BaseModel):
41
+ input: str = Field(..., min_length=1)
42
  max_new_tokens: int = 180
43
  temperature: float = 0.7
44
  top_p: float = 0.95
 
54
 
55
  @app.get("/")
56
  def root():
57
+ return {"message": "Check-in GPT-2 API. POST /generate", "model": MODEL_ID, "device": "cuda" if device == 0 else "cpu"}
 
 
 
 
58
 
59
  @app.get("/health")
60
  def health():
 
82
  return GenerateResponse(output=output, prompt=prompt, parameters=req.model_dump())
83
  except Exception as e:
84
  raise HTTPException(status_code=500, detail=str(e))
85
+