Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, Query, BackgroundTasks | |
| import numpy as np | |
| import cv2 | |
| import uvicorn | |
| from PIL import Image | |
| import io | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from pydantic import BaseModel | |
| import logging | |
| from pathlib import Path | |
| import time | |
| import hashlib | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from concurrent.futures import ThreadPoolExecutor | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| import warnings | |
| import torch | |
| from torchvision import transforms | |
| import onnxruntime as ort | |
| from sklearn.cluster import KMeans | |
| import os | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| warnings.filterwarnings("ignore") | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="Seat Extraction API v9.0 (No OCR)", | |
| description="BG removal → Section detection )", | |
| version="9.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| CACHE_DIR = Path("cache") | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| RESULTS_CACHE = {} | |
| MAX_CACHE_SIZE = 100 | |
| extractor = None | |
| class PolygonResponse(BaseModel): | |
| polygons: List[List[List[float]]] | |
| confidence_scores: List[float] | |
| areas: List[float] | |
| bounding_boxes: List[List[float]] | |
| labels: List[str] | |
| colors: List[str] | |
| seat_groups: Dict[str, List[int]] | |
| processing_info: Dict[str, Any] | |
| cache_hit: bool = False | |
| detected_text: List[Dict[str, Any]] = [] | |
| geojson: Optional[Dict[str, Any]] = None | |
| class OptimizationConfig: | |
| """Configuration for seat detection (OCR removed)""" | |
| use_background_removal: bool = True | |
| # Color detection | |
| exclude_pure_black: bool = True | |
| exclude_pure_white: bool = True | |
| use_color_clustering: bool = True | |
| n_color_clusters: int = 20 | |
| # Detection thresholds | |
| min_section_area: int = 500 | |
| max_section_area: int = 50000 | |
| min_solidity: float = 0.3 | |
| # Morphology | |
| morphology_kernel_size: int = 3 | |
| class BackgroundRemover: | |
| """Background removal using BiRefNet ONNX""" | |
| def __init__(self): | |
| self.session = None | |
| self.input_name = None | |
| self.output_name = None | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| def load_model(self): | |
| if self.session is None: | |
| try: | |
| providers = [] | |
| if ort.get_device() == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers(): | |
| providers.append('CUDAExecutionProvider') | |
| providers.append('CPUExecutionProvider') | |
| model_path = "models/BiRefNet.onnx" | |
| self.session = ort.InferenceSession(model_path, providers=providers) | |
| self.input_name = self.session.get_inputs()[0].name | |
| self.output_name = self.session.get_outputs()[0].name | |
| logger.info(f"BiRefNet loaded: {self.session.get_providers()}") | |
| except Exception as e: | |
| logger.error(f"BiRefNet load failed: {e}") | |
| self.session = None | |
| def remove_background(self, image: Image.Image) -> Tuple[Image.Image, np.ndarray]: | |
| if self.session is None: | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| return image, None | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image_size = image.size | |
| input_tensor = self.transform(image).unsqueeze(0) | |
| input_numpy = input_tensor.numpy() | |
| try: | |
| outputs = self.session.run([self.output_name], {self.input_name: input_numpy}) | |
| pred_numpy = outputs[0][0] | |
| pred_numpy = 1 / (1 + np.exp(-pred_numpy)) | |
| if len(pred_numpy.shape) == 3: | |
| pred_numpy = pred_numpy[0] | |
| pred_numpy = (pred_numpy * 255).astype(np.uint8) | |
| pred_pil = Image.fromarray(pred_numpy, mode='L') | |
| mask = pred_pil.resize(image_size) | |
| except Exception as e: | |
| logger.error(f"ONNX inference failed: {e}") | |
| return image, None | |
| mask_np = np.array(mask) | |
| if len(mask_np.shape) == 3: | |
| mask_np = mask_np[:, :, 0] | |
| image_array = np.array(image) | |
| if len(image_array.shape) == 2: | |
| image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB) | |
| elif image_array.shape[2] == 4: | |
| image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB) | |
| masked_array = np.zeros_like(image_array) | |
| mask_normalized = mask_np.astype(np.float32) / 255.0 | |
| for c in range(3): | |
| masked_array[:, :, c] = (image_array[:, :, c] * mask_normalized).astype(np.uint8) | |
| processed_image = Image.fromarray(masked_array) | |
| return processed_image, mask_np | |
| class SmartColorDetector: | |
| """Detect all colors except pure black/white""" | |
| def __init__(self, config: OptimizationConfig): | |
| self.config = config | |
| def create_valid_color_mask(self, image: np.ndarray) -> np.ndarray: | |
| """Create mask for all colored pixels (not pure black/white)""" | |
| hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) | |
| h, s, v = cv2.split(hsv) | |
| valid_mask = np.ones(image.shape[:2], dtype=np.uint8) * 255 | |
| if self.config.exclude_pure_black: | |
| black_mask = v < 20 | |
| valid_mask[black_mask] = 0 | |
| logger.info(f"Excluded {np.sum(black_mask)} pure black pixels") | |
| if self.config.exclude_pure_white: | |
| white_mask = (v > 235) & (s < 25) | |
| valid_mask[white_mask] = 0 | |
| logger.info(f"Excluded {np.sum(white_mask)} pure white pixels") | |
| logger.info(f"Valid colored pixels: {np.sum(valid_mask > 0)}") | |
| return valid_mask | |
| def cluster_colors(self, image: np.ndarray, valid_mask: np.ndarray) -> List[np.ndarray]: | |
| """Group similar colors using K-means clustering""" | |
| masks = [] | |
| valid_pixels = image[valid_mask > 0] | |
| if len(valid_pixels) < 100: | |
| logger.warning("Not enough valid pixels for clustering") | |
| return [valid_mask] | |
| pixels_flat = valid_pixels.reshape(-1, 3).astype(np.float32) | |
| n_clusters = min(self.config.n_color_clusters, len(pixels_flat) // 100) | |
| if n_clusters < 2: | |
| return [valid_mask] | |
| logger.info(f"Clustering into {n_clusters} color groups...") | |
| try: | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) | |
| labels = kmeans.fit_predict(pixels_flat) | |
| centers = kmeans.cluster_centers_.astype(np.uint8) | |
| pixel_coords = np.argwhere(valid_mask > 0) | |
| for cluster_id in range(n_clusters): | |
| cluster_mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
| cluster_pixels = pixel_coords[labels == cluster_id] | |
| if len(cluster_pixels) < 50: | |
| continue | |
| for coord in cluster_pixels: | |
| cluster_mask[coord[0], coord[1]] = 255 | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) | |
| cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_OPEN, kernel, iterations=1) | |
| if np.sum(cluster_mask) > 100: | |
| masks.append(cluster_mask) | |
| logger.info(f" Cluster {cluster_id}: {np.sum(cluster_mask)} pixels, " | |
| f"center color: {centers[cluster_id]}") | |
| except Exception as e: | |
| logger.error(f"Clustering failed: {e}") | |
| return [valid_mask] | |
| return masks | |
| class EnhancedSeatExtractor: | |
| def __init__(self, config: OptimizationConfig = OptimizationConfig()): | |
| self.config = config | |
| self.executor = ThreadPoolExecutor(max_workers=4) | |
| self.bg_remover = BackgroundRemover() | |
| self.color_detector = SmartColorDetector(config) | |
| logger.info("Enhanced Extractor initialized") | |
| def compute_image_hash(self, image: np.ndarray) -> str: | |
| return hashlib.md5(image.tobytes()).hexdigest() | |
| def extract_dominant_color(self, image: np.ndarray, contour: np.ndarray) -> str: | |
| """ | |
| Trích xuất màu chủ đạo từ contour và convert sang HEX | |
| """ | |
| # Tạo mask cho contour | |
| mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
| cv2.drawContours(mask, [contour], 0, 255, -1) | |
| # Lấy pixels trong vùng contour | |
| pixels = image[mask > 0] | |
| if len(pixels) == 0: | |
| return "#808080" # Gray mặc định | |
| # Tính màu trung bình | |
| mean_color = np.mean(pixels, axis=0).astype(int) | |
| # Convert RGB to HEX | |
| hex_color = "#{:02x}{:02x}{:02x}".format( | |
| int(mean_color[0]), | |
| int(mean_color[1]), | |
| int(mean_color[2]) | |
| ) | |
| return hex_color | |
| def detect_sections_in_mask(self, mask: np.ndarray, image: np.ndarray) -> List[Dict]: | |
| """ | |
| Detect sections from a color mask và extract màu | |
| """ | |
| sections = [] | |
| if np.sum(mask) < self.config.min_section_area: | |
| return sections | |
| kernel = cv2.getStructuringElement( | |
| cv2.MORPH_ELLIPSE, | |
| (self.config.morphology_kernel_size, self.config.morphology_kernel_size) | |
| ) | |
| cleaned_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| for contour in contours: | |
| area = cv2.contourArea(contour) | |
| if area < self.config.min_section_area or area > self.config.max_section_area: | |
| continue | |
| hull = cv2.convexHull(contour) | |
| hull_area = cv2.contourArea(hull) | |
| solidity = area / hull_area if hull_area > 0 else 0 | |
| if solidity < self.config.min_solidity: | |
| continue | |
| epsilon = 0.01 * cv2.arcLength(contour, True) | |
| approx = cv2.approxPolyDP(contour, epsilon, True) | |
| if len(approx) >= 3: | |
| x, y, w, h = cv2.boundingRect(contour) | |
| # Extract màu chủ đạo | |
| dominant_color = self.extract_dominant_color(image, approx) | |
| sections.append({ | |
| 'contour': approx, | |
| 'bbox': [x, y, x + w, y + h], | |
| 'area': area, | |
| 'confidence': min(1.0, solidity), | |
| 'center': (x + w // 2, y + h // 2), | |
| 'solidity': solidity, | |
| 'color': dominant_color | |
| }) | |
| return sections | |
| def extract_polygons_enhanced(self, image: np.ndarray) -> PolygonResponse: | |
| """ | |
| PIPELINE: | |
| 1. Background removal for section detection | |
| 2. Color detection & clustering | |
| 3. Section detection + Color extraction | |
| """ | |
| start_time = time.time() | |
| image_hash = self.compute_image_hash(image) | |
| if image_hash in RESULTS_CACHE: | |
| logger.info("Returning cached results") | |
| cached_result = RESULTS_CACHE[image_hash] | |
| cached_result.cache_hit = True | |
| return cached_result | |
| # Ensure RGB | |
| if len(image.shape) == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| elif len(image.shape) == 3: | |
| if image.shape[2] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| # Step 1: Background Removal | |
| processed_image = image | |
| if self.config.use_background_removal: | |
| logger.info("Removing background for section detection...") | |
| pil_image = Image.fromarray(image).convert('RGB') | |
| bg_removed, bg_mask = self.bg_remover.remove_background(pil_image) | |
| processed_image = np.array(bg_removed) | |
| if len(processed_image.shape) != 3 or processed_image.shape[2] != 3: | |
| if len(processed_image.shape) == 2: | |
| processed_image = cv2.cvtColor(processed_image, cv2.COLOR_GRAY2RGB) | |
| # Step 2: Smart Color Detection | |
| logger.info("Detecting all colors (excluding black/white)...") | |
| valid_color_mask = self.color_detector.create_valid_color_mask(processed_image) | |
| # Step 3: Cluster Colors & Detect Sections | |
| all_sections = [] | |
| if self.config.use_color_clustering: | |
| logger.info("Clustering colors...") | |
| color_masks = self.color_detector.cluster_colors(processed_image, valid_color_mask) | |
| logger.info(f"Found {len(color_masks)} color groups") | |
| for i, mask in enumerate(color_masks): | |
| logger.info(f"Processing color group {i + 1}/{len(color_masks)}...") | |
| sections = self.detect_sections_in_mask(mask, processed_image) | |
| for section in sections: | |
| section['color_group'] = i | |
| all_sections.extend(sections) | |
| logger.info(f" Found {len(sections)} sections in group {i}") | |
| else: | |
| all_sections = self.detect_sections_in_mask(valid_color_mask, processed_image) | |
| # Step 4: Remove overlapping sections | |
| filtered_sections = self.remove_overlapping_sections(all_sections) | |
| # Convert to response format | |
| polygons = [] | |
| confidence_scores = [] | |
| areas = [] | |
| bounding_boxes = [] | |
| labels = [] | |
| colors = [] | |
| for i, section in enumerate(filtered_sections): | |
| contour = section['contour'] | |
| polygon = contour.reshape(-1, 2).tolist() | |
| polygons.append(polygon) | |
| confidence_scores.append(section['confidence']) | |
| areas.append(section['area']) | |
| bounding_boxes.append(section['bbox']) | |
| labels.append(f"Section_{i + 1}") | |
| colors.append(section['color']) | |
| seat_groups = self.group_sections(filtered_sections) | |
| processing_time = time.time() - start_time | |
| geojson_output = self.to_geojson(filtered_sections) | |
| response = PolygonResponse( | |
| polygons=polygons, | |
| confidence_scores=confidence_scores, | |
| areas=areas, | |
| bounding_boxes=bounding_boxes, | |
| labels=labels, | |
| colors=colors, | |
| seat_groups=seat_groups, | |
| detected_text=[], | |
| processing_info={ | |
| "total_sections": len(polygons), | |
| "total_text_regions": 0, | |
| "sections_with_text": 0, | |
| "vietnamese_text": 0, | |
| "english_text": 0, | |
| "processing_time": processing_time, | |
| "ocr_engine": "None (Disabled for performance)", | |
| "pipeline": "BG Removal → Color Detection → Section Detection + Color Extraction", | |
| "techniques": [ | |
| "BiRefNet BG removal for section detection", | |
| "Smart color detection (exclude black/white)", | |
| "K-means color clustering", | |
| "Morphological cleaning", | |
| "Dominant color extraction (HEX format)" | |
| ] | |
| }, | |
| cache_hit=False, | |
| geojson=geojson_output | |
| ) | |
| if len(RESULTS_CACHE) >= MAX_CACHE_SIZE: | |
| RESULTS_CACHE.pop(next(iter(RESULTS_CACHE))) | |
| RESULTS_CACHE[image_hash] = response | |
| return response | |
| def remove_overlapping_sections(self, sections: List[Dict]) -> List[Dict]: | |
| if not sections: | |
| return sections | |
| sorted_sections = sorted(sections, key=lambda x: x['confidence'], reverse=True) | |
| filtered = [] | |
| for section in sorted_sections: | |
| overlap = False | |
| for accepted in filtered: | |
| if self.calculate_overlap(section['bbox'], accepted['bbox']) > 0.5: | |
| overlap = True | |
| break | |
| if not overlap: | |
| filtered.append(section) | |
| return filtered | |
| def calculate_overlap(self, bbox1: List, bbox2: List) -> float: | |
| x1_1, y1_1, x2_1, y2_1 = bbox1 | |
| x1_2, y1_2, x2_2, y2_2 = bbox2 | |
| x1_int = max(x1_1, x1_2) | |
| y1_int = max(y1_1, y1_2) | |
| x2_int = min(x2_1, x2_2) | |
| y2_int = min(y2_1, y2_2) | |
| if x2_int <= x1_int or y2_int <= y1_int: | |
| return 0.0 | |
| intersection = (x2_int - x1_int) * (y2_int - y1_int) | |
| area1 = (x2_1 - x1_1) * (y2_1 - y1_1) | |
| area2 = (x2_2 - x1_2) * (y2_2 - y1_2) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0.0 | |
| def group_sections(self, sections: List[Dict]) -> Dict[str, List[int]]: | |
| groups = defaultdict(list) | |
| for idx, section in enumerate(sections): | |
| group_id = section.get('color_group', 0) | |
| groups[f"ColorGroup_{group_id}"].append(idx) | |
| return dict(groups) | |
| def to_geojson(self, sections: List[Dict]) -> Dict[str, Any]: | |
| features = [] | |
| for section in sections: | |
| contour = section['contour'].reshape(-1, 2).tolist() | |
| features.append({ | |
| "type": "Feature", | |
| "properties": { | |
| "confidence": section.get("confidence"), | |
| "area": section.get("area"), | |
| "color_group": section.get("color_group"), | |
| "color": section.get("color") | |
| }, | |
| "geometry": { | |
| "type": "Polygon", | |
| "coordinates": [[list(map(float, p)) for p in contour]] | |
| } | |
| }) | |
| return { | |
| "type": "FeatureCollection", | |
| "features": features | |
| } | |
| async def startup_event(): | |
| global extractor | |
| try: | |
| config = OptimizationConfig( | |
| use_background_removal=True, | |
| exclude_pure_black=True, | |
| exclude_pure_white=True, | |
| use_color_clustering=True, | |
| n_color_clusters=20, | |
| min_section_area=500, | |
| max_section_area=50000 | |
| ) | |
| extractor = EnhancedSeatExtractor(config) | |
| logger.info("Loading BiRefNet...") | |
| extractor.bg_remover.load_model() | |
| logger.info("System initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Initialization failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| async def extract_seats_endpoint( | |
| file: UploadFile = File(...), | |
| use_background_removal: bool = Query(True), | |
| use_clustering: bool = Query(True), | |
| n_clusters: int = Query(20, ge=2, le=50) | |
| ): | |
| """ | |
| Extract sections with color information | |
| PIPELINE: | |
| 1. Background removal for section detection | |
| 2. Color detection & clustering | |
| 3. Section detection + Color extraction | |
| Response includes: | |
| - colors: List of HEX color strings for each section | |
| """ | |
| if extractor is None: | |
| raise HTTPException(status_code=503, detail="System not initialized") | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Must be an image") | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| image_array = np.array(image) | |
| extractor.config.use_background_removal = use_background_removal | |
| extractor.config.use_color_clustering = use_clustering | |
| extractor.config.n_color_clusters = n_clusters | |
| result = extractor.extract_polygons_enhanced(image_array) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Processing failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Failed: {str(e)}") | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=int(os.environ.get("PORT", 7860)), | |
| reload=False, | |
| log_level="info" | |
| ) |