Bg-Removal / app.py
Amitshri's picture
Update app.py
b074695
# from fastapi import FastAPI, WebSocket, WebSocketDisconnect
# from pydantic import BaseModel
# from fastapi.middleware.cors import CORSMiddleware
# import uvicorn
# from pydantic import BaseModel
# import json
# import base64
# from io import BytesIO
# from rembg import remove
# from PIL import Image
# import numpy as np
# import scipy.ndimage as ndi
# from model import detect_from_base64_image
# def keep_largest_object_alpha(alpha_mask_np, threshold=0.1):
# binary_mask = (alpha_mask_np > threshold).astype(np.uint8)
# labeled, n = ndi.label(binary_mask)
# if n == 0:
# return alpha_mask_np
# sizes = ndi.sum(binary_mask, labeled, range(n + 1))
# largest_label = sizes.argmax()
# largest_mask = (labeled == largest_label).astype(np.float32)
# return largest_mask * alpha_mask_np
# def hex_to_rgb(hex_color: str):
# hex_color = hex_color.lstrip('#')
# return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
# def remove_bg_base64(base64_image: str, bg_color_hex: str = "#FFFFFF") -> str:
# # Decode base64 image to PIL Image
# image_data = base64.b64decode(base64_image)
# input_image = Image.open(BytesIO(image_data)).convert("RGBA")
# # Remove background
# output = remove(input_image)
# # Convert to RGBA and extract alpha
# output = output.convert("RGBA")
# r, g, b, a = output.split()
# alpha_np = np.array(a).astype(np.float32) / 255.0
# # Keep only largest subject in the alpha mask
# filtered_alpha = keep_largest_object_alpha(alpha_np)
# # Convert filtered alpha to image
# filtered_alpha_img = Image.fromarray((filtered_alpha * 255).astype(np.uint8))
# output.putalpha(filtered_alpha_img)
# # Convert hex background color to RGB
# bg_rgb = hex_to_rgb(bg_color_hex)
# # Create colored background
# bg_image = Image.new("RGBA", output.size, bg_rgb + (255,))
# bg_image.paste(output, (0, 0), mask=output.split()[3])
# # Final image (no alpha)
# final_image = bg_image.convert("RGB")
# # Convert to base64 string
# buffered = BytesIO()
# final_image.save(buffered, format="JPEG")
# final_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
# return final_base64
# app = FastAPI()
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# class ImageData(BaseModel):
# image_data: str
# colour_code: str
# VERSION = "1.0.0"
# @app.get("/")
# async def health_check():
# return {"status": "healthy"}
# @app.post("/BgRemoval")
# async def bg_removal(request: ImageData):
# try:
# image_data = request.image_data
# colour_code = request.colour_code
# output_image = remove_bg_base64(image_data, colour_code)
# return {"output_image": output_image}
# except Exception as e:
# print(f"Error Occured: {str(e)}")
# return {"error": str(e)}
# @app.websocket("/ws-detect")
# async def websocket_detect(websocket: WebSocket):
# await websocket.accept()
# try:
# while True:
# try:
# # ๐Ÿ”ธ Receive JSON data
# payload = await websocket.receive_json()
# base64_input = payload.get("image_data")
# if not base64_input:
# await websocket.send_json({"error": "Missing 'image_data' in request"})
# continue
# # ๐Ÿ”ธ Call the detection function
# result = detect_from_base64_image(base64_input)
# # ๐Ÿ”ธ Send JSON result
# await websocket.send_json(result)
# except Exception as e:
# await websocket.send_json({"error": str(e)})
# except WebSocketDisconnect:
# print("WebSocket disconnected")
# except Exception as e:
# print(f"WebSocket closed: {e}")
# # await websocket.close()
# finally:
# try:
# await websocket.close()
# print("WebSocket connection is closed.")
# except RuntimeError as e:
# print(f"WebSocket was already closed: {e}")
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from pydantic import BaseModel
import json
import base64
from io import BytesIO
from rembg import remove
from PIL import Image
import numpy as np
import scipy.ndimage as ndi
from model import detect_from_base64_image
def keep_largest_object_alpha(alpha_mask_np, threshold=0.1):
binary_mask = (alpha_mask_np > threshold).astype(np.uint8)
labeled, n = ndi.label(binary_mask)
if n == 0:
return alpha_mask_np
sizes = ndi.sum(binary_mask, labeled, range(n + 1))
largest_label = sizes.argmax()
largest_mask = (labeled == largest_label).astype(np.float32)
return largest_mask * alpha_mask_np
def hex_to_rgb(hex_color: str):
hex_color = hex_color.lstrip('#')
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
def crop_image(input_image: Image.Image, width: int, height: int) -> Image.Image:
"""
Crop the center of the image to the specified width and height.
"""
img_width, img_height = input_image.size
left = (img_width - width) // 2
top = (img_height - height) // 2
right = left + width
bottom = top + height
return input_image.crop((left, top, right, bottom))
def mm_to_pixels(mm: float, dpi: int = 300) -> int:
"""Convert millimeters to pixels based on image DPI."""
return int(round((mm * dpi) / 25.4)) # Rounding to nearest pixel
def remove_bg_base64(base64_image: str, bg_color_hex: str = "#FFFFFF") -> str:
# Decode base64 image to PIL Image
image_data = base64.b64decode(base64_image)
input_image = Image.open(BytesIO(image_data)).convert("RGBA")
# Remove background
output = remove(input_image)
# Convert to RGBA and extract alpha
output = output.convert("RGBA")
r, g, b, a = output.split()
alpha_np = np.array(a).astype(np.float32) / 255.0
# Keep only largest subject in the alpha mask
filtered_alpha = keep_largest_object_alpha(alpha_np)
# Convert filtered alpha to image
filtered_alpha_img = Image.fromarray((filtered_alpha * 255).astype(np.uint8))
output.putalpha(filtered_alpha_img)
# Convert hex background color to RGB
bg_rgb = hex_to_rgb(bg_color_hex)
# Create colored background
bg_image = Image.new("RGBA", output.size, bg_rgb + (255,))
bg_image.paste(output, (0, 0), mask=output.split()[3])
# Final image (no alpha)
final_image = bg_image.convert("RGB")
# Convert to base64 string
buffered = BytesIO()
final_image.save(buffered, format="JPEG")
final_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return final_base64
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# class ImageData(BaseModel):
# image_data: str
# colour_code: str
class ImageData(BaseModel):
image_data: str
colour_code: str
width: float
height: float
VERSION = "1.0.0"
@app.get("/")
async def health_check():
return {"status": "healthy"}
# @app.post("/BgRemoval")
# async def bg_removal(request: ImageData):
# try:
# image_data = request.image_data
# colour_code = request.colour_code
# output_image = remove_bg_base64(image_data, colour_code)
# return {"output_image": output_image}
# except Exception as e:
# print(f"Error Occured: {str(e)}")
# return {"error": str(e)}
@app.post("/BgRemoval")
async def bg_removal(request: ImageData):
try:
image_data = request.image_data
colour_code = request.colour_code
width_px = mm_to_pixels(request.width)
height_px = mm_to_pixels(request.height)
# Decode base64 image to PIL Image
image_data = base64.b64decode(image_data)
input_image = Image.open(BytesIO(image_data)).convert("RGBA")
# Crop the image using pixel dimensions
cropped_image = crop_image(input_image, width_px, height_px)
# Convert cropped image back to base64
buffered = BytesIO()
cropped_image.save(buffered, format="PNG")
cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Remove background from cropped image
output_image = remove_bg_base64(cropped_base64, colour_code)
return {"output_image": output_image}
except Exception as e:
print(f"Error Occurred: {str(e)}")
return {"error": str(e)}
@app.websocket("/ws-detect")
async def websocket_detect(websocket: WebSocket):
await websocket.accept()
try:
while True:
try:
# ๐Ÿ”ธ Receive JSON data
payload = await websocket.receive_json()
base64_input = payload.get("image_data")
if not base64_input:
await websocket.send_json({"error": "Missing 'image_data' in request"})
continue
# ๐Ÿ”ธ Call the detection function
result = detect_from_base64_image(base64_input)
# ๐Ÿ”ธ Send JSON result
await websocket.send_json(result)
except Exception as e:
await websocket.send_json({"error": str(e)})
except WebSocketDisconnect:
print("WebSocket disconnected")
except Exception as e:
print(f"WebSocket closed: {e}")
# await websocket.close()
finally:
try:
await websocket.close()
print("WebSocket connection is closed.")
except RuntimeError as e:
print(f"WebSocket was already closed: {e}")