Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,11 +23,13 @@ pillow_heif.register_heif_opener()
|
|
| 23 |
executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4)
|
| 24 |
|
| 25 |
# -------------------------
|
| 26 |
-
# Model Setup (
|
| 27 |
# -------------------------
|
| 28 |
MODEL_DIR = "models/BiRefNet"
|
| 29 |
os.makedirs(MODEL_DIR, exist_ok=True)
|
|
|
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 31 |
|
| 32 |
print("Loading BiRefNet model (first run may take a while)...")
|
| 33 |
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
|
@@ -37,13 +39,15 @@ birefnet = AutoModelForImageSegmentation.from_pretrained(
|
|
| 37 |
)
|
| 38 |
birefnet.to(device)
|
| 39 |
birefnet.eval()
|
| 40 |
-
print("Model loaded successfully.")
|
| 41 |
|
| 42 |
# -------------------------
|
| 43 |
# Image Preprocessing
|
| 44 |
# -------------------------
|
|
|
|
|
|
|
| 45 |
def transform_image(image: Image.Image) -> torch.Tensor:
|
| 46 |
-
image = image.resize(
|
| 47 |
arr = np.array(image).astype(np.float32) / 255.0
|
| 48 |
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 49 |
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
@@ -53,16 +57,22 @@ def transform_image(image: Image.Image) -> torch.Tensor:
|
|
| 53 |
return tensor
|
| 54 |
|
| 55 |
def process_image_sync(image: Image.Image) -> BytesIO:
|
| 56 |
-
"""Process image synchronously and return PNG bytes (
|
| 57 |
image_size = image.size
|
| 58 |
input_tensor = transform_image(image)
|
|
|
|
| 59 |
with torch.no_grad():
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
pred = preds[0, 0].numpy()
|
| 63 |
mask = Image.fromarray((pred * 255).astype(np.uint8)).resize(image_size)
|
| 64 |
|
| 65 |
-
# Apply alpha mask and keep only in-memory result
|
| 66 |
image = image.copy()
|
| 67 |
image.putalpha(mask)
|
| 68 |
|
|
@@ -85,24 +95,24 @@ def open_image_safely(file_bytes: bytes) -> Image.Image:
|
|
| 85 |
img = Image.open(BytesIO(file_bytes))
|
| 86 |
fmt = (img.format or "").lower()
|
| 87 |
|
| 88 |
-
# Handle PDF: first page
|
| 89 |
if fmt == "pdf":
|
| 90 |
from pdf2image import convert_from_bytes
|
| 91 |
pdf_images = convert_from_bytes(file_bytes, first_page=1, last_page=1)
|
| 92 |
return pdf_images[0].convert("RGB")
|
| 93 |
|
| 94 |
-
# Handle
|
| 95 |
if fmt == "gif" and getattr(img, "is_animated", False):
|
| 96 |
img.seek(0)
|
| 97 |
return img.convert("RGB")
|
| 98 |
|
| 99 |
-
# Handle SVG
|
| 100 |
if fmt == "svg":
|
| 101 |
import cairosvg
|
| 102 |
png_bytes = cairosvg.svg2png(bytestring=file_bytes)
|
| 103 |
return Image.open(BytesIO(png_bytes)).convert("RGB")
|
| 104 |
|
| 105 |
-
# Other formats (HEIC, HEIF, JPG, PNG
|
| 106 |
return img.convert("RGB")
|
| 107 |
|
| 108 |
except Exception as e:
|
|
@@ -118,15 +128,11 @@ app = FastAPI(title="Background Removal API", description="Removes image backgro
|
|
| 118 |
# -------------------------
|
| 119 |
@app.post("/remove_bg_file")
|
| 120 |
async def remove_bg_file(file: UploadFile = File(...)):
|
| 121 |
-
"""Upload an image and get transparent PNG."""
|
| 122 |
try:
|
| 123 |
contents = await file.read()
|
| 124 |
image = open_image_safely(contents)
|
| 125 |
output_buffer = await process_image_async(image)
|
| 126 |
-
|
| 127 |
-
# Return directly from memory
|
| 128 |
return StreamingResponse(output_buffer, media_type="image/png")
|
| 129 |
-
|
| 130 |
except HTTPException as e:
|
| 131 |
raise e
|
| 132 |
except Exception as e:
|
|
@@ -134,7 +140,6 @@ async def remove_bg_file(file: UploadFile = File(...)):
|
|
| 134 |
|
| 135 |
@app.post("/remove_bg_url")
|
| 136 |
async def remove_bg_url(image_url: str = Form(...)):
|
| 137 |
-
"""Provide image URL and get transparent PNG."""
|
| 138 |
try:
|
| 139 |
image = load_img(image_url, output_type="pil").convert("RGB")
|
| 140 |
output_buffer = await process_image_async(image)
|
|
@@ -143,7 +148,7 @@ async def remove_bg_url(image_url: str = Form(...)):
|
|
| 143 |
raise HTTPException(status_code=500, detail=f"Error processing URL: {e}")
|
| 144 |
|
| 145 |
# -------------------------
|
| 146 |
-
# Web Interface
|
| 147 |
# -------------------------
|
| 148 |
@app.get("/", response_class=HTMLResponse)
|
| 149 |
async def index():
|
|
@@ -211,7 +216,7 @@ async def index():
|
|
| 211 |
fileForm.addEventListener('submit', async (e) => {
|
| 212 |
e.preventDefault();
|
| 213 |
const fileInput = document.getElementById('fileInput');
|
| 214 |
-
if (fileInput.files.length
|
| 215 |
const file = fileInput.files[0];
|
| 216 |
beforeImg.src = URL.createObjectURL(file);
|
| 217 |
const formData = new FormData();
|
|
@@ -249,16 +254,8 @@ async def index():
|
|
| 249 |
return HTMLResponse(content=html)
|
| 250 |
|
| 251 |
# -------------------------
|
| 252 |
-
# Run Server
|
| 253 |
# -------------------------
|
| 254 |
if __name__ == "__main__":
|
| 255 |
-
import sys
|
| 256 |
-
import os
|
| 257 |
-
import uvicorn
|
| 258 |
-
|
| 259 |
-
# Get current filename without .py
|
| 260 |
module_name = os.path.splitext(os.path.basename(__file__))[0]
|
| 261 |
-
|
| 262 |
-
# Run uvicorn using the detected module name
|
| 263 |
uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, workers=2)
|
| 264 |
-
|
|
|
|
| 23 |
executor = ThreadPoolExecutor(max_workers=os.cpu_count() or 4)
|
| 24 |
|
| 25 |
# -------------------------
|
| 26 |
+
# Model Setup (Load Once)
|
| 27 |
# -------------------------
|
| 28 |
MODEL_DIR = "models/BiRefNet"
|
| 29 |
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 30 |
+
|
| 31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
print(f"Using device: {device}")
|
| 33 |
|
| 34 |
print("Loading BiRefNet model (first run may take a while)...")
|
| 35 |
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
|
|
|
| 39 |
)
|
| 40 |
birefnet.to(device)
|
| 41 |
birefnet.eval()
|
| 42 |
+
print(f"Model loaded successfully on {device}.")
|
| 43 |
|
| 44 |
# -------------------------
|
| 45 |
# Image Preprocessing
|
| 46 |
# -------------------------
|
| 47 |
+
TARGET_SIZE = (512, 512) # Lower resolution for faster inference
|
| 48 |
+
|
| 49 |
def transform_image(image: Image.Image) -> torch.Tensor:
|
| 50 |
+
image = image.resize(TARGET_SIZE)
|
| 51 |
arr = np.array(image).astype(np.float32) / 255.0
|
| 52 |
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 53 |
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
|
|
| 57 |
return tensor
|
| 58 |
|
| 59 |
def process_image_sync(image: Image.Image) -> BytesIO:
|
| 60 |
+
"""Process image synchronously and return PNG bytes (in-memory)."""
|
| 61 |
image_size = image.size
|
| 62 |
input_tensor = transform_image(image)
|
| 63 |
+
|
| 64 |
with torch.no_grad():
|
| 65 |
+
if device == "cuda":
|
| 66 |
+
# Mixed precision for GPU
|
| 67 |
+
with torch.cuda.amp.autocast():
|
| 68 |
+
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
|
| 69 |
+
else:
|
| 70 |
+
# CPU fallback
|
| 71 |
+
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
|
| 72 |
|
| 73 |
pred = preds[0, 0].numpy()
|
| 74 |
mask = Image.fromarray((pred * 255).astype(np.uint8)).resize(image_size)
|
| 75 |
|
|
|
|
| 76 |
image = image.copy()
|
| 77 |
image.putalpha(mask)
|
| 78 |
|
|
|
|
| 95 |
img = Image.open(BytesIO(file_bytes))
|
| 96 |
fmt = (img.format or "").lower()
|
| 97 |
|
| 98 |
+
# Handle PDF: first page
|
| 99 |
if fmt == "pdf":
|
| 100 |
from pdf2image import convert_from_bytes
|
| 101 |
pdf_images = convert_from_bytes(file_bytes, first_page=1, last_page=1)
|
| 102 |
return pdf_images[0].convert("RGB")
|
| 103 |
|
| 104 |
+
# Handle GIF: first frame
|
| 105 |
if fmt == "gif" and getattr(img, "is_animated", False):
|
| 106 |
img.seek(0)
|
| 107 |
return img.convert("RGB")
|
| 108 |
|
| 109 |
+
# Handle SVG
|
| 110 |
if fmt == "svg":
|
| 111 |
import cairosvg
|
| 112 |
png_bytes = cairosvg.svg2png(bytestring=file_bytes)
|
| 113 |
return Image.open(BytesIO(png_bytes)).convert("RGB")
|
| 114 |
|
| 115 |
+
# Other formats (HEIC, HEIF, JPG, PNG)
|
| 116 |
return img.convert("RGB")
|
| 117 |
|
| 118 |
except Exception as e:
|
|
|
|
| 128 |
# -------------------------
|
| 129 |
@app.post("/remove_bg_file")
|
| 130 |
async def remove_bg_file(file: UploadFile = File(...)):
|
|
|
|
| 131 |
try:
|
| 132 |
contents = await file.read()
|
| 133 |
image = open_image_safely(contents)
|
| 134 |
output_buffer = await process_image_async(image)
|
|
|
|
|
|
|
| 135 |
return StreamingResponse(output_buffer, media_type="image/png")
|
|
|
|
| 136 |
except HTTPException as e:
|
| 137 |
raise e
|
| 138 |
except Exception as e:
|
|
|
|
| 140 |
|
| 141 |
@app.post("/remove_bg_url")
|
| 142 |
async def remove_bg_url(image_url: str = Form(...)):
|
|
|
|
| 143 |
try:
|
| 144 |
image = load_img(image_url, output_type="pil").convert("RGB")
|
| 145 |
output_buffer = await process_image_async(image)
|
|
|
|
| 148 |
raise HTTPException(status_code=500, detail=f"Error processing URL: {e}")
|
| 149 |
|
| 150 |
# -------------------------
|
| 151 |
+
# Web Interface
|
| 152 |
# -------------------------
|
| 153 |
@app.get("/", response_class=HTMLResponse)
|
| 154 |
async def index():
|
|
|
|
| 216 |
fileForm.addEventListener('submit', async (e) => {
|
| 217 |
e.preventDefault();
|
| 218 |
const fileInput = document.getElementById('fileInput');
|
| 219 |
+
if (!fileInput.files.length) return alert("Select a file!");
|
| 220 |
const file = fileInput.files[0];
|
| 221 |
beforeImg.src = URL.createObjectURL(file);
|
| 222 |
const formData = new FormData();
|
|
|
|
| 254 |
return HTMLResponse(content=html)
|
| 255 |
|
| 256 |
# -------------------------
|
| 257 |
+
# Run Server (Auto-detect filename)
|
| 258 |
# -------------------------
|
| 259 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
module_name = os.path.splitext(os.path.basename(__file__))[0]
|
|
|
|
|
|
|
| 261 |
uvicorn.run(f"{module_name}:app", host="0.0.0.0", port=7860, workers=2)
|
|
|