Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import io | |
| import base64 | |
| from typing import List, Optional | |
| import torch | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
| import uvicorn | |
| app = FastAPI(title="Wall Color Visualizer API") | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for SAM model | |
| sam_checkpoint = "sam_vit_h_4b8939.pth" | |
| model_type = "vit_h" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| sam = None | |
| mask_generator = None | |
| predictor = None | |
| # Request models | |
| class SegmentRequest(BaseModel): | |
| image_base64: str | |
| point_x: Optional[float] = None | |
| point_y: Optional[float] = None | |
| class ColorChangeRequest(BaseModel): | |
| image_base64: str | |
| mask_base64: str | |
| color_hex: str | |
| opacity: float = 0.8 | |
| # Initialize SAM model | |
| def initialize_sam(): | |
| global sam, mask_generator, predictor | |
| try: | |
| print(f"Loading SAM model on {device}...") | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
| sam.to(device=device) | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| predictor = SamPredictor(sam) | |
| print("SAM model loaded successfully!") | |
| except Exception as e: | |
| print(f"Warning: Could not load SAM model: {e}") | |
| print("The API will run but segmentation features will be limited.") | |
| async def startup_event(): | |
| initialize_sam() | |
| async def root(): | |
| return { | |
| "message": "Wall Color Visualizer API", | |
| "status": "running", | |
| "sam_loaded": sam is not None | |
| } | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "device": device, | |
| "sam_model_loaded": sam is not None | |
| } | |
| def decode_base64_image(base64_string: str) -> np.ndarray: | |
| """Decode base64 string to numpy array image""" | |
| try: | |
| # Remove data URL prefix if present | |
| if "base64," in base64_string: | |
| base64_string = base64_string.split("base64,")[1] | |
| img_data = base64.b64decode(base64_string) | |
| img = Image.open(io.BytesIO(img_data)) | |
| img_array = np.array(img.convert("RGB")) | |
| return img_array | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}") | |
| def encode_image_to_base64(image: np.ndarray) -> str: | |
| """Encode numpy array image to base64 string""" | |
| img = Image.fromarray(image.astype(np.uint8)) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return img_str | |
| def encode_mask_to_base64(mask: np.ndarray) -> str: | |
| """Encode binary mask to base64 string""" | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| img = Image.fromarray(mask_uint8) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| mask_str = base64.b64encode(buffered.getvalue()).decode() | |
| return mask_str | |
| def hex_to_rgb(hex_color: str) -> tuple: | |
| """Convert hex color to RGB tuple""" | |
| hex_color = hex_color.lstrip('#') | |
| return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) | |
| async def segment_automatic(file: UploadFile = File(...)): | |
| """Automatically segment all objects in the image""" | |
| if sam is None: | |
| raise HTTPException(status_code=503, detail="SAM model not loaded") | |
| try: | |
| # Read and decode image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| image_np = np.array(image.convert("RGB")) | |
| # Generate masks | |
| masks = mask_generator.generate(image_np) | |
| # Sort masks by area (largest first) | |
| masks = sorted(masks, key=lambda x: x['area'], reverse=True) | |
| # Return top masks | |
| result_masks = [] | |
| for i, mask_data in enumerate(masks[:2]): # Return top 10 masks | |
| mask = mask_data['segmentation'] | |
| result_masks.append({ | |
| "id": i, | |
| "mask_base64": encode_mask_to_base64(mask), | |
| "area": int(mask_data['area']), | |
| "bbox": [int(x) for x in mask_data['bbox']] | |
| }) | |
| return { | |
| "success": True, | |
| "num_masks": len(result_masks), | |
| "masks": result_masks, | |
| "image_base64": encode_image_to_base64(image_np) | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Segmentation failed: {str(e)}") | |
| async def segment_point(request: SegmentRequest): | |
| """Segment object at a specific point in the image""" | |
| if sam is None: | |
| raise HTTPException(status_code=503, detail="SAM model not loaded") | |
| try: | |
| # Decode image | |
| image_np = decode_base64_image(request.image_base64) | |
| # Set image for predictor | |
| predictor.set_image(image_np) | |
| # Use point prompt | |
| if request.point_x is not None and request.point_y is not None: | |
| point_coords = np.array([[request.point_x, request.point_y]]) | |
| point_labels = np.array([1]) # 1 = foreground point | |
| masks, scores, logits = predictor.predict( | |
| point_coords=point_coords, | |
| point_labels=point_labels, | |
| multimask_output=True | |
| ) | |
| # Get the best mask (highest score) | |
| best_mask_idx = np.argmax(scores) | |
| best_mask = masks[best_mask_idx] | |
| return { | |
| "success": True, | |
| "mask_base64": encode_mask_to_base64(best_mask), | |
| "score": float(scores[best_mask_idx]) | |
| } | |
| else: | |
| raise HTTPException(status_code=400, detail="Point coordinates required") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Segmentation failed: {str(e)}") | |
| async def apply_color(request: ColorChangeRequest): | |
| """Apply color to masked region of the image""" | |
| try: | |
| # Decode image and mask | |
| image_np = decode_base64_image(request.image_base64) | |
| mask_np = decode_base64_image(request.mask_base64) | |
| # Convert mask to binary | |
| if len(mask_np.shape) == 3: | |
| mask_np = cv2.cvtColor(mask_np, cv2.COLOR_RGB2GRAY) | |
| mask_binary = (mask_np > 128).astype(np.uint8) | |
| # Convert hex color to RGB | |
| rgb_color = hex_to_rgb(request.color_hex) | |
| # Create colored overlay | |
| colored_mask = np.zeros_like(image_np) | |
| colored_mask[mask_binary == 1] = rgb_color | |
| # Blend with original image | |
| result = image_np.copy().astype(float) | |
| alpha = request.opacity | |
| result[mask_binary == 1] = ( | |
| alpha * colored_mask[mask_binary == 1] + | |
| (1 - alpha) * image_np[mask_binary == 1] | |
| ) | |
| result = result.astype(np.uint8) | |
| return { | |
| "success": True, | |
| "result_base64": encode_image_to_base64(result) | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Color application failed: {str(e)}") | |
| async def simple_segment(file: UploadFile = File(...)): | |
| """Simple segmentation using traditional CV methods (fallback when SAM not available)""" | |
| try: | |
| # Read and decode image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| image_np = np.array(image.convert("RGB")) | |
| # Convert to different color spaces for better wall detection | |
| hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV) | |
| gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) | |
| # Apply edge detection | |
| edges = cv2.Canny(gray, 50, 150) | |
| # Dilate edges to create connected regions | |
| kernel = np.ones((5, 5), np.uint8) | |
| dilated = cv2.dilate(edges, kernel, iterations=2) | |
| # Find contours | |
| contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # Create masks for largest contours | |
| result_masks = [] | |
| h, w = image_np.shape[:2] | |
| # Sort by area | |
| contours = sorted(contours, key=cv2.contourArea, reverse=True) | |
| for i, contour in enumerate(contours[:5]): # Top 5 regions | |
| area = cv2.contourArea(contour) | |
| if area < (h * w * 0.01): # Skip very small regions | |
| continue | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| cv2.drawContours(mask, [contour], -1, 255, -1) | |
| # Get bounding box | |
| x, y, bw, bh = cv2.boundingRect(contour) | |
| result_masks.append({ | |
| "id": i, | |
| "mask_base64": encode_mask_to_base64(mask / 255), | |
| "area": int(area), | |
| "bbox": [int(x), int(y), int(bw), int(bh)] | |
| }) | |
| return { | |
| "success": True, | |
| "num_masks": len(result_masks), | |
| "masks": result_masks, | |
| "image_base64": encode_image_to_base64(image_np), | |
| "method": "traditional_cv" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Segmentation failed: {str(e)}") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |