Adedoyinjames commited on
Commit
5ea1723
·
verified ·
1 Parent(s): f5af1f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -1,30 +1,37 @@
1
- from os import getenv
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from typing import Optional
5
- from transformers import pipeline
6
 
7
- MODEL_ID = getenv("MODEL_ID", "gpt2") # set by env if you want another model
8
- GEN_KWARGS = {"max_length": 64, "num_return_sequences": 1}
 
 
 
9
 
10
  app = FastAPI(title="FastAPI Hugging Face Space")
11
 
12
- # load pipeline once on startup
13
- generator = pipeline("text-generation", model=MODEL_ID)
14
 
15
  class GenerateRequest(BaseModel):
16
  prompt: str
17
  max_length: Optional[int] = None
18
 
 
 
 
 
 
 
 
19
  @app.get("/health")
20
  async def health():
21
- return {"status": "ok", "model": MODEL_ID}
22
 
23
  @app.post("/generate")
24
  async def generate(req: GenerateRequest):
25
- kwargs = GEN_KWARGS.copy()
26
- if req.max_length:
27
- kwargs["max_length"] = req.max_length
28
- out = generator(req.prompt, **kwargs)
29
- # pipeline returns a list with dicts containing "generated_text"
30
  return {"generated_text": out[0]["generated_text"]}
 
1
+ import os
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from typing import Optional
 
5
 
6
+ MODEL_ID = os.getenv("MODEL_ID", "gpt2")
7
+ CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/.cache")
8
+
9
+ # ensure cache dir exists
10
+ os.makedirs(CACHE_DIR, exist_ok=True)
11
 
12
  app = FastAPI(title="FastAPI Hugging Face Space")
13
 
14
+ generator = None
 
15
 
16
  class GenerateRequest(BaseModel):
17
  prompt: str
18
  max_length: Optional[int] = None
19
 
20
+ @app.on_event("startup")
21
+ async def load_model():
22
+ global generator
23
+ # import here so transformers uses the configured cache
24
+ from transformers import pipeline
25
+ generator = pipeline("text-generation", model=MODEL_ID)
26
+
27
  @app.get("/health")
28
  async def health():
29
+ return {"status": "ok", "model": MODEL_ID, "cache": CACHE_DIR}
30
 
31
  @app.post("/generate")
32
  async def generate(req: GenerateRequest):
33
+ if generator is None:
34
+ return {"error": "model not loaded yet"}
35
+ max_len = req.max_length or 64
36
+ out = generator(req.prompt, max_length=max_len, num_return_sequences=1)
 
37
  return {"generated_text": out[0]["generated_text"]}