videopix commited on
Commit
4919185
·
verified ·
1 Parent(s): 1282773

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -107
app.py CHANGED
@@ -1,47 +1,28 @@
1
- import os
2
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
3
- from fastapi.responses import StreamingResponse, HTMLResponse
4
  from PIL import Image
5
  from io import BytesIO
6
  import torch
7
  import numpy as np
8
  from transformers import AutoModelForImageSegmentation
9
  from loadimg import load_img
10
- from contextlib import asynccontextmanager
11
- import asyncio
12
- from functools import partial
13
 
14
  # -------------------------
15
  # Model Setup
16
  # -------------------------
17
  MODEL_DIR = "models/BiRefNet"
18
  os.makedirs(MODEL_DIR, exist_ok=True)
19
- device = "cpu" # force CPU usage
20
- birefnet = None # will initialize on startup
21
 
22
- # -------------------------
23
- # Lifespan (Startup/Shutdown)
24
- # -------------------------
25
- @asynccontextmanager
26
- async def lifespan(app: FastAPI):
27
- global birefnet
28
- if birefnet is None:
29
- print("Loading BiRefNet model on CPU...")
30
- birefnet = AutoModelForImageSegmentation.from_pretrained(
31
- "ZhengPeng7/BiRefNet",
32
- cache_dir=MODEL_DIR,
33
- trust_remote_code=True,
34
- revision="main"
35
- )
36
- birefnet.to(device).eval()
37
- print("Model loaded successfully.")
38
- yield
39
- # optional shutdown logic
40
-
41
- # -------------------------
42
- # FastAPI App
43
- # -------------------------
44
- app = FastAPI(title="Background Removal API", lifespan=lifespan)
45
 
46
  # -------------------------
47
  # Image Preprocessing
@@ -52,7 +33,7 @@ def transform_image(image: Image.Image) -> torch.Tensor:
52
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
53
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
54
  arr = (arr - mean) / std
55
- arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
56
  tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
57
  return tensor
58
 
@@ -68,82 +49,23 @@ def process_image(image: Image.Image) -> Image.Image:
68
  return image
69
 
70
  # -------------------------
71
- # /remove-background endpoint
72
  # -------------------------
73
- @app.post("/remove-background")
74
- async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
75
- try:
76
- if file:
77
- image = Image.open(BytesIO(await file.read())).convert("RGB")
78
- elif image_url:
79
- image = load_img(image_url, output_type="pil").convert("RGB")
80
- else:
81
- raise HTTPException(status_code=400, detail="Provide file or image_url")
82
-
83
- # run blocking image processing in a separate thread
84
- loop = asyncio.get_running_loop()
85
- result = await loop.run_in_executor(None, partial(process_image, image))
86
-
87
- buf = BytesIO()
88
- result.save(buf, format="PNG")
89
- buf.seek(0)
90
- return StreamingResponse(buf, media_type="image/png")
91
- except Exception as e:
92
- raise HTTPException(status_code=500, detail=str(e))
93
 
94
  # -------------------------
95
- # Optional: Web Interface
96
  # -------------------------
97
- @app.get("/", response_class=HTMLResponse)
98
- async def index():
99
- html_content = """
100
- <!DOCTYPE html>
101
- <html>
102
- <head>
103
- <title>Background Removal</title>
104
- <style>
105
- body { font-family: Arial; padding: 20px; }
106
- .container { max-width: 600px; margin: auto; background: #f9f9f9; padding: 20px; border-radius: 10px; }
107
- img { max-width: 100%; margin-top: 20px; }
108
- </style>
109
- </head>
110
- <body>
111
- <div class="container">
112
- <h2>Background Removal Tool</h2>
113
- <form id="fileForm" enctype="multipart/form-data">
114
- <input type="file" name="file" id="fileInput">
115
- <button type="submit">Remove Background</button>
116
- </form>
117
- <hr>
118
- <form id="urlForm">
119
- <input type="text" id="urlInput" placeholder="Image URL">
120
- <button type="submit">Remove Background</button>
121
- </form>
122
- <img id="resultImg" src="">
123
- </div>
124
- <script>
125
- const fileForm = document.getElementById('fileForm');
126
- fileForm.addEventListener('submit', async e => {
127
- e.preventDefault();
128
- const fileInput = document.getElementById('fileInput');
129
- if(fileInput.files.length === 0) return alert("Select a file!");
130
- const formData = new FormData();
131
- formData.append("file", fileInput.files[0]);
132
- const res = await fetch('/remove-background', {method:'POST', body:formData});
133
- const blob = await res.blob();
134
- document.getElementById('resultImg').src = URL.createObjectURL(blob);
135
- });
136
- const urlForm = document.getElementById('urlForm');
137
- urlForm.addEventListener('submit', async e => {
138
- e.preventDefault();
139
- const formData = new FormData();
140
- formData.append("image_url", document.getElementById('urlInput').value);
141
- const res = await fetch('/remove-background', {method:'POST', body:formData});
142
- const blob = await res.blob();
143
- document.getElementById('resultImg').src = URL.createObjectURL(blob);
144
- });
145
- </script>
146
- </body>
147
- </html>
148
- """
149
- return HTMLResponse(html_content)
 
1
+ import gradio as gr
 
 
2
  from PIL import Image
3
  from io import BytesIO
4
  import torch
5
  import numpy as np
6
  from transformers import AutoModelForImageSegmentation
7
  from loadimg import load_img
8
+ import os
 
 
9
 
10
  # -------------------------
11
  # Model Setup
12
  # -------------------------
13
  MODEL_DIR = "models/BiRefNet"
14
  os.makedirs(MODEL_DIR, exist_ok=True)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
16
 
17
+ print("Loading BiRefNet model...")
18
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
19
+ "ZhengPeng7/BiRefNet",
20
+ cache_dir=MODEL_DIR,
21
+ trust_remote_code=True,
22
+ revision="main"
23
+ )
24
+ birefnet.to(device).eval()
25
+ print("Model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # -------------------------
28
  # Image Preprocessing
 
33
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
34
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
35
  arr = (arr - mean) / std
36
+ arr = np.transpose(arr, (2, 0, 1))
37
  tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
38
  return tensor
39
 
 
49
  return image
50
 
51
  # -------------------------
52
+ # Gradio Function for API
53
  # -------------------------
54
+ def remove_background_gradio(input_img: Image.Image) -> Image.Image:
55
+ return process_image(input_img.convert("RGB"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # -------------------------
58
+ # Gradio Interface (Web + API)
59
  # -------------------------
60
+ iface = gr.Interface(
61
+ fn=remove_background_gradio,
62
+ inputs=gr.Image(type="pil"),
63
+ outputs=gr.Image(type="pil"),
64
+ title="Background Removal Tool",
65
+ description="Upload an image and get a transparent background.",
66
+ allow_flagging="never",
67
+ api_name="remove-background" # This exposes /api/predict/remove-background
68
+ )
69
+
70
+ # Launch
71
+ iface.launch(server_name="0.0.0.0", server_port=7860)