videopix commited on
Commit
e6e54a0
·
verified ·
1 Parent(s): 8ba660f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -27
app.py CHANGED
@@ -1,11 +1,15 @@
1
- import gradio as gr
 
 
2
  from PIL import Image
3
  from io import BytesIO
4
  import torch
5
  import numpy as np
6
  from transformers import AutoModelForImageSegmentation
7
  from loadimg import load_img
8
- import os
 
 
9
 
10
  # -------------------------
11
  # Model Setup
@@ -13,16 +17,31 @@ import os
13
  MODEL_DIR = "models/BiRefNet"
14
  os.makedirs(MODEL_DIR, exist_ok=True)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
16
 
17
- print("Loading BiRefNet model...")
18
- birefnet = AutoModelForImageSegmentation.from_pretrained(
19
- "ZhengPeng7/BiRefNet",
20
- cache_dir=MODEL_DIR,
21
- trust_remote_code=True,
22
- revision="main"
23
- )
24
- birefnet.to(device).eval()
25
- print("Model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # -------------------------
28
  # Image Preprocessing
@@ -33,7 +52,7 @@ def transform_image(image: Image.Image) -> torch.Tensor:
33
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
34
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
35
  arr = (arr - mean) / std
36
- arr = np.transpose(arr, (2, 0, 1))
37
  tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
38
  return tensor
39
 
@@ -49,22 +68,86 @@ def process_image(image: Image.Image) -> Image.Image:
49
  return image
50
 
51
  # -------------------------
52
- # Gradio Function
53
  # -------------------------
54
- def remove_background_gradio(input_img):
55
- # Gradio passes PIL images directly
56
- result = process_image(input_img.convert("RGB"))
57
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # -------------------------
60
- # Gradio Interface
61
  # -------------------------
62
- iface = gr.Interface(
63
- fn=remove_background_gradio,
64
- inputs=gr.Image(type="pil"),
65
- outputs=gr.Image(type="pil"),
66
- title="Background Removal Tool",
67
- description="Upload an image and get a transparent background."
68
- )
69
-
70
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
3
+ from fastapi.responses import StreamingResponse, HTMLResponse
4
  from PIL import Image
5
  from io import BytesIO
6
  import torch
7
  import numpy as np
8
  from transformers import AutoModelForImageSegmentation
9
  from loadimg import load_img
10
+ from contextlib import asynccontextmanager
11
+ import asyncio
12
+ from functools import partial
13
 
14
  # -------------------------
15
  # Model Setup
 
17
  MODEL_DIR = "models/BiRefNet"
18
  os.makedirs(MODEL_DIR, exist_ok=True)
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ birefnet = None # will initialize on startup
21
 
22
+ # -------------------------
23
+ # Lifespan (Startup/Shutdown)
24
+ # -------------------------
25
+ @asynccontextmanager
26
+ async def lifespan(app: FastAPI):
27
+ global birefnet
28
+ if birefnet is None:
29
+ print("Loading BiRefNet model...")
30
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
31
+ "ZhengPeng7/BiRefNet",
32
+ cache_dir=MODEL_DIR,
33
+ trust_remote_code=True,
34
+ revision="main"
35
+ )
36
+ birefnet.to(device).eval()
37
+ print("Model loaded successfully.")
38
+ yield
39
+ # optional shutdown logic
40
+
41
+ # -------------------------
42
+ # FastAPI App
43
+ # -------------------------
44
+ app = FastAPI(title="Background Removal API", lifespan=lifespan)
45
 
46
  # -------------------------
47
  # Image Preprocessing
 
52
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
53
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
54
  arr = (arr - mean) / std
55
+ arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
56
  tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
57
  return tensor
58
 
 
68
  return image
69
 
70
  # -------------------------
71
+ # /remove-background endpoint
72
  # -------------------------
73
+ @app.post("/remove-background")
74
+ async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
75
+ """
76
+ Accept either an uploaded file or an image URL.
77
+ Returns PNG with transparent background.
78
+ """
79
+ try:
80
+ if file:
81
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
82
+ elif image_url:
83
+ image = load_img(image_url, output_type="pil").convert("RGB")
84
+ else:
85
+ raise HTTPException(status_code=400, detail="Provide file or image_url")
86
+
87
+ # run blocking image processing in a separate thread
88
+ loop = asyncio.get_running_loop()
89
+ result = await loop.run_in_executor(None, partial(process_image, image))
90
+
91
+ buf = BytesIO()
92
+ result.save(buf, format="PNG")
93
+ buf.seek(0)
94
+ return StreamingResponse(buf, media_type="image/png")
95
+ except Exception as e:
96
+ raise HTTPException(status_code=500, detail=str(e))
97
 
98
  # -------------------------
99
+ # Optional: Web Interface
100
  # -------------------------
101
+ @app.get("/", response_class=HTMLResponse)
102
+ async def index():
103
+ html_content = """
104
+ <!DOCTYPE html>
105
+ <html>
106
+ <head>
107
+ <title>Background Removal</title>
108
+ <style>
109
+ body { font-family: Arial; padding: 20px; }
110
+ .container { max-width: 600px; margin: auto; background: #f9f9f9; padding: 20px; border-radius: 10px; }
111
+ img { max-width: 100%; margin-top: 20px; }
112
+ </style>
113
+ </head>
114
+ <body>
115
+ <div class="container">
116
+ <h2>Background Removal Tool</h2>
117
+ <form id="fileForm" enctype="multipart/form-data">
118
+ <input type="file" name="file" id="fileInput">
119
+ <button type="submit">Remove Background</button>
120
+ </form>
121
+ <hr>
122
+ <form id="urlForm">
123
+ <input type="text" id="urlInput" placeholder="Image URL">
124
+ <button type="submit">Remove Background</button>
125
+ </form>
126
+ <img id="resultImg" src="">
127
+ </div>
128
+ <script>
129
+ const fileForm = document.getElementById('fileForm');
130
+ fileForm.addEventListener('submit', async e => {
131
+ e.preventDefault();
132
+ const fileInput = document.getElementById('fileInput');
133
+ if(fileInput.files.length === 0) return alert("Select a file!");
134
+ const formData = new FormData();
135
+ formData.append("file", fileInput.files[0]);
136
+ const res = await fetch('/remove-background', {method:'POST', body:formData});
137
+ const blob = await res.blob();
138
+ document.getElementById('resultImg').src = URL.createObjectURL(blob);
139
+ });
140
+ const urlForm = document.getElementById('urlForm');
141
+ urlForm.addEventListener('submit', async e => {
142
+ e.preventDefault();
143
+ const formData = new FormData();
144
+ formData.append("image_url", document.getElementById('urlInput').value);
145
+ const res = await fetch('/remove-background', {method:'POST', body:formData});
146
+ const blob = await res.blob();
147
+ document.getElementById('resultImg').src = URL.createObjectURL(blob);
148
+ });
149
+ </script>
150
+ </body>
151
+ </html>
152
+ """
153
+ return HTMLResponse(html_content)