um41r commited on
Commit
353cbab
Β·
verified Β·
1 Parent(s): 839de5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -35
app.py CHANGED
@@ -1,10 +1,49 @@
1
- from fastapi import FastAPI, File, UploadFile, Form, Request, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- import io
4
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- app = FastAPI()
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  app.add_middleware(
9
  CORSMiddleware,
10
  allow_origins=["*"],
@@ -13,43 +52,50 @@ app.add_middleware(
13
  allow_headers=["*"],
14
  )
15
 
16
- # Your model loading code here...
17
- print("βœ… Model loaded")
 
18
 
 
19
  @app.post("/remove-bg")
20
- async def remove_bg_endpoint(request: Request):
21
- """Accepts ANY form field name - "file", "image", etc."""
22
  try:
23
- form = await request.form()
24
- print("πŸ“‹ All form fields:", list(form.keys()))
25
-
26
- # Try common field names
27
- file_data = None
28
- for field_name in ['file', 'image', 'data', 'upload']:
29
- if field_name in form:
30
- file_data = await form[field_name].read()
31
- print(f"βœ… Found {field_name}: {len(file_data)} bytes")
32
- break
 
 
33
 
34
- if not file_data:
35
- raise HTTPException(400, "No image file found. Send as 'file' or 'image'")
 
36
 
37
- image = Image.open(io.BytesIO(file_data))
38
- print(f"βœ… Image: {image.size}")
39
 
40
- # Your remove_bg(image) function here
41
- result = remove_bg(image) # Your existing function
42
-
43
- img_bytes = io.BytesIO()
44
- result.save(img_bytes, "PNG")
45
- img_bytes.seek(0)
46
-
47
- return StreamingResponse(img_bytes, media_type="image/png")
48
 
 
 
 
 
 
 
 
 
49
  except Exception as e:
50
- print(f"❌ Error: {e}")
51
- raise HTTPException(400, f"Failed: {str(e)}")
52
 
53
- @app.get("/")
54
- def root():
55
- return {"status": "OK", "endpoint": "/remove-bg POST with file=image or file=file"}
 
1
+ # app.py - Fixed FastAPI for HF Spaces
2
+ import torch
3
+ import numpy as np
4
  from PIL import Image
5
+ from safetensors.torch import load_file
6
+ import io
7
+
8
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import StreamingResponse
11
+ import uvicorn
12
+
13
+ from birefnet import BiRefNet
14
+ from BiRefNet_config import BiRefNetConfig
15
+
16
+ # Global model
17
+ device = "cpu"
18
+ config = BiRefNetConfig()
19
+ model = BiRefNet(config)
20
+ state_dict = load_file("model.safetensors")
21
+ model.load_state_dict(state_dict, strict=False)
22
+ model.to(device)
23
+ model.eval()
24
+ print("βœ… BiRefNet Lite loaded")
25
 
26
+ def preprocess(img: Image.Image):
27
+ img = img.convert("RGB").resize((1024, 1024))
28
+ arr = np.array(img).astype(np.float32) / 255.0
29
+ arr = arr.transpose(2, 0, 1)
30
+ return torch.from_numpy(arr).unsqueeze(0)
31
 
32
+ @torch.no_grad()
33
+ def remove_bg(image: Image.Image) -> Image.Image:
34
+ x = preprocess(image).to(device)
35
+ pred = model(x)[0]
36
+ pred = torch.sigmoid(pred)
37
+ mask = pred.squeeze().cpu().numpy()
38
+ mask = (mask * 255).astype(np.uint8)
39
+ mask = Image.fromarray(mask).resize(image.size)
40
+ out = image.convert("RGBA")
41
+ out.putalpha(mask)
42
+ return out
43
+
44
+ app = FastAPI(title="BiRefNet Background Remover API")
45
+
46
+ # CORS - allow all for testing (restrict in production)
47
  app.add_middleware(
48
  CORSMiddleware,
49
  allow_origins=["*"],
 
52
  allow_headers=["*"],
53
  )
54
 
55
+ @app.get("/")
56
+ async def root():
57
+ return {"message": "BiRefNet Background Remover API βœ…", "endpoints": ["/remove-bg"]}
58
 
59
+ # FIXED: Accept both UploadFile and raw bytes
60
  @app.post("/remove-bg")
61
+ async def remove_background(request: Request, file: UploadFile = File(None)):
 
62
  try:
63
+ # Handle different content types
64
+ if file is None:
65
+ # Fallback: read raw body if no file param
66
+ body = await request.body()
67
+ if not body:
68
+ raise HTTPException(400, detail="No image data received")
69
+ image = Image.open(io.BytesIO(body))
70
+ else:
71
+ contents = await file.read()
72
+ if not contents:
73
+ raise HTTPException(400, detail="Empty file")
74
+ image = Image.open(io.BytesIO(contents))
75
 
76
+ # Validate image
77
+ if not image.format or image.format not in ["JPEG", "PNG", "JPG"]:
78
+ raise HTTPException(400, detail="Invalid image format")
79
 
80
+ # Process
81
+ result = remove_bg(image)
82
 
83
+ # Return PNG
84
+ img_byte_arr = io.BytesIO()
85
+ result.save(img_byte_arr, format="PNG")
86
+ img_byte_arr.seek(0)
 
 
 
 
87
 
88
+ return StreamingResponse(
89
+ img_byte_arr,
90
+ media_type="image/png",
91
+ headers={
92
+ "Content-Disposition": "inline; filename=removed-bg.png",
93
+ "Cache-Control": "public, max-age=3600"
94
+ }
95
+ )
96
  except Exception as e:
97
+ print(f"❌ Error: {str(e)}")
98
+ raise HTTPException(500, detail=f"Processing failed: {str(e)}")
99
 
100
+ if __name__ == "__main__":
101
+ uvicorn.run(app, host="0.0.0.0", port=7860)