Spaces:
Runtime error
Runtime error
| # from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| # from pydantic import BaseModel | |
| # from fastapi.middleware.cors import CORSMiddleware | |
| # import uvicorn | |
| # from pydantic import BaseModel | |
| # import json | |
| # import base64 | |
| # from io import BytesIO | |
| # from rembg import remove | |
| # from PIL import Image | |
| # import numpy as np | |
| # import scipy.ndimage as ndi | |
| # from model import detect_from_base64_image | |
| # def keep_largest_object_alpha(alpha_mask_np, threshold=0.1): | |
| # binary_mask = (alpha_mask_np > threshold).astype(np.uint8) | |
| # labeled, n = ndi.label(binary_mask) | |
| # if n == 0: | |
| # return alpha_mask_np | |
| # sizes = ndi.sum(binary_mask, labeled, range(n + 1)) | |
| # largest_label = sizes.argmax() | |
| # largest_mask = (labeled == largest_label).astype(np.float32) | |
| # return largest_mask * alpha_mask_np | |
| # def hex_to_rgb(hex_color: str): | |
| # hex_color = hex_color.lstrip('#') | |
| # return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) | |
| # def remove_bg_base64(base64_image: str, bg_color_hex: str = "#FFFFFF") -> str: | |
| # # Decode base64 image to PIL Image | |
| # image_data = base64.b64decode(base64_image) | |
| # input_image = Image.open(BytesIO(image_data)).convert("RGBA") | |
| # # Remove background | |
| # output = remove(input_image) | |
| # # Convert to RGBA and extract alpha | |
| # output = output.convert("RGBA") | |
| # r, g, b, a = output.split() | |
| # alpha_np = np.array(a).astype(np.float32) / 255.0 | |
| # # Keep only largest subject in the alpha mask | |
| # filtered_alpha = keep_largest_object_alpha(alpha_np) | |
| # # Convert filtered alpha to image | |
| # filtered_alpha_img = Image.fromarray((filtered_alpha * 255).astype(np.uint8)) | |
| # output.putalpha(filtered_alpha_img) | |
| # # Convert hex background color to RGB | |
| # bg_rgb = hex_to_rgb(bg_color_hex) | |
| # # Create colored background | |
| # bg_image = Image.new("RGBA", output.size, bg_rgb + (255,)) | |
| # bg_image.paste(output, (0, 0), mask=output.split()[3]) | |
| # # Final image (no alpha) | |
| # final_image = bg_image.convert("RGB") | |
| # # Convert to base64 string | |
| # buffered = BytesIO() | |
| # final_image.save(buffered, format="JPEG") | |
| # final_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| # return final_base64 | |
| # app = FastAPI() | |
| # app.add_middleware( | |
| # CORSMiddleware, | |
| # allow_origins=["*"], | |
| # allow_credentials=True, | |
| # allow_methods=["*"], | |
| # allow_headers=["*"], | |
| # ) | |
| # class ImageData(BaseModel): | |
| # image_data: str | |
| # colour_code: str | |
| # VERSION = "1.0.0" | |
| # @app.get("/") | |
| # async def health_check(): | |
| # return {"status": "healthy"} | |
| # @app.post("/BgRemoval") | |
| # async def bg_removal(request: ImageData): | |
| # try: | |
| # image_data = request.image_data | |
| # colour_code = request.colour_code | |
| # output_image = remove_bg_base64(image_data, colour_code) | |
| # return {"output_image": output_image} | |
| # except Exception as e: | |
| # print(f"Error Occured: {str(e)}") | |
| # return {"error": str(e)} | |
| # @app.websocket("/ws-detect") | |
| # async def websocket_detect(websocket: WebSocket): | |
| # await websocket.accept() | |
| # try: | |
| # while True: | |
| # try: | |
| # # ๐ธ Receive JSON data | |
| # payload = await websocket.receive_json() | |
| # base64_input = payload.get("image_data") | |
| # if not base64_input: | |
| # await websocket.send_json({"error": "Missing 'image_data' in request"}) | |
| # continue | |
| # # ๐ธ Call the detection function | |
| # result = detect_from_base64_image(base64_input) | |
| # # ๐ธ Send JSON result | |
| # await websocket.send_json(result) | |
| # except Exception as e: | |
| # await websocket.send_json({"error": str(e)}) | |
| # except WebSocketDisconnect: | |
| # print("WebSocket disconnected") | |
| # except Exception as e: | |
| # print(f"WebSocket closed: {e}") | |
| # # await websocket.close() | |
| # finally: | |
| # try: | |
| # await websocket.close() | |
| # print("WebSocket connection is closed.") | |
| # except RuntimeError as e: | |
| # print(f"WebSocket was already closed: {e}") | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from pydantic import BaseModel | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| from pydantic import BaseModel | |
| import json | |
| import base64 | |
| from io import BytesIO | |
| from rembg import remove | |
| from PIL import Image | |
| import numpy as np | |
| import scipy.ndimage as ndi | |
| from model import detect_from_base64_image | |
| def keep_largest_object_alpha(alpha_mask_np, threshold=0.1): | |
| binary_mask = (alpha_mask_np > threshold).astype(np.uint8) | |
| labeled, n = ndi.label(binary_mask) | |
| if n == 0: | |
| return alpha_mask_np | |
| sizes = ndi.sum(binary_mask, labeled, range(n + 1)) | |
| largest_label = sizes.argmax() | |
| largest_mask = (labeled == largest_label).astype(np.float32) | |
| return largest_mask * alpha_mask_np | |
| def hex_to_rgb(hex_color: str): | |
| hex_color = hex_color.lstrip('#') | |
| return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) | |
| def crop_image(input_image: Image.Image, width: int, height: int) -> Image.Image: | |
| """ | |
| Crop the center of the image to the specified width and height. | |
| """ | |
| img_width, img_height = input_image.size | |
| left = (img_width - width) // 2 | |
| top = (img_height - height) // 2 | |
| right = left + width | |
| bottom = top + height | |
| return input_image.crop((left, top, right, bottom)) | |
| def mm_to_pixels(mm: float, dpi: int = 300) -> int: | |
| """Convert millimeters to pixels based on image DPI.""" | |
| return int(round((mm * dpi) / 25.4)) # Rounding to nearest pixel | |
| def remove_bg_base64(base64_image: str, bg_color_hex: str = "#FFFFFF") -> str: | |
| # Decode base64 image to PIL Image | |
| image_data = base64.b64decode(base64_image) | |
| input_image = Image.open(BytesIO(image_data)).convert("RGBA") | |
| # Remove background | |
| output = remove(input_image) | |
| # Convert to RGBA and extract alpha | |
| output = output.convert("RGBA") | |
| r, g, b, a = output.split() | |
| alpha_np = np.array(a).astype(np.float32) / 255.0 | |
| # Keep only largest subject in the alpha mask | |
| filtered_alpha = keep_largest_object_alpha(alpha_np) | |
| # Convert filtered alpha to image | |
| filtered_alpha_img = Image.fromarray((filtered_alpha * 255).astype(np.uint8)) | |
| output.putalpha(filtered_alpha_img) | |
| # Convert hex background color to RGB | |
| bg_rgb = hex_to_rgb(bg_color_hex) | |
| # Create colored background | |
| bg_image = Image.new("RGBA", output.size, bg_rgb + (255,)) | |
| bg_image.paste(output, (0, 0), mask=output.split()[3]) | |
| # Final image (no alpha) | |
| final_image = bg_image.convert("RGB") | |
| # Convert to base64 string | |
| buffered = BytesIO() | |
| final_image.save(buffered, format="JPEG") | |
| final_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return final_base64 | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # class ImageData(BaseModel): | |
| # image_data: str | |
| # colour_code: str | |
| class ImageData(BaseModel): | |
| image_data: str | |
| colour_code: str | |
| width: float | |
| height: float | |
| VERSION = "1.0.0" | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| # @app.post("/BgRemoval") | |
| # async def bg_removal(request: ImageData): | |
| # try: | |
| # image_data = request.image_data | |
| # colour_code = request.colour_code | |
| # output_image = remove_bg_base64(image_data, colour_code) | |
| # return {"output_image": output_image} | |
| # except Exception as e: | |
| # print(f"Error Occured: {str(e)}") | |
| # return {"error": str(e)} | |
| async def bg_removal(request: ImageData): | |
| try: | |
| image_data = request.image_data | |
| colour_code = request.colour_code | |
| width_px = mm_to_pixels(request.width) | |
| height_px = mm_to_pixels(request.height) | |
| # Decode base64 image to PIL Image | |
| image_data = base64.b64decode(image_data) | |
| input_image = Image.open(BytesIO(image_data)).convert("RGBA") | |
| # Crop the image using pixel dimensions | |
| cropped_image = crop_image(input_image, width_px, height_px) | |
| # Convert cropped image back to base64 | |
| buffered = BytesIO() | |
| cropped_image.save(buffered, format="PNG") | |
| cropped_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| # Remove background from cropped image | |
| output_image = remove_bg_base64(cropped_base64, colour_code) | |
| return {"output_image": output_image} | |
| except Exception as e: | |
| print(f"Error Occurred: {str(e)}") | |
| return {"error": str(e)} | |
| async def websocket_detect(websocket: WebSocket): | |
| await websocket.accept() | |
| try: | |
| while True: | |
| try: | |
| # ๐ธ Receive JSON data | |
| payload = await websocket.receive_json() | |
| base64_input = payload.get("image_data") | |
| if not base64_input: | |
| await websocket.send_json({"error": "Missing 'image_data' in request"}) | |
| continue | |
| # ๐ธ Call the detection function | |
| result = detect_from_base64_image(base64_input) | |
| # ๐ธ Send JSON result | |
| await websocket.send_json(result) | |
| except Exception as e: | |
| await websocket.send_json({"error": str(e)}) | |
| except WebSocketDisconnect: | |
| print("WebSocket disconnected") | |
| except Exception as e: | |
| print(f"WebSocket closed: {e}") | |
| # await websocket.close() | |
| finally: | |
| try: | |
| await websocket.close() | |
| print("WebSocket connection is closed.") | |
| except RuntimeError as e: | |
| print(f"WebSocket was already closed: {e}") | |