Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,7 +16,7 @@ from functools import partial
|
|
| 16 |
# -------------------------
|
| 17 |
MODEL_DIR = "models/BiRefNet"
|
| 18 |
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 19 |
-
device = "
|
| 20 |
birefnet = None # will initialize on startup
|
| 21 |
|
| 22 |
# -------------------------
|
|
@@ -26,7 +26,7 @@ birefnet = None # will initialize on startup
|
|
| 26 |
async def lifespan(app: FastAPI):
|
| 27 |
global birefnet
|
| 28 |
if birefnet is None:
|
| 29 |
-
print("Loading BiRefNet model...")
|
| 30 |
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 31 |
"ZhengPeng7/BiRefNet",
|
| 32 |
cache_dir=MODEL_DIR,
|
|
@@ -72,10 +72,6 @@ def process_image(image: Image.Image) -> Image.Image:
|
|
| 72 |
# -------------------------
|
| 73 |
@app.post("/remove-background")
|
| 74 |
async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
|
| 75 |
-
"""
|
| 76 |
-
Accept either an uploaded file or an image URL.
|
| 77 |
-
Returns PNG with transparent background.
|
| 78 |
-
"""
|
| 79 |
try:
|
| 80 |
if file:
|
| 81 |
image = Image.open(BytesIO(await file.read())).convert("RGB")
|
|
|
|
| 16 |
# -------------------------
|
| 17 |
MODEL_DIR = "models/BiRefNet"
|
| 18 |
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 19 |
+
device = "cpu" # force CPU usage
|
| 20 |
birefnet = None # will initialize on startup
|
| 21 |
|
| 22 |
# -------------------------
|
|
|
|
| 26 |
async def lifespan(app: FastAPI):
|
| 27 |
global birefnet
|
| 28 |
if birefnet is None:
|
| 29 |
+
print("Loading BiRefNet model on CPU...")
|
| 30 |
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| 31 |
"ZhengPeng7/BiRefNet",
|
| 32 |
cache_dir=MODEL_DIR,
|
|
|
|
| 72 |
# -------------------------
|
| 73 |
@app.post("/remove-background")
|
| 74 |
async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
try:
|
| 76 |
if file:
|
| 77 |
image = Image.open(BytesIO(await file.read())).convert("RGB")
|