videopix commited on
Commit
5cc75c9
·
verified ·
1 Parent(s): 483291c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -18
app.py CHANGED
@@ -1,28 +1,43 @@
 
1
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
  from fastapi.responses import StreamingResponse, HTMLResponse
3
  from PIL import Image
4
- import torch, numpy as np
 
5
  from transformers import AutoModelForImageSegmentation
6
  from io import BytesIO
7
  from loadimg import load_img
8
- import os
9
 
10
  # -------------------------
11
  # Model Setup
12
  # -------------------------
13
  MODEL_DIR = "models/BiRefNet"
14
  os.makedirs(MODEL_DIR, exist_ok=True)
15
-
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
- print("Loading BiRefNet model...")
19
- birefnet = AutoModelForImageSegmentation.from_pretrained(
20
- "ZhengPeng7/BiRefNet",
21
- cache_dir=MODEL_DIR,
22
- trust_remote_code=True
23
- )
24
- birefnet.to(device).eval()
25
- print("Model loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # -------------------------
28
  # Image Preprocessing
@@ -49,12 +64,7 @@ def process_image(image: Image.Image) -> Image.Image:
49
  return image
50
 
51
  # -------------------------
52
- # FastAPI App
53
- # -------------------------
54
- app = FastAPI(title="Background Removal API")
55
-
56
- # -------------------------
57
- # Single /remove-background endpoint
58
  # -------------------------
59
  @app.post("/remove-background")
60
  async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
@@ -79,7 +89,7 @@ async def remove_background(file: UploadFile = File(None), image_url: str = Form
79
  raise HTTPException(status_code=500, detail=str(e))
80
 
81
  # -------------------------
82
- # Simple HTML frontend
83
  # -------------------------
84
  @app.get("/", response_class=HTMLResponse)
85
  async def index():
 
1
+ import os
2
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
3
  from fastapi.responses import StreamingResponse, HTMLResponse
4
  from PIL import Image
5
+ import torch
6
+ import numpy as np
7
  from transformers import AutoModelForImageSegmentation
8
  from io import BytesIO
9
  from loadimg import load_img
 
10
 
11
  # -------------------------
12
  # Model Setup
13
  # -------------------------
14
  MODEL_DIR = "models/BiRefNet"
15
  os.makedirs(MODEL_DIR, exist_ok=True)
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ birefnet = None # will initialize on startup
19
+
20
+ # -------------------------
21
+ # FastAPI App
22
+ # -------------------------
23
+ app = FastAPI(title="Background Removal API")
24
+
25
+ # -------------------------
26
+ # Startup Event
27
+ # -------------------------
28
+ @app.on_event("startup")
29
+ async def load_model():
30
+ global birefnet
31
+ if birefnet is None:
32
+ print("Loading BiRefNet model...")
33
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
34
+ "ZhengPeng7/BiRefNet",
35
+ cache_dir=MODEL_DIR,
36
+ trust_remote_code=True,
37
+ revision="main" # pin revision to avoid unexpected updates
38
+ )
39
+ birefnet.to(device).eval()
40
+ print("Model loaded successfully.")
41
 
42
  # -------------------------
43
  # Image Preprocessing
 
64
  return image
65
 
66
  # -------------------------
67
+ # Remove Background Endpoint
 
 
 
 
 
68
  # -------------------------
69
  @app.post("/remove-background")
70
  async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
 
89
  raise HTTPException(status_code=500, detail=str(e))
90
 
91
  # -------------------------
92
+ # Web Interface
93
  # -------------------------
94
  @app.get("/", response_class=HTMLResponse)
95
  async def index():