um41r commited on
Commit
839de5d
Β·
verified Β·
1 Parent(s): a045e4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -79
app.py CHANGED
@@ -1,49 +1,10 @@
1
- # app.py - Fixed FastAPI for HF Spaces
2
- import torch
3
- import numpy as np
4
- from PIL import Image
5
- from safetensors.torch import load_file
6
- import io
7
-
8
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
9
  from fastapi.middleware.cors import CORSMiddleware
10
- from fastapi.responses import StreamingResponse
11
- import uvicorn
12
-
13
- from birefnet import BiRefNet
14
- from BiRefNet_config import BiRefNetConfig
15
-
16
- # Global model
17
- device = "cpu"
18
- config = BiRefNetConfig()
19
- model = BiRefNet(config)
20
- state_dict = load_file("model.safetensors")
21
- model.load_state_dict(state_dict, strict=False)
22
- model.to(device)
23
- model.eval()
24
- print("βœ… BiRefNet Lite loaded")
25
-
26
- def preprocess(img: Image.Image):
27
- img = img.convert("RGB").resize((1024, 1024))
28
- arr = np.array(img).astype(np.float32) / 255.0
29
- arr = arr.transpose(2, 0, 1)
30
- return torch.from_numpy(arr).unsqueeze(0)
31
-
32
- @torch.no_grad()
33
- def remove_bg(image: Image.Image) -> Image.Image:
34
- x = preprocess(image).to(device)
35
- pred = model(x)[0]
36
- pred = torch.sigmoid(pred)
37
- mask = pred.squeeze().cpu().numpy()
38
- mask = (mask * 255).astype(np.uint8)
39
- mask = Image.fromarray(mask).resize(image.size)
40
- out = image.convert("RGBA")
41
- out.putalpha(mask)
42
- return out
43
 
44
- app = FastAPI(title="BiRefNet Background Remover API")
45
 
46
- # CORS - allow all for testing (restrict in production)
47
  app.add_middleware(
48
  CORSMiddleware,
49
  allow_origins=["*"],
@@ -52,50 +13,43 @@ app.add_middleware(
52
  allow_headers=["*"],
53
  )
54
 
55
- @app.get("/")
56
- async def root():
57
- return {"message": "BiRefNet Background Remover API βœ…", "endpoints": ["/remove-bg"]}
58
 
59
- # FIXED: Accept both UploadFile and raw bytes
60
  @app.post("/remove-bg")
61
- async def remove_background(
62
- file: UploadFile = File(None),
63
- request: Request = None
64
- ):
65
  try:
66
- contents = None
 
 
 
 
 
 
 
 
 
67
 
68
- # Try UploadFile first
69
- if file:
70
- contents = await file.read()
71
- print(f"βœ… File received: {file.filename}, {len(contents)} bytes")
72
- else:
73
- # Fallback: check raw body
74
- body = await request.body()
75
- contents = body
76
- print(f"βœ… Raw body received: {len(contents)} bytes")
77
 
78
- if not contents or len(contents) == 0:
79
- raise HTTPException(400, detail="No file uploaded")
80
 
81
- image = Image.open(io.BytesIO(contents))
82
- print(f"βœ… Image loaded: {image.size}, format: {image.format}")
83
 
84
- result = remove_bg(image)
 
 
85
 
86
- img_byte_arr = io.BytesIO()
87
- result.save(img_byte_arr, format="PNG")
88
- img_byte_arr.seek(0)
89
 
90
- print("βœ… Processing complete")
91
- return StreamingResponse(
92
- img_byte_arr,
93
- media_type="image/png",
94
- headers={"Content-Disposition": "inline; filename=removed-bg.png"}
95
- )
96
  except Exception as e:
97
- print(f"❌ Detailed error: {str(e)}")
98
- raise HTTPException(400, detail=f"Processing failed: {str(e)}")
99
 
100
- if __name__ == "__main__":
101
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, Request, HTTPException
 
 
 
 
 
 
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ import io
4
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ app = FastAPI()
7
 
 
8
  app.add_middleware(
9
  CORSMiddleware,
10
  allow_origins=["*"],
 
13
  allow_headers=["*"],
14
  )
15
 
16
+ # Your model loading code here...
17
+ print("βœ… Model loaded")
 
18
 
 
19
  @app.post("/remove-bg")
20
+ async def remove_bg_endpoint(request: Request):
21
+ """Accepts ANY form field name - "file", "image", etc."""
 
 
22
  try:
23
+ form = await request.form()
24
+ print("πŸ“‹ All form fields:", list(form.keys()))
25
+
26
+ # Try common field names
27
+ file_data = None
28
+ for field_name in ['file', 'image', 'data', 'upload']:
29
+ if field_name in form:
30
+ file_data = await form[field_name].read()
31
+ print(f"βœ… Found {field_name}: {len(file_data)} bytes")
32
+ break
33
 
34
+ if not file_data:
35
+ raise HTTPException(400, "No image file found. Send as 'file' or 'image'")
 
 
 
 
 
 
 
36
 
37
+ image = Image.open(io.BytesIO(file_data))
38
+ print(f"βœ… Image: {image.size}")
39
 
40
+ # Your remove_bg(image) function here
41
+ result = remove_bg(image) # Your existing function
42
 
43
+ img_bytes = io.BytesIO()
44
+ result.save(img_bytes, "PNG")
45
+ img_bytes.seek(0)
46
 
47
+ return StreamingResponse(img_bytes, media_type="image/png")
 
 
48
 
 
 
 
 
 
 
49
  except Exception as e:
50
+ print(f"❌ Error: {e}")
51
+ raise HTTPException(400, f"Failed: {str(e)}")
52
 
53
+ @app.get("/")
54
+ def root():
55
+ return {"status": "OK", "endpoint": "/remove-bg POST with file=image or file=file"}