pylogyn_detect / app.py
minhvtt's picture
Update app.py
f0aa586 verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Query, BackgroundTasks
import numpy as np
import cv2
import uvicorn
from PIL import Image
import io
from typing import List, Dict, Any, Optional, Tuple
from pydantic import BaseModel
import logging
from pathlib import Path
import time
import hashlib
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dataclasses import dataclass, field
import warnings
import torch
from torchvision import transforms
import onnxruntime as ort
from sklearn.cluster import KMeans
import os
os.environ["OMP_NUM_THREADS"] = "1"
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Seat Extraction API v9.0 (No OCR)",
description="BG removal → Section detection )",
version="9.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
CACHE_DIR = Path("cache")
CACHE_DIR.mkdir(exist_ok=True)
RESULTS_CACHE = {}
MAX_CACHE_SIZE = 100
extractor = None
class PolygonResponse(BaseModel):
polygons: List[List[List[float]]]
confidence_scores: List[float]
areas: List[float]
bounding_boxes: List[List[float]]
labels: List[str]
colors: List[str]
seat_groups: Dict[str, List[int]]
processing_info: Dict[str, Any]
cache_hit: bool = False
detected_text: List[Dict[str, Any]] = []
geojson: Optional[Dict[str, Any]] = None
@dataclass
class OptimizationConfig:
"""Configuration for seat detection (OCR removed)"""
use_background_removal: bool = True
# Color detection
exclude_pure_black: bool = True
exclude_pure_white: bool = True
use_color_clustering: bool = True
n_color_clusters: int = 20
# Detection thresholds
min_section_area: int = 500
max_section_area: int = 50000
min_solidity: float = 0.3
# Morphology
morphology_kernel_size: int = 3
class BackgroundRemover:
"""Background removal using BiRefNet ONNX"""
def __init__(self):
self.session = None
self.input_name = None
self.output_name = None
self.transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def load_model(self):
if self.session is None:
try:
providers = []
if ort.get_device() == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers():
providers.append('CUDAExecutionProvider')
providers.append('CPUExecutionProvider')
model_path = "models/BiRefNet.onnx"
self.session = ort.InferenceSession(model_path, providers=providers)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
logger.info(f"BiRefNet loaded: {self.session.get_providers()}")
except Exception as e:
logger.error(f"BiRefNet load failed: {e}")
self.session = None
def remove_background(self, image: Image.Image) -> Tuple[Image.Image, np.ndarray]:
if self.session is None:
if image.mode != 'RGB':
image = image.convert('RGB')
return image, None
if image.mode != 'RGB':
image = image.convert('RGB')
image_size = image.size
input_tensor = self.transform(image).unsqueeze(0)
input_numpy = input_tensor.numpy()
try:
outputs = self.session.run([self.output_name], {self.input_name: input_numpy})
pred_numpy = outputs[0][0]
pred_numpy = 1 / (1 + np.exp(-pred_numpy))
if len(pred_numpy.shape) == 3:
pred_numpy = pred_numpy[0]
pred_numpy = (pred_numpy * 255).astype(np.uint8)
pred_pil = Image.fromarray(pred_numpy, mode='L')
mask = pred_pil.resize(image_size)
except Exception as e:
logger.error(f"ONNX inference failed: {e}")
return image, None
mask_np = np.array(mask)
if len(mask_np.shape) == 3:
mask_np = mask_np[:, :, 0]
image_array = np.array(image)
if len(image_array.shape) == 2:
image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB)
elif image_array.shape[2] == 4:
image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB)
masked_array = np.zeros_like(image_array)
mask_normalized = mask_np.astype(np.float32) / 255.0
for c in range(3):
masked_array[:, :, c] = (image_array[:, :, c] * mask_normalized).astype(np.uint8)
processed_image = Image.fromarray(masked_array)
return processed_image, mask_np
class SmartColorDetector:
"""Detect all colors except pure black/white"""
def __init__(self, config: OptimizationConfig):
self.config = config
def create_valid_color_mask(self, image: np.ndarray) -> np.ndarray:
"""Create mask for all colored pixels (not pure black/white)"""
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
h, s, v = cv2.split(hsv)
valid_mask = np.ones(image.shape[:2], dtype=np.uint8) * 255
if self.config.exclude_pure_black:
black_mask = v < 20
valid_mask[black_mask] = 0
logger.info(f"Excluded {np.sum(black_mask)} pure black pixels")
if self.config.exclude_pure_white:
white_mask = (v > 235) & (s < 25)
valid_mask[white_mask] = 0
logger.info(f"Excluded {np.sum(white_mask)} pure white pixels")
logger.info(f"Valid colored pixels: {np.sum(valid_mask > 0)}")
return valid_mask
def cluster_colors(self, image: np.ndarray, valid_mask: np.ndarray) -> List[np.ndarray]:
"""Group similar colors using K-means clustering"""
masks = []
valid_pixels = image[valid_mask > 0]
if len(valid_pixels) < 100:
logger.warning("Not enough valid pixels for clustering")
return [valid_mask]
pixels_flat = valid_pixels.reshape(-1, 3).astype(np.float32)
n_clusters = min(self.config.n_color_clusters, len(pixels_flat) // 100)
if n_clusters < 2:
return [valid_mask]
logger.info(f"Clustering into {n_clusters} color groups...")
try:
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
labels = kmeans.fit_predict(pixels_flat)
centers = kmeans.cluster_centers_.astype(np.uint8)
pixel_coords = np.argwhere(valid_mask > 0)
for cluster_id in range(n_clusters):
cluster_mask = np.zeros(image.shape[:2], dtype=np.uint8)
cluster_pixels = pixel_coords[labels == cluster_id]
if len(cluster_pixels) < 50:
continue
for coord in cluster_pixels:
cluster_mask[coord[0], coord[1]] = 255
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_OPEN, kernel, iterations=1)
if np.sum(cluster_mask) > 100:
masks.append(cluster_mask)
logger.info(f" Cluster {cluster_id}: {np.sum(cluster_mask)} pixels, "
f"center color: {centers[cluster_id]}")
except Exception as e:
logger.error(f"Clustering failed: {e}")
return [valid_mask]
return masks
class EnhancedSeatExtractor:
def __init__(self, config: OptimizationConfig = OptimizationConfig()):
self.config = config
self.executor = ThreadPoolExecutor(max_workers=4)
self.bg_remover = BackgroundRemover()
self.color_detector = SmartColorDetector(config)
logger.info("Enhanced Extractor initialized")
def compute_image_hash(self, image: np.ndarray) -> str:
return hashlib.md5(image.tobytes()).hexdigest()
def extract_dominant_color(self, image: np.ndarray, contour: np.ndarray) -> str:
"""
Trích xuất màu chủ đạo từ contour và convert sang HEX
"""
# Tạo mask cho contour
mask = np.zeros(image.shape[:2], dtype=np.uint8)
cv2.drawContours(mask, [contour], 0, 255, -1)
# Lấy pixels trong vùng contour
pixels = image[mask > 0]
if len(pixels) == 0:
return "#808080" # Gray mặc định
# Tính màu trung bình
mean_color = np.mean(pixels, axis=0).astype(int)
# Convert RGB to HEX
hex_color = "#{:02x}{:02x}{:02x}".format(
int(mean_color[0]),
int(mean_color[1]),
int(mean_color[2])
)
return hex_color
def detect_sections_in_mask(self, mask: np.ndarray, image: np.ndarray) -> List[Dict]:
"""
Detect sections from a color mask và extract màu
"""
sections = []
if np.sum(mask) < self.config.min_section_area:
return sections
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE,
(self.config.morphology_kernel_size, self.config.morphology_kernel_size)
)
cleaned_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
area = cv2.contourArea(contour)
if area < self.config.min_section_area or area > self.config.max_section_area:
continue
hull = cv2.convexHull(contour)
hull_area = cv2.contourArea(hull)
solidity = area / hull_area if hull_area > 0 else 0
if solidity < self.config.min_solidity:
continue
epsilon = 0.01 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
if len(approx) >= 3:
x, y, w, h = cv2.boundingRect(contour)
# Extract màu chủ đạo
dominant_color = self.extract_dominant_color(image, approx)
sections.append({
'contour': approx,
'bbox': [x, y, x + w, y + h],
'area': area,
'confidence': min(1.0, solidity),
'center': (x + w // 2, y + h // 2),
'solidity': solidity,
'color': dominant_color
})
return sections
def extract_polygons_enhanced(self, image: np.ndarray) -> PolygonResponse:
"""
PIPELINE:
1. Background removal for section detection
2. Color detection & clustering
3. Section detection + Color extraction
"""
start_time = time.time()
image_hash = self.compute_image_hash(image)
if image_hash in RESULTS_CACHE:
logger.info("Returning cached results")
cached_result = RESULTS_CACHE[image_hash]
cached_result.cache_hit = True
return cached_result
# Ensure RGB
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif len(image.shape) == 3:
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Step 1: Background Removal
processed_image = image
if self.config.use_background_removal:
logger.info("Removing background for section detection...")
pil_image = Image.fromarray(image).convert('RGB')
bg_removed, bg_mask = self.bg_remover.remove_background(pil_image)
processed_image = np.array(bg_removed)
if len(processed_image.shape) != 3 or processed_image.shape[2] != 3:
if len(processed_image.shape) == 2:
processed_image = cv2.cvtColor(processed_image, cv2.COLOR_GRAY2RGB)
# Step 2: Smart Color Detection
logger.info("Detecting all colors (excluding black/white)...")
valid_color_mask = self.color_detector.create_valid_color_mask(processed_image)
# Step 3: Cluster Colors & Detect Sections
all_sections = []
if self.config.use_color_clustering:
logger.info("Clustering colors...")
color_masks = self.color_detector.cluster_colors(processed_image, valid_color_mask)
logger.info(f"Found {len(color_masks)} color groups")
for i, mask in enumerate(color_masks):
logger.info(f"Processing color group {i + 1}/{len(color_masks)}...")
sections = self.detect_sections_in_mask(mask, processed_image)
for section in sections:
section['color_group'] = i
all_sections.extend(sections)
logger.info(f" Found {len(sections)} sections in group {i}")
else:
all_sections = self.detect_sections_in_mask(valid_color_mask, processed_image)
# Step 4: Remove overlapping sections
filtered_sections = self.remove_overlapping_sections(all_sections)
# Convert to response format
polygons = []
confidence_scores = []
areas = []
bounding_boxes = []
labels = []
colors = []
for i, section in enumerate(filtered_sections):
contour = section['contour']
polygon = contour.reshape(-1, 2).tolist()
polygons.append(polygon)
confidence_scores.append(section['confidence'])
areas.append(section['area'])
bounding_boxes.append(section['bbox'])
labels.append(f"Section_{i + 1}")
colors.append(section['color'])
seat_groups = self.group_sections(filtered_sections)
processing_time = time.time() - start_time
geojson_output = self.to_geojson(filtered_sections)
response = PolygonResponse(
polygons=polygons,
confidence_scores=confidence_scores,
areas=areas,
bounding_boxes=bounding_boxes,
labels=labels,
colors=colors,
seat_groups=seat_groups,
detected_text=[],
processing_info={
"total_sections": len(polygons),
"total_text_regions": 0,
"sections_with_text": 0,
"vietnamese_text": 0,
"english_text": 0,
"processing_time": processing_time,
"ocr_engine": "None (Disabled for performance)",
"pipeline": "BG Removal → Color Detection → Section Detection + Color Extraction",
"techniques": [
"BiRefNet BG removal for section detection",
"Smart color detection (exclude black/white)",
"K-means color clustering",
"Morphological cleaning",
"Dominant color extraction (HEX format)"
]
},
cache_hit=False,
geojson=geojson_output
)
if len(RESULTS_CACHE) >= MAX_CACHE_SIZE:
RESULTS_CACHE.pop(next(iter(RESULTS_CACHE)))
RESULTS_CACHE[image_hash] = response
return response
def remove_overlapping_sections(self, sections: List[Dict]) -> List[Dict]:
if not sections:
return sections
sorted_sections = sorted(sections, key=lambda x: x['confidence'], reverse=True)
filtered = []
for section in sorted_sections:
overlap = False
for accepted in filtered:
if self.calculate_overlap(section['bbox'], accepted['bbox']) > 0.5:
overlap = True
break
if not overlap:
filtered.append(section)
return filtered
def calculate_overlap(self, bbox1: List, bbox2: List) -> float:
x1_1, y1_1, x2_1, y2_1 = bbox1
x1_2, y1_2, x2_2, y2_2 = bbox2
x1_int = max(x1_1, x1_2)
y1_int = max(y1_1, y1_2)
x2_int = min(x2_1, x2_2)
y2_int = min(y2_1, y2_2)
if x2_int <= x1_int or y2_int <= y1_int:
return 0.0
intersection = (x2_int - x1_int) * (y2_int - y1_int)
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def group_sections(self, sections: List[Dict]) -> Dict[str, List[int]]:
groups = defaultdict(list)
for idx, section in enumerate(sections):
group_id = section.get('color_group', 0)
groups[f"ColorGroup_{group_id}"].append(idx)
return dict(groups)
def to_geojson(self, sections: List[Dict]) -> Dict[str, Any]:
features = []
for section in sections:
contour = section['contour'].reshape(-1, 2).tolist()
features.append({
"type": "Feature",
"properties": {
"confidence": section.get("confidence"),
"area": section.get("area"),
"color_group": section.get("color_group"),
"color": section.get("color")
},
"geometry": {
"type": "Polygon",
"coordinates": [[list(map(float, p)) for p in contour]]
}
})
return {
"type": "FeatureCollection",
"features": features
}
@app.on_event("startup")
async def startup_event():
global extractor
try:
config = OptimizationConfig(
use_background_removal=True,
exclude_pure_black=True,
exclude_pure_white=True,
use_color_clustering=True,
n_color_clusters=20,
min_section_area=500,
max_section_area=50000
)
extractor = EnhancedSeatExtractor(config)
logger.info("Loading BiRefNet...")
extractor.bg_remover.load_model()
logger.info("System initialized successfully")
except Exception as e:
logger.error(f"Initialization failed: {e}")
import traceback
traceback.print_exc()
@app.post("/extract-seats/", response_model=PolygonResponse)
async def extract_seats_endpoint(
file: UploadFile = File(...),
use_background_removal: bool = Query(True),
use_clustering: bool = Query(True),
n_clusters: int = Query(20, ge=2, le=50)
):
"""
Extract sections with color information
PIPELINE:
1. Background removal for section detection
2. Color detection & clustering
3. Section detection + Color extraction
Response includes:
- colors: List of HEX color strings for each section
"""
if extractor is None:
raise HTTPException(status_code=503, detail="System not initialized")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Must be an image")
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image_array = np.array(image)
extractor.config.use_background_removal = use_background_removal
extractor.config.use_color_clustering = use_clustering
extractor.config.n_color_clusters = n_clusters
result = extractor.extract_polygons_enhanced(image_array)
return result
except Exception as e:
logger.error(f"Processing failed: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Failed: {str(e)}")
if __name__ == "__main__":
uvicorn.run(
"app:app",
host="0.0.0.0",
port=int(os.environ.get("PORT", 7860)),
reload=False,
log_level="info"
)