Munaf1987 commited on
Commit
de8f0db
·
verified ·
1 Parent(s): 1a7d511

Rename main.py to app.py

Browse files
Files changed (2) hide show
  1. app.py +97 -0
  2. main.py +0 -248
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ import uuid
6
+ import base64
7
+ from PIL import Image
8
+ import gradio as gr
9
+ import spaces
10
+
11
+ # Setup
12
+ API_KEY = os.getenv("API_KEY", "demo")
13
+ INPUT_SIZE = (512, 512)
14
+ MODEL_PATH = "BiRefNet-general-resolution_512x512-fp16-epoch_216.onnx"
15
+
16
+ # Load ONNX model
17
+ assert os.path.exists(MODEL_PATH), f"Model not found: {MODEL_PATH}"
18
+ session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
19
+ input_name = session.get_inputs()[0].name
20
+
21
+ # Preprocess
22
+ def preprocess_image(image: np.ndarray):
23
+ original_shape = image.shape[:2]
24
+ resized = cv2.resize(image, INPUT_SIZE)
25
+ normalized = (resized.astype(np.float32) / 255.0 - 0.5) / 0.5
26
+ transposed = np.transpose(normalized, (2, 0, 1))
27
+ input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32)
28
+ return input_tensor, original_shape, image
29
+
30
+ # Mask logic
31
+ def apply_mask(original_img, mask_array, original_shape):
32
+ mask = np.squeeze(mask_array)
33
+ resized_mask = cv2.resize(mask, (original_shape[1], original_shape[0]))
34
+ binary_mask = (resized_mask > 0.5).astype(np.uint8)
35
+ alpha = (binary_mask * 255).astype(np.uint8)
36
+ masked = cv2.bitwise_and(original_img, original_img, mask=binary_mask)
37
+ bgra = cv2.cvtColor(masked, cv2.COLOR_RGB2BGRA)
38
+ bgra[:, :, 3] = alpha
39
+ return bgra
40
+
41
+ # ============ UI ============
42
+ @spaces.GPU
43
+ def remove_background_ui(image, bg=None):
44
+ input_tensor, original_shape, original_img = preprocess_image(image)
45
+ mask = session.run(None, {input_name: input_tensor})[0]
46
+ result = apply_mask(original_img, mask, original_shape)
47
+
48
+ if bg is not None:
49
+ bg_resized = cv2.resize(bg, (original_shape[1], original_shape[0]))
50
+ alpha = result[:, :, 3] / 255.0
51
+ fg = result[:, :, :3]
52
+ blended = (fg * alpha[..., None] + bg_resized * (1 - alpha[..., None])).astype(np.uint8)
53
+ return Image.fromarray(blended)
54
+ return Image.fromarray(result)
55
+
56
+ # ============ API ============
57
+ @spaces.GPU
58
+ def remove_background_api(image_file, api_key=""):
59
+ if api_key != API_KEY:
60
+ raise gr.Error("❌ Invalid API Key")
61
+ image = Image.open(image_file).convert("RGB")
62
+ image_np = np.array(image)
63
+ input_tensor, original_shape, original_img = preprocess_image(image_np)
64
+ mask = session.run(None, {input_name: input_tensor})[0]
65
+ result = apply_mask(original_img, mask, original_shape)
66
+ success, buffer = cv2.imencode(".png", result)
67
+ return f"data:image/png;base64,{base64.b64encode(buffer).decode('utf-8')}"
68
+
69
+ # Gradio interfaces
70
+ ui = gr.Interface(
71
+ fn=remove_background_ui,
72
+ inputs=[
73
+ gr.Image(type="numpy", label="Main Image"),
74
+ gr.Image(type="numpy", label="Optional Background")
75
+ ],
76
+ outputs=gr.Image(type="pil", label="Result"),
77
+ title="🖼️ Background Remover",
78
+ description="Upload a photo (and optionally a background)."
79
+ )
80
+
81
+ api = gr.Interface(
82
+ fn=remove_background_api,
83
+ inputs=[
84
+ gr.Image(type="filepath", label="Upload Image"),
85
+ gr.Text(label="API Key", type="password")
86
+ ],
87
+ outputs=gr.Text(label="Base64 PNG"),
88
+ title="🔐 API Access",
89
+ description="POST to `/run/predict` with file + API key."
90
+ )
91
+
92
+ # Final Gradio app with predict support
93
+ demo = gr.TabbedInterface([ui, api], ["Web UI", "API Access"])
94
+
95
+ # Launch
96
+ if __name__ == "__main__":
97
+ demo.launch()
main.py DELETED
@@ -1,248 +0,0 @@
1
- ###############################################################################
2
- # 1. Set environment variables BEFORE importing Gradio
3
- ###############################################################################
4
- import os
5
-
6
- os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
7
- os.environ["GRADIO_SERVER_PORT"] = "7860"
8
- os.environ["GRADIO_ROOT_PATH"] = "/"
9
-
10
- ###############################################################################
11
- # 2. Imports
12
- ###############################################################################
13
- import uuid
14
- import base64
15
- import json
16
- import shutil
17
- import traceback
18
- import numpy as np
19
- import cv2
20
- import onnxruntime as ort
21
- from io import BytesIO
22
- from PIL import Image
23
- from datetime import datetime
24
- from pathlib import Path
25
-
26
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, Request
27
- from fastapi.responses import FileResponse
28
- from fastapi.staticfiles import StaticFiles
29
- from fastapi.templating import Jinja2Templates
30
- from fastapi.middleware.cors import CORSMiddleware
31
-
32
- import gradio as gr
33
- import spaces
34
- import uvicorn
35
-
36
- ###############################################################################
37
- # 3. Setup
38
- ###############################################################################
39
- API_KEY = os.getenv("API_KEY")
40
-
41
- app = FastAPI(title="Background Removal API")
42
-
43
- app.add_middleware(
44
- CORSMiddleware,
45
- allow_origins=["*"],
46
- allow_credentials=True,
47
- allow_methods=["*"],
48
- allow_headers=["*"],
49
- )
50
-
51
- TMP_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
52
- os.makedirs(TMP_FOLDER, exist_ok=True)
53
- print(f"Created tmp folder at: {TMP_FOLDER}")
54
-
55
- app.mount("/tmp", StaticFiles(directory=TMP_FOLDER), name="tmp")
56
- templates = Jinja2Templates(directory="templates")
57
-
58
- model_path = "BiRefNet-general-resolution_512x512-fp16-epoch_216.onnx"
59
- session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
60
- assert "CUDAExecutionProvider" not in ort.get_available_providers(), \
61
- "CUDA provider found but not supported on ZeroGPU."
62
-
63
- input_name = session.get_inputs()[0].name
64
- INPUT_SIZE = (512, 512)
65
-
66
- ###############################################################################
67
- # 4. Utilities
68
- ###############################################################################
69
-
70
- def verify_api_key(api_key: str = Form(...)):
71
- if api_key != API_KEY:
72
- raise HTTPException(status_code=401, detail="Invalid API key")
73
- return api_key
74
- @spaces.GPU
75
- def preprocess_image(image):
76
- if isinstance(image, str):
77
- img = cv2.imread(image)
78
- elif isinstance(image, np.ndarray):
79
- img = image
80
- else:
81
- nparr = np.frombuffer(image, np.uint8)
82
- img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
83
-
84
- original_img = img.copy()
85
- original_shape = img.shape[:2]
86
- rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
87
- resized = cv2.resize(rgb, INPUT_SIZE)
88
- normalized = resized.astype(np.float32) / 255.0
89
- normalized = (normalized - 0.5) / 0.5
90
- transposed = np.transpose(normalized, (2, 0, 1))
91
- input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32)
92
-
93
- return input_tensor, original_shape, original_img
94
-
95
- @spaces.GPU(duration=240)
96
- def apply_mask(original_img, mask_array, original_shape, output_path):
97
- try:
98
- mask = np.squeeze(mask_array)
99
- mask = cv2.resize(mask, (original_shape[1], original_shape[0]))
100
- mask = np.clip(mask, 0, 1)
101
- binary_mask = (mask > 0.5).astype(np.uint8)
102
-
103
- img = original_img.astype(np.uint8)
104
- masked_img = cv2.bitwise_and(img, img, mask=binary_mask)
105
- alpha = (binary_mask * 255).astype(np.uint8)
106
-
107
- bgra = cv2.cvtColor(masked_img, cv2.COLOR_BGR2BGRA)
108
- bgra[:, :, 3] = alpha
109
-
110
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
111
- cv2.imwrite(output_path, bgra, [cv2.IMWRITE_PNG_COMPRESSION, 0])
112
- return bgra, True
113
- except Exception as e:
114
- print(f"Error applying mask: {e}")
115
- return None, False
116
-
117
- ###############################################################################
118
- # 5. FastAPI Routes
119
- ###############################################################################
120
- @app.post("/")
121
- async def index_post(request: Request, main_photo: UploadFile = File(...), bg_photo: UploadFile = File(None)):
122
- try:
123
- main_image_data = await main_photo.read()
124
- input_tensor, original_shape, original_img = preprocess_image(main_image_data)
125
- output = session.run(None, {input_name: input_tensor})
126
- mask = output[0]
127
-
128
- result_filename = f"{uuid.uuid4()}.png"
129
- output_path = os.path.join(TMP_FOLDER, result_filename)
130
- transparent_img, success = apply_mask(original_img, mask, original_shape, output_path)
131
- final_result_path = output_path
132
-
133
- if bg_photo:
134
- bg_image_data = await bg_photo.read()
135
- bg_np = np.frombuffer(bg_image_data, np.uint8)
136
- bg_img = cv2.imdecode(bg_np, cv2.IMREAD_COLOR)
137
- bg_img_resized = cv2.resize(bg_img, (original_shape[1], original_shape[0]))
138
- alpha = transparent_img[:, :, 3] / 255.0
139
- foreground = transparent_img[:, :, :3]
140
- blended = (foreground * alpha[..., None] + bg_img_resized * (1 - alpha[..., None])).astype(np.uint8)
141
- final_result_path = os.path.join(TMP_FOLDER, f"bg_replaced_{uuid.uuid4()}.png")
142
- cv2.imwrite(final_result_path, blended)
143
-
144
- return templates.TemplateResponse("index.html", {"request": request, "output_image": os.path.basename(final_result_path)})
145
- except Exception as e:
146
- print(traceback.format_exc())
147
- return templates.TemplateResponse("index.html", {"request": request, "error": f"Error: {str(e)}"})
148
-
149
- @app.get("/")
150
- async def index_get(request: Request):
151
- return templates.TemplateResponse("index.html", {"request": request})
152
-
153
- @app.post("/remove-background")
154
- async def remove_background(request: Request, api_key: str = Form(...), main_photo: UploadFile = File(...)):
155
- verify_api_key(api_key)
156
- try:
157
- image_data = await main_photo.read()
158
- result_filename = f"{uuid.uuid4()}.png"
159
- output_path = os.path.join(TMP_FOLDER, result_filename)
160
- input_tensor, original_shape, original_img = preprocess_image(image_data)
161
- output = session.run(None, {input_name: input_tensor})
162
- mask = output[0]
163
- _, success = apply_mask(original_img, mask, original_shape, output_path)
164
-
165
- if success:
166
- base_url = str(request.base_url).rstrip("/")
167
- image_url = f"{base_url}/tmp/{result_filename}"
168
- return {"status": "success", "message": "Background removed", "filename": result_filename, "image_url": image_url}
169
- return {"status": "failure", "message": "Failed to process image"}
170
- except Exception as e:
171
- print(traceback.format_exc())
172
- return {"status": "failure", "message": f"Error: {str(e)}"}
173
-
174
- @app.post("/process_image")
175
- async def process_image(request: Request, image: UploadFile = File(...), api_key: str = Form(...)):
176
- verify_api_key(api_key)
177
- try:
178
- image_data = await image.read()
179
- result_filename = f"{uuid.uuid4()}.png"
180
- output_path = os.path.join(TMP_FOLDER, result_filename)
181
- input_tensor, original_shape, original_img = preprocess_image(image_data)
182
- output = session.run(None, {input_name: input_tensor})
183
- mask = output[0]
184
- bgra, success = apply_mask(original_img, mask, original_shape, output_path)
185
-
186
- if success:
187
- with open(output_path, "rb") as img_file:
188
- base64_image = base64.b64encode(img_file.read()).decode("utf-8")
189
- return {"status": "success", "image_code": f"data:image/png;base64,{base64_image}"}
190
- return {"status": "failure", "message": "Failed to process image"}
191
- except Exception as e:
192
- print(traceback.format_exc())
193
- return {"status": "failure", "message": f"Error: {str(e)}"}
194
-
195
- @app.get("/download/{filename}")
196
- async def download_file(filename: str):
197
- file_path = os.path.join(TMP_FOLDER, filename)
198
- if os.path.exists(file_path):
199
- return FileResponse(path=file_path, filename=filename, media_type="image/png")
200
- raise HTTPException(status_code=404, detail="File not found")
201
-
202
- ###############################################################################
203
- # 6. Gradio Interface
204
- ###############################################################################
205
- @spaces.GPU
206
- def process_image_gradio(image):
207
- session = ort.InferenceSession(
208
- model_path,
209
- providers=["CUDAExecutionProvider"]
210
- )
211
- input_tensor, original_shape, original_img = preprocess_image(image)
212
- output = session.run(None, {input_name: input_tensor})
213
- mask = output[0]
214
- result_filename = f"{uuid.uuid4()}.png"
215
- output_path = os.path.join(TMP_FOLDER, result_filename)
216
- result_img, success = apply_mask(original_img, mask, original_shape, output_path)
217
- if success:
218
- return Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGRA2RGBA))
219
- return None
220
- process_image_gradio.zerogpu =True
221
- # Optional: this ensures startup detects GPU
222
- @spaces.GPU(duration=5)
223
- def _gpu_wakeup():
224
- return "GPU ready"
225
-
226
- # Gradio Blocks Interface
227
- with gr.Blocks() as demo:
228
- gr.Markdown("## 🧠 Background Removal (GPU)")
229
- gr.Markdown("Upload an image to remove the background using ONNX on GPU")
230
-
231
- with gr.Row():
232
- image_input = gr.Image(type="numpy", label="Upload Image")
233
- image_output = gr.Image(type="pil", label="Output")
234
-
235
- submit_btn = gr.Button("Remove Background")
236
- submit_btn.click(fn=process_image_gradio, inputs=image_input, outputs=image_output)
237
-
238
-
239
- ###############################################################################
240
- # 7. Mount Gradio App
241
- ###############################################################################
242
- gr.mount_gradio_app(app, demo, path="/", ssr_mode=False)
243
-
244
- ###############################################################################
245
- # 8. Run (only needed for local dev; not needed on Hugging Face)
246
- ###############################################################################
247
- if __name__ == "__main__":
248
- uvicorn.run(app, host="0.0.0.0", port=7860)