um41r commited on
Commit
b423626
·
verified ·
1 Parent(s): 0b06727

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -20
app.py CHANGED
@@ -1,57 +1,87 @@
 
1
  import torch
2
- import gradio as gr
3
  import numpy as np
4
  from PIL import Image
5
  from safetensors.torch import load_file
 
 
 
 
 
 
 
6
 
7
  from birefnet import BiRefNet
8
  from BiRefNet_config import BiRefNetConfig
9
 
10
-
11
  device = "cpu"
12
-
13
- # Load config + model
14
  config = BiRefNetConfig()
15
  model = BiRefNet(config)
16
-
17
  state_dict = load_file("model.safetensors")
18
  model.load_state_dict(state_dict, strict=False)
19
-
20
  model.to(device)
21
  model.eval()
22
-
23
  print("✅ BiRefNet Lite loaded")
24
 
25
-
26
  def preprocess(img: Image.Image):
27
  img = img.convert("RGB").resize((1024, 1024))
28
  arr = np.array(img).astype(np.float32) / 255.0
29
  arr = arr.transpose(2, 0, 1)
30
  return torch.from_numpy(arr).unsqueeze(0)
31
 
32
-
33
  @torch.no_grad()
34
- def remove_bg(image):
35
- image = Image.fromarray(image)
36
  x = preprocess(image).to(device)
37
-
38
  pred = model(x)[0]
39
  pred = torch.sigmoid(pred)
40
-
41
  mask = pred.squeeze().cpu().numpy()
42
  mask = (mask * 255).astype(np.uint8)
43
  mask = Image.fromarray(mask).resize(image.size)
44
-
45
  out = image.convert("RGBA")
46
  out.putalpha(mask)
47
  return out
48
 
 
49
 
50
- demo = gr.Interface(
51
- fn=remove_bg,
52
- inputs=gr.Image(type="numpy"),
53
- outputs=gr.Image(type="pil"),
54
- title="BiRefNet Lite – Background Remover",
 
 
55
  )
56
 
57
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py - FastAPI BiRefNet Background Remover API
2
  import torch
 
3
  import numpy as np
4
  from PIL import Image
5
  from safetensors.torch import load_file
6
+ import io
7
+ from typing import Optional
8
+
9
+ from fastapi import FastAPI, File, UploadFile, HTTPException
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import StreamingResponse
12
+ import uvicorn
13
 
14
  from birefnet import BiRefNet
15
  from BiRefNet_config import BiRefNetConfig
16
 
17
+ # Global model (loaded once)
18
  device = "cpu"
 
 
19
  config = BiRefNetConfig()
20
  model = BiRefNet(config)
 
21
  state_dict = load_file("model.safetensors")
22
  model.load_state_dict(state_dict, strict=False)
 
23
  model.to(device)
24
  model.eval()
 
25
  print("✅ BiRefNet Lite loaded")
26
 
 
27
  def preprocess(img: Image.Image):
28
  img = img.convert("RGB").resize((1024, 1024))
29
  arr = np.array(img).astype(np.float32) / 255.0
30
  arr = arr.transpose(2, 0, 1)
31
  return torch.from_numpy(arr).unsqueeze(0)
32
 
 
33
  @torch.no_grad()
34
+ def remove_bg(image: Image.Image) -> Image.Image:
 
35
  x = preprocess(image).to(device)
 
36
  pred = model(x)[0]
37
  pred = torch.sigmoid(pred)
 
38
  mask = pred.squeeze().cpu().numpy()
39
  mask = (mask * 255).astype(np.uint8)
40
  mask = Image.fromarray(mask).resize(image.size)
 
41
  out = image.convert("RGBA")
42
  out.putalpha(mask)
43
  return out
44
 
45
+ app = FastAPI(title="BiRefNet Background Remover API")
46
 
47
+ # CORS for NextJS/Vercel
48
+ app.add_middleware(
49
+ CORSMiddleware,
50
+ allow_origins=["*"], # Update with your domain in production
51
+ allow_credentials=True,
52
+ allow_methods=["*"],
53
+ allow_headers=["*"],
54
  )
55
 
56
+ @app.get("/")
57
+ async def root():
58
+ return {"message": "BiRefNet Background Remover API", "status": "ready"}
59
+
60
+ @app.post("/remove-bg")
61
+ async def remove_background(
62
+ file: UploadFile = File(..., description="Image file (PNG/JPG)")
63
+ ):
64
+ if not file.content_type.startswith("image/"):
65
+ raise HTTPException(400, detail="File must be an image")
66
+
67
+ try:
68
+ # Read and process image
69
+ contents = await file.read()
70
+ image = Image.open(io.BytesIO(contents))
71
+ result = remove_bg(image)
72
+
73
+ # Save to bytes
74
+ img_byte_arr = io.BytesIO()
75
+ result.save(img_byte_arr, format="PNG")
76
+ img_byte_arr.seek(0)
77
+
78
+ return StreamingResponse(
79
+ img_byte_arr,
80
+ media_type="image/png",
81
+ headers={"Content-Disposition": "inline; filename=removed-bg.png"}
82
+ )
83
+ except Exception as e:
84
+ raise HTTPException(500, detail=f"Processing failed: {str(e)}")
85
+
86
+ if __name__ == "__main__":
87
+ uvicorn.run(app, host="0.0.0.0", port=7860)