Adedoyinjames commited on
Commit
1514e88
·
verified ·
1 Parent(s): e3e42ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -2,23 +2,23 @@ import os
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from typing import Optional
 
5
 
6
- MODEL_ID = os.getenv("MODEL_ID", "distilgpt2") # smaller model for CPU
7
  CACHE_DIR = os.getenv("HF_HOME", "/app/.cache")
8
  os.makedirs(CACHE_DIR, exist_ok=True)
9
 
10
  app = FastAPI(title="FastAPI Hugging Face Space")
11
- generator = None
 
12
 
13
  class GenerateRequest(BaseModel):
14
  prompt: str
15
  max_length: Optional[int] = 64
16
 
17
- @app.on_event("startup")
18
- async def load_model():
19
- global generator
20
- from transformers import pipeline
21
- generator = pipeline("text-generation", model=MODEL_ID, cache_dir=CACHE_DIR)
22
 
23
  @app.get("/health")
24
  async def health():
@@ -26,7 +26,9 @@ async def health():
26
 
27
  @app.post("/generate")
28
  async def generate(req: GenerateRequest):
29
- if generator is None:
30
- return {"error": "model not loaded yet"}
31
- out = generator(req.prompt, max_length=req.max_length, num_return_sequences=1)
32
- return {"generated_text": out[0]["generated_text"]}
 
 
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from typing import Optional
5
+ from transformers import pipeline
6
 
7
+ MODEL_ID = os.getenv("MODEL_ID", "distilgpt2")
8
  CACHE_DIR = os.getenv("HF_HOME", "/app/.cache")
9
  os.makedirs(CACHE_DIR, exist_ok=True)
10
 
11
  app = FastAPI(title="FastAPI Hugging Face Space")
12
+
13
+ generator = pipeline("text-generation", model=MODEL_ID, cache_dir=CACHE_DIR)
14
 
15
  class GenerateRequest(BaseModel):
16
  prompt: str
17
  max_length: Optional[int] = 64
18
 
19
+ @app.get("/")
20
+ async def root():
21
+ return {"message": "API running. Use POST /generate to generate text."}
 
 
22
 
23
  @app.get("/health")
24
  async def health():
 
26
 
27
  @app.post("/generate")
28
  async def generate(req: GenerateRequest):
29
+ result = generator(req.prompt, max_length=req.max_length, num_return_sequences=1)
30
+ return {"generated_text": result[0]["generated_text"]}
31
+
32
+ if __name__ == "__main__":
33
+ import uvicorn
34
+ uvicorn.run(app, host="0.0.0.0", port=7860)