videopix commited on
Commit
8ba660f
·
verified ·
1 Parent(s): 15ebf0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -106
app.py CHANGED
@@ -1,13 +1,11 @@
1
- import os
2
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
3
- from fastapi.responses import StreamingResponse, HTMLResponse
4
  from PIL import Image
 
5
  import torch
6
  import numpy as np
7
  from transformers import AutoModelForImageSegmentation
8
- from io import BytesIO
9
  from loadimg import load_img
10
- from contextlib import asynccontextmanager
11
 
12
  # -------------------------
13
  # Model Setup
@@ -15,31 +13,16 @@ from contextlib import asynccontextmanager
15
  MODEL_DIR = "models/BiRefNet"
16
  os.makedirs(MODEL_DIR, exist_ok=True)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
- birefnet = None # will initialize on startup
19
-
20
- # -------------------------
21
- # Lifespan (Startup/Shutdown)
22
- # -------------------------
23
- @asynccontextmanager
24
- async def lifespan(app: FastAPI):
25
- global birefnet
26
- if birefnet is None:
27
- print("Loading BiRefNet model...")
28
- birefnet = AutoModelForImageSegmentation.from_pretrained(
29
- "ZhengPeng7/BiRefNet",
30
- cache_dir=MODEL_DIR,
31
- trust_remote_code=True,
32
- revision="main"
33
- )
34
- birefnet.to(device).eval()
35
- print("Model loaded successfully.")
36
- yield
37
- # Optional shutdown logic here (nothing needed for this model)
38
 
39
- # -------------------------
40
- # FastAPI App
41
- # -------------------------
42
- app = FastAPI(title="Background Removal API", lifespan=lifespan)
 
 
 
 
 
43
 
44
  # -------------------------
45
  # Image Preprocessing
@@ -50,7 +33,7 @@ def transform_image(image: Image.Image) -> torch.Tensor:
50
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
51
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
52
  arr = (arr - mean) / std
53
- arr = np.transpose(arr, (2, 0, 1)) # HWC -> CHW
54
  tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
55
  return tensor
56
 
@@ -66,83 +49,22 @@ def process_image(image: Image.Image) -> Image.Image:
66
  return image
67
 
68
  # -------------------------
69
- # Remove Background Endpoint
70
  # -------------------------
71
- @app.post("/remove-background")
72
- async def remove_background(file: UploadFile = File(None), image_url: str = Form(None)):
73
- """
74
- Accept either an uploaded file or an image URL.
75
- Returns PNG with transparent background.
76
- """
77
- try:
78
- if file:
79
- image = Image.open(BytesIO(await file.read())).convert("RGB")
80
- elif image_url:
81
- image = load_img(image_url, output_type="pil").convert("RGB")
82
- else:
83
- raise HTTPException(status_code=400, detail="Provide file or image_url")
84
-
85
- result = process_image(image)
86
- buf = BytesIO()
87
- result.save(buf, format="PNG")
88
- buf.seek(0)
89
- return StreamingResponse(buf, media_type="image/png")
90
- except Exception as e:
91
- raise HTTPException(status_code=500, detail=str(e))
92
 
93
  # -------------------------
94
- # Web Interface
95
  # -------------------------
96
- @app.get("/", response_class=HTMLResponse)
97
- async def index():
98
- html_content = """
99
- <!DOCTYPE html>
100
- <html>
101
- <head>
102
- <title>Background Removal</title>
103
- <style>
104
- body { font-family: Arial; padding: 20px; }
105
- .container { max-width: 600px; margin: auto; background: #f9f9f9; padding: 20px; border-radius: 10px; }
106
- img { max-width: 100%; margin-top: 20px; }
107
- </style>
108
- </head>
109
- <body>
110
- <div class="container">
111
- <h2>Background Removal Tool</h2>
112
- <form id="fileForm" enctype="multipart/form-data">
113
- <input type="file" name="file" id="fileInput">
114
- <button type="submit">Remove Background</button>
115
- </form>
116
- <hr>
117
- <form id="urlForm">
118
- <input type="text" id="urlInput" placeholder="Image URL">
119
- <button type="submit">Remove Background</button>
120
- </form>
121
- <img id="resultImg" src="">
122
- </div>
123
- <script>
124
- const fileForm = document.getElementById('fileForm');
125
- fileForm.addEventListener('submit', async e => {
126
- e.preventDefault();
127
- const fileInput = document.getElementById('fileInput');
128
- if(fileInput.files.length === 0) return alert("Select a file!");
129
- const formData = new FormData();
130
- formData.append("file", fileInput.files[0]);
131
- const res = await fetch('/remove-background', {method:'POST', body:formData});
132
- const blob = await res.blob();
133
- document.getElementById('resultImg').src = URL.createObjectURL(blob);
134
- });
135
- const urlForm = document.getElementById('urlForm');
136
- urlForm.addEventListener('submit', async e => {
137
- e.preventDefault();
138
- const formData = new FormData();
139
- formData.append("image_url", document.getElementById('urlInput').value);
140
- const res = await fetch('/remove-background', {method:'POST', body:formData});
141
- const blob = await res.blob();
142
- document.getElementById('resultImg').src = URL.createObjectURL(blob);
143
- });
144
- </script>
145
- </body>
146
- </html>
147
- """
148
- return HTMLResponse(html_content)
 
1
+ import gradio as gr
 
 
2
  from PIL import Image
3
+ from io import BytesIO
4
  import torch
5
  import numpy as np
6
  from transformers import AutoModelForImageSegmentation
 
7
  from loadimg import load_img
8
+ import os
9
 
10
  # -------------------------
11
  # Model Setup
 
13
  MODEL_DIR = "models/BiRefNet"
14
  os.makedirs(MODEL_DIR, exist_ok=True)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ print("Loading BiRefNet model...")
18
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
19
+ "ZhengPeng7/BiRefNet",
20
+ cache_dir=MODEL_DIR,
21
+ trust_remote_code=True,
22
+ revision="main"
23
+ )
24
+ birefnet.to(device).eval()
25
+ print("Model loaded successfully.")
26
 
27
  # -------------------------
28
  # Image Preprocessing
 
33
  mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
34
  std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
35
  arr = (arr - mean) / std
36
+ arr = np.transpose(arr, (2, 0, 1))
37
  tensor = torch.from_numpy(arr).unsqueeze(0).to(torch.float32).to(device)
38
  return tensor
39
 
 
49
  return image
50
 
51
  # -------------------------
52
+ # Gradio Function
53
  # -------------------------
54
+ def remove_background_gradio(input_img):
55
+ # Gradio passes PIL images directly
56
+ result = process_image(input_img.convert("RGB"))
57
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # -------------------------
60
+ # Gradio Interface
61
  # -------------------------
62
+ iface = gr.Interface(
63
+ fn=remove_background_gradio,
64
+ inputs=gr.Image(type="pil"),
65
+ outputs=gr.Image(type="pil"),
66
+ title="Background Removal Tool",
67
+ description="Upload an image and get a transparent background."
68
+ )
69
+
70
+ iface.launch(server_name="0.0.0.0", server_port=7860)