Adedoyinjames commited on
Commit
49784f7
·
verified ·
1 Parent(s): c7b862f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -3,26 +3,22 @@ from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from typing import Optional
5
 
6
- MODEL_ID = os.getenv("MODEL_ID", "gpt2")
7
- CACHE_DIR = os.getenv("TRANSFORMERS_CACHE", "/app/.cache")
8
-
9
- # ensure cache dir exists
10
  os.makedirs(CACHE_DIR, exist_ok=True)
11
 
12
  app = FastAPI(title="FastAPI Hugging Face Space")
13
-
14
  generator = None
15
 
16
  class GenerateRequest(BaseModel):
17
  prompt: str
18
- max_length: Optional[int] = None
19
 
20
  @app.on_event("startup")
21
  async def load_model():
22
  global generator
23
- # import here so transformers uses the configured cache
24
  from transformers import pipeline
25
- generator = pipeline("text-generation", model=MODEL_ID)
26
 
27
  @app.get("/health")
28
  async def health():
@@ -32,6 +28,5 @@ async def health():
32
  async def generate(req: GenerateRequest):
33
  if generator is None:
34
  return {"error": "model not loaded yet"}
35
- max_len = req.max_length or 64
36
- out = generator(req.prompt, max_length=max_len, num_return_sequences=1)
37
  return {"generated_text": out[0]["generated_text"]}
 
3
  from pydantic import BaseModel
4
  from typing import Optional
5
 
6
+ MODEL_ID = os.getenv("MODEL_ID", "distilgpt2") # smaller model for CPU
7
+ CACHE_DIR = os.getenv("HF_HOME", "/app/.cache")
 
 
8
  os.makedirs(CACHE_DIR, exist_ok=True)
9
 
10
  app = FastAPI(title="FastAPI Hugging Face Space")
 
11
  generator = None
12
 
13
  class GenerateRequest(BaseModel):
14
  prompt: str
15
+ max_length: Optional[int] = 64
16
 
17
  @app.on_event("startup")
18
  async def load_model():
19
  global generator
 
20
  from transformers import pipeline
21
+ generator = pipeline("text-generation", model=MODEL_ID, cache_dir=CACHE_DIR)
22
 
23
  @app.get("/health")
24
  async def health():
 
28
  async def generate(req: GenerateRequest):
29
  if generator is None:
30
  return {"error": "model not loaded yet"}
31
+ out = generator(req.prompt, max_length=req.max_length, num_return_sequences=1)
 
32
  return {"generated_text": out[0]["generated_text"]}