um41r commited on
Commit
0df353e
·
verified ·
1 Parent(s): 353cbab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -34
app.py CHANGED
@@ -1,28 +1,57 @@
1
- # app.py - Fixed FastAPI for HF Spaces
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
  from safetensors.torch import load_file
6
- import io
7
 
8
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import StreamingResponse
 
11
  import uvicorn
12
 
13
  from birefnet import BiRefNet
14
  from BiRefNet_config import BiRefNetConfig
15
 
16
- # Global model
17
- device = "cpu"
 
 
 
 
 
 
 
 
 
 
 
18
  config = BiRefNetConfig()
19
  model = BiRefNet(config)
 
20
  state_dict = load_file("model.safetensors")
21
  model.load_state_dict(state_dict, strict=False)
22
- model.to(device)
23
  model.eval()
 
24
  print("✅ BiRefNet Lite loaded")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def preprocess(img: Image.Image):
27
  img = img.convert("RGB").resize((1024, 1024))
28
  arr = np.array(img).astype(np.float32) / 255.0
@@ -31,71 +60,86 @@ def preprocess(img: Image.Image):
31
 
32
  @torch.no_grad()
33
  def remove_bg(image: Image.Image) -> Image.Image:
34
- x = preprocess(image).to(device)
35
  pred = model(x)[0]
36
  pred = torch.sigmoid(pred)
37
  mask = pred.squeeze().cpu().numpy()
38
  mask = (mask * 255).astype(np.uint8)
39
  mask = Image.fromarray(mask).resize(image.size)
 
40
  out = image.convert("RGBA")
41
  out.putalpha(mask)
42
  return out
43
 
 
 
 
44
  app = FastAPI(title="BiRefNet Background Remover API")
45
 
46
- # CORS - allow all for testing (restrict in production)
47
  app.add_middleware(
48
  CORSMiddleware,
49
- allow_origins=["*"],
50
  allow_credentials=True,
51
  allow_methods=["*"],
52
  allow_headers=["*"],
53
  )
54
 
 
 
 
55
  @app.get("/")
56
  async def root():
57
- return {"message": "BiRefNet Background Remover API ✅", "endpoints": ["/remove-bg"]}
 
 
 
 
58
 
59
- # FIXED: Accept both UploadFile and raw bytes
60
  @app.post("/remove-bg")
61
- async def remove_background(request: Request, file: UploadFile = File(None)):
 
 
 
 
62
  try:
63
- # Handle different content types
64
  if file is None:
65
- # Fallback: read raw body if no file param
66
  body = await request.body()
67
  if not body:
68
- raise HTTPException(400, detail="No image data received")
69
  image = Image.open(io.BytesIO(body))
70
  else:
71
  contents = await file.read()
72
  if not contents:
73
- raise HTTPException(400, detail="Empty file")
74
  image = Image.open(io.BytesIO(contents))
75
-
76
- # Validate image
77
- if not image.format or image.format not in ["JPEG", "PNG", "JPG"]:
78
- raise HTTPException(400, detail="Invalid image format")
79
-
80
- # Process
81
  result = remove_bg(image)
82
-
83
- # Return PNG
84
- img_byte_arr = io.BytesIO()
85
- result.save(img_byte_arr, format="PNG")
86
- img_byte_arr.seek(0)
87
-
88
  return StreamingResponse(
89
- img_byte_arr,
90
  media_type="image/png",
91
  headers={
92
- "Content-Disposition": "inline; filename=removed-bg.png",
93
- "Cache-Control": "public, max-age=3600"
94
  }
95
  )
 
96
  except Exception as e:
97
- print(f"❌ Error: {str(e)}")
98
- raise HTTPException(500, detail=f"Processing failed: {str(e)}")
99
 
 
 
 
100
  if __name__ == "__main__":
101
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
1
+ import os
2
+ import io
3
  import torch
4
  import numpy as np
5
  from PIL import Image
6
  from safetensors.torch import load_file
 
7
 
8
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Request, Depends
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import StreamingResponse
11
+ from fastapi.security import APIKeyHeader
12
  import uvicorn
13
 
14
  from birefnet import BiRefNet
15
  from BiRefNet_config import BiRefNetConfig
16
 
17
+ # =========================
18
+ # HUGGING FACE SECRET
19
+ # =========================
20
+ API_KEY = os.getenv("BIREFNET_API_KEY")
21
+
22
+ if not API_KEY:
23
+ raise RuntimeError("❌ BIREFNET_API_KEY not found in HF Space Secrets")
24
+
25
+ DEVICE = "cpu"
26
+
27
+ # =========================
28
+ # LOAD MODEL
29
+ # =========================
30
  config = BiRefNetConfig()
31
  model = BiRefNet(config)
32
+
33
  state_dict = load_file("model.safetensors")
34
  model.load_state_dict(state_dict, strict=False)
35
+ model.to(DEVICE)
36
  model.eval()
37
+
38
  print("✅ BiRefNet Lite loaded")
39
 
40
+ # =========================
41
+ # API KEY AUTH
42
+ # =========================
43
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
44
+
45
+ def verify_api_key(api_key: str = Depends(api_key_header)):
46
+ if api_key != API_KEY:
47
+ raise HTTPException(
48
+ status_code=401,
49
+ detail="Invalid or missing API key"
50
+ )
51
+
52
+ # =========================
53
+ # IMAGE PIPELINE
54
+ # =========================
55
  def preprocess(img: Image.Image):
56
  img = img.convert("RGB").resize((1024, 1024))
57
  arr = np.array(img).astype(np.float32) / 255.0
 
60
 
61
  @torch.no_grad()
62
  def remove_bg(image: Image.Image) -> Image.Image:
63
+ x = preprocess(image).to(DEVICE)
64
  pred = model(x)[0]
65
  pred = torch.sigmoid(pred)
66
  mask = pred.squeeze().cpu().numpy()
67
  mask = (mask * 255).astype(np.uint8)
68
  mask = Image.fromarray(mask).resize(image.size)
69
+
70
  out = image.convert("RGBA")
71
  out.putalpha(mask)
72
  return out
73
 
74
+ # =========================
75
+ # FASTAPI APP
76
+ # =========================
77
  app = FastAPI(title="BiRefNet Background Remover API")
78
 
 
79
  app.add_middleware(
80
  CORSMiddleware,
81
+ allow_origins=["*"], # Restrict later
82
  allow_credentials=True,
83
  allow_methods=["*"],
84
  allow_headers=["*"],
85
  )
86
 
87
+ # =========================
88
+ # ROUTES
89
+ # =========================
90
  @app.get("/")
91
  async def root():
92
+ return {
93
+ "status": "ok",
94
+ "secured": True,
95
+ "endpoint": "/remove-bg"
96
+ }
97
 
 
98
  @app.post("/remove-bg")
99
+ async def remove_background(
100
+ request: Request,
101
+ file: UploadFile = File(None),
102
+ _: str = Depends(verify_api_key)
103
+ ):
104
  try:
 
105
  if file is None:
 
106
  body = await request.body()
107
  if not body:
108
+ raise HTTPException(400, "No image data received")
109
  image = Image.open(io.BytesIO(body))
110
  else:
111
  contents = await file.read()
112
  if not contents:
113
+ raise HTTPException(400, "Empty file")
114
  image = Image.open(io.BytesIO(contents))
115
+
116
+ if image.format not in ["JPEG", "JPG", "PNG"]:
117
+ raise HTTPException(400, "Invalid image format")
118
+
 
 
119
  result = remove_bg(image)
120
+
121
+ img_bytes = io.BytesIO()
122
+ result.save(img_bytes, format="PNG")
123
+ img_bytes.seek(0)
124
+
 
125
  return StreamingResponse(
126
+ img_bytes,
127
  media_type="image/png",
128
  headers={
129
+ "Content-Disposition": "inline; filename=removed-bg.png"
 
130
  }
131
  )
132
+
133
  except Exception as e:
134
+ print("❌ Error:", e)
135
+ raise HTTPException(500, "Processing failed")
136
 
137
+ # =========================
138
+ # HF DOCKER ENTRYPOINT
139
+ # =========================
140
  if __name__ == "__main__":
141
+ uvicorn.run(
142
+ app,
143
+ host="0.0.0.0",
144
+ port=int(os.environ.get("PORT", 7860))
145
+ )