videopix commited on
Commit
1bfaffb
·
verified ·
1 Parent(s): d8cb31a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -7,7 +7,6 @@ import numpy as np
7
  from transformers import AutoModelForImageSegmentation
8
  from io import BytesIO
9
  from loadimg import load_img
10
- from contextlib import asynccontextmanager
11
 
12
  # -------------------------
13
  # Model Setup
@@ -16,13 +15,18 @@ MODEL_DIR = "models/BiRefNet"
16
  os.makedirs(MODEL_DIR, exist_ok=True)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- birefnet = None # Will be initialized in lifespan
20
 
21
  # -------------------------
22
- # Lifespan (startup/shutdown)
23
  # -------------------------
24
- @asynccontextmanager
25
- async def lifespan(app: FastAPI):
 
 
 
 
 
26
  global birefnet
27
  if birefnet is None:
28
  print("Loading BiRefNet model...")
@@ -34,10 +38,6 @@ async def lifespan(app: FastAPI):
34
  )
35
  birefnet.to(device).eval()
36
  print("Model loaded successfully.")
37
- yield # Hand control back to FastAPI
38
- # No special shutdown actions needed
39
-
40
- app = FastAPI(title="Background Removal API", lifespan=lifespan)
41
 
42
  # -------------------------
43
  # Image Preprocessing
@@ -64,7 +64,7 @@ def process_image(image: Image.Image) -> Image.Image:
64
  return image
65
 
66
  # -------------------------
67
- # API Endpoint
68
  # -------------------------
69
  @app.post("/remove-background")
70
  async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
 
7
  from transformers import AutoModelForImageSegmentation
8
  from io import BytesIO
9
  from loadimg import load_img
 
10
 
11
  # -------------------------
12
  # Model Setup
 
15
  os.makedirs(MODEL_DIR, exist_ok=True)
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ birefnet = None # will initialize on startup
19
 
20
  # -------------------------
21
+ # FastAPI App
22
  # -------------------------
23
+ app = FastAPI(title="Background Removal API")
24
+
25
+ # -------------------------
26
+ # Startup Event
27
+ # -------------------------
28
+ @app.on_event("startup")
29
+ async def load_model():
30
  global birefnet
31
  if birefnet is None:
32
  print("Loading BiRefNet model...")
 
38
  )
39
  birefnet.to(device).eval()
40
  print("Model loaded successfully.")
 
 
 
 
41
 
42
  # -------------------------
43
  # Image Preprocessing
 
64
  return image
65
 
66
  # -------------------------
67
+ # Remove Background Endpoint
68
  # -------------------------
69
  @app.post("/remove-background")
70
  async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):