minh9972t12 commited on
Commit
ac5bf8b
·
verified ·
1 Parent(s): e7da0b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +745 -64
app.py CHANGED
@@ -1,70 +1,751 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Query, BackgroundTasks
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ import io
6
+ from typing import List, Dict, Any, Optional, Tuple
7
+ from pydantic import BaseModel
8
+ import logging
9
+ from pathlib import Path
10
+ import time
11
+ import hashlib
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from collections import defaultdict
14
+ from dataclasses import dataclass, field
15
+ import warnings
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ import torch
18
+ from torchvision import transforms
19
+ import onnxruntime as ort
20
+ from sklearn.cluster import KMeans
21
+ import uvicorn
22
+ # PaddleOCR for Vietnamese
23
+ try:
24
+ from paddleocr import PaddleOCR
25
+ PADDLEOCR_AVAILABLE = True
26
+ except ImportError:
27
+ PADDLEOCR_AVAILABLE = False
28
+ logger.warning("PaddleOCR not available. Install: pip install paddleocr")
29
+
30
+ warnings.filterwarnings("ignore")
31
+
32
+ logging.basicConfig(level=logging.INFO)
33
+ logger = logging.getLogger(__name__)
34
+
35
+ app = FastAPI(
36
+ title="Fixed Seat Extraction API - Smart Color Detection",
37
+ description="Detects ALL colors except pure black and white",
38
+ version="6.0.0"
39
+ )
40
+
41
+ app.add_middleware(
42
+ CORSMiddleware,
43
+ allow_origins=["*"],
44
+ allow_credentials=True,
45
+ allow_methods=["*"],
46
+ allow_headers=["*"],
47
+ )
48
+
49
+ CACHE_DIR = Path("cache")
50
+ CACHE_DIR.mkdir(exist_ok=True)
51
+ RESULTS_CACHE = {}
52
+ MAX_CACHE_SIZE = 100
53
+
54
+ extractor = None
55
+
56
+
57
+ class PolygonResponse(BaseModel):
58
+ polygons: List[List[List[float]]]
59
+ confidence_scores: List[float]
60
+ areas: List[float]
61
+ bounding_boxes: List[List[float]]
62
+ labels: List[str]
63
+ seat_groups: Dict[str, List[int]]
64
+ processing_info: Dict[str, Any]
65
+ cache_hit: bool = False
66
+ detected_text: List[Dict[str, Any]] = []
67
+ geojson: Optional[Dict[str, Any]] = None
68
+
69
+
70
+ @dataclass
71
+ class OptimizationConfig:
72
+ """Fixed configuration - detect all colors except black/white"""
73
+ use_background_removal: bool = True
74
+ use_ocr: bool = True
75
+
76
+ # Color detection - NEW LOGIC
77
+ # Loại BỎ thuần đen và thuần trắng, GIỮ LẠI tất cả còn lại
78
+ exclude_pure_black: bool = True # V < 20 in HSV
79
+ exclude_pure_white: bool = True # V > 235 AND S < 25 in HSV
80
+
81
+ # Clustering để group màu giống nhau
82
+ use_color_clustering: bool = True
83
+ n_color_clusters: int = 20 # Số lượng nhóm màu
84
+
85
+ # Detection thresholds
86
+ min_section_area: int = 500 # Diện tích tối thiểu
87
+ max_section_area: int = 50000
88
+ min_solidity: float = 0.3 # Shape quality
89
+
90
+ # Morphology
91
+ morphology_kernel_size: int = 3
92
+
93
+ # OCR
94
+ ocr_languages: List[str] = field(default_factory=lambda: ["vi", "en"])
95
+ ocr_gpu: bool = True
96
+
97
+
98
+ class BackgroundRemover:
99
+ """Background removal using BiRefNet ONNX"""
100
+
101
+ def __init__(self):
102
+ self.session = None
103
+ self.input_name = None
104
+ self.output_name = None
105
+ self.transform = transforms.Compose([
106
+ transforms.Resize((1024, 1024)),
107
+ transforms.ToTensor(),
108
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
109
+ ])
110
+
111
+ def load_model(self):
112
+ if self.session is None:
113
+ try:
114
+ providers = []
115
+ if ort.get_device() == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers():
116
+ providers.append('CUDAExecutionProvider')
117
+ providers.append('CPUExecutionProvider')
118
+
119
+ model_path = "models/BiRefNet.onnx"
120
+ self.session = ort.InferenceSession(model_path, providers=providers)
121
+ self.input_name = self.session.get_inputs()[0].name
122
+ self.output_name = self.session.get_outputs()[0].name
123
+
124
+ logger.info(f"✅ BiRefNet loaded: {self.session.get_providers()}")
125
+ except Exception as e:
126
+ logger.error(f"BiRefNet load failed: {e}")
127
+ self.session = None
128
+
129
+ def remove_background(self, image: Image.Image) -> Tuple[Image.Image, np.ndarray]:
130
+ if self.session is None:
131
+ if image.mode != 'RGB':
132
+ image = image.convert('RGB')
133
+ return image, None
134
+
135
+ if image.mode != 'RGB':
136
+ image = image.convert('RGB')
137
+
138
+ image_size = image.size
139
+ input_tensor = self.transform(image).unsqueeze(0)
140
+ input_numpy = input_tensor.numpy()
141
+
142
+ try:
143
+ outputs = self.session.run([self.output_name], {self.input_name: input_numpy})
144
+ pred_numpy = outputs[0][0]
145
+ pred_numpy = 1 / (1 + np.exp(-pred_numpy))
146
+
147
+ if len(pred_numpy.shape) == 3:
148
+ pred_numpy = pred_numpy[0]
149
+
150
+ pred_numpy = (pred_numpy * 255).astype(np.uint8)
151
+ pred_pil = Image.fromarray(pred_numpy, mode='L')
152
+ mask = pred_pil.resize(image_size)
153
+ except Exception as e:
154
+ logger.error(f"ONNX inference failed: {e}")
155
+ return image, None
156
+
157
+ mask_np = np.array(mask)
158
+ if len(mask_np.shape) == 3:
159
+ mask_np = mask_np[:, :, 0]
160
+
161
+ image_array = np.array(image)
162
+ if len(image_array.shape) == 2:
163
+ image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB)
164
+ elif image_array.shape[2] == 4:
165
+ image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB)
166
+
167
+ masked_array = np.zeros_like(image_array)
168
+ mask_normalized = mask_np.astype(np.float32) / 255.0
169
+
170
+ for c in range(3):
171
+ masked_array[:, :, c] = (image_array[:, :, c] * mask_normalized).astype(np.uint8)
172
+
173
+ processed_image = Image.fromarray(masked_array)
174
+ return processed_image, mask_np
175
+
176
+
177
+ class TextDetector:
178
+ """OCR with Vietnamese support using PaddleOCR"""
179
+
180
+ def __init__(self, config: OptimizationConfig):
181
+ self.config = config
182
+ self.ocr = None
183
+
184
+ def load_models(self):
185
+ if not PADDLEOCR_AVAILABLE:
186
+ logger.error("PaddleOCR not available")
187
+ return
188
+
189
+ try:
190
+ # Initialize PaddleOCR với mobile lite models cho Vietnamese
191
+ self.ocr = PaddleOCR(
192
+ lang='latin', # Vietnamese sử dụng latin script
193
+ # Sử dụng PP-OCRv4 mobile models (lightweight)
194
+ text_detection_model_name="PP-OCRv4_mobile_det",
195
+ text_recognition_model_name="PP-OCRv4_mobile_rec",
196
+ # Tắt các features không cần thiết để tăng tốc
197
+ use_angle_cls=False,
198
+ use_doc_orientation_classify=False,
199
+ use_doc_unwarping=False,
200
+ use_textline_orientation=False,
201
+ # GPU settings
202
+ use_gpu=torch.cuda.is_available() and self.config.ocr_gpu,
203
+ # Giảm batch size cho lightweight
204
+ det_db_box_thresh=0.5, # Detection threshold
205
+ det_db_unclip_ratio=1.6, # Unclip ratio cho bbox
206
+ # Rec settings
207
+ rec_batch_num=1,
208
+ drop_score=0.3, # Confidence threshold thấp để catch nhiều text
209
+ # Tắt logging
210
+ show_log=False
211
+ )
212
+ logger.info("✅ PaddleOCR loaded (PP-OCRv4_mobile) for Vietnamese")
213
+ logger.info(f" GPU enabled: {torch.cuda.is_available() and self.config.ocr_gpu}")
214
+ except Exception as e:
215
+ logger.error(f"PaddleOCR load failed: {e}")
216
+ import traceback
217
+ traceback.print_exc()
218
+ self.ocr = None
219
+
220
+ def preprocess_for_vietnamese_ocr(self, image: np.ndarray) -> np.ndarray:
221
+ """
222
+ Preprocessing tối ưu cho Vietnamese OCR với PaddleOCR
223
+ """
224
+ if len(image.shape) == 3:
225
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
226
+ else:
227
+ gray = image.copy()
228
+
229
+ # 1. Denoise
230
+ denoised = cv2.fastNlMeansDenoising(gray, h=7)
231
+
232
+ # 2. Sharpen để diacritics rõ hơn
233
+ kernel_sharpen = np.array([[-1,-1,-1],
234
+ [-1, 9,-1],
235
+ [-1,-1,-1]])
236
+ sharpened = cv2.filter2D(denoised, -1, kernel_sharpen)
237
+
238
+ # 3. CLAHE
239
+ clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))
240
+ enhanced = clahe.apply(sharpened)
241
+
242
+ # 4. Contrast
243
+ alpha = 1.3
244
+ beta = 10
245
+ adjusted = cv2.convertScaleAbs(enhanced, alpha=alpha, beta=beta)
246
+
247
+ # PaddleOCR có thể nhận grayscale hoặc RGB
248
+ # Trả về RGB để consistent
249
+ rgb = cv2.cvtColor(adjusted, cv2.COLOR_GRAY2RGB)
250
+
251
+ return rgb
252
+
253
+ def detect_language(self, text: str) -> str:
254
+ """Detect Vietnamese by diacritics"""
255
+ vietnamese_chars = 'àáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ'
256
+ if any(c in vietnamese_chars for c in text.lower()):
257
+ return 'vi'
258
+ return 'en'
259
+
260
+ def detect_text(self, image: np.ndarray) -> List[Dict]:
261
+ text_regions = []
262
+ if self.ocr is None:
263
+ logger.warning("PaddleOCR not initialized")
264
+ return text_regions
265
+
266
+ try:
267
+ # Preprocessing
268
+ preprocessed = self.preprocess_for_vietnamese_ocr(image)
269
+
270
+ # PaddleOCR inference
271
+ # result[0] là list của page đầu tiên
272
+ # Mỗi item: [bbox_points, (text, confidence)]
273
+ result = self.ocr.ocr(preprocessed, cls=False)
274
+
275
+ if result is None or len(result) == 0:
276
+ logger.warning("PaddleOCR returned no results")
277
+ return text_regions
278
+
279
+ # Parse kết quả
280
+ for line in result[0]:
281
+ if line is None:
282
+ continue
283
+
284
+ bbox_points, (text, confidence) = line
285
+
286
+ # bbox_points format: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
287
+ x_coords = [point[0] for point in bbox_points]
288
+ y_coords = [point[1] for point in bbox_points]
289
+
290
+ if confidence > 0.2: # Threshold thấp để catch nhiều text
291
+ # Detect language
292
+ language = self.detect_language(text)
293
+
294
+ text_regions.append({
295
+ 'bbox': [int(min(x_coords)), int(min(y_coords)),
296
+ int(max(x_coords)), int(max(y_coords))],
297
+ 'text': text,
298
+ 'confidence': float(confidence),
299
+ 'language': language
300
+ })
301
+ logger.info(f"OCR: '{text}' (conf: {confidence:.2f}, lang: {language})")
302
+
303
+ logger.info(f"✅ Detected {len(text_regions)} text regions")
304
+ except Exception as e:
305
+ logger.error(f"PaddleOCR failed: {e}")
306
+ import traceback
307
+ traceback.print_exc()
308
+
309
+ return text_regions
310
+
311
+
312
+ class SmartColorDetector:
313
  """
314
+ LOGIC MỚI: Detect TẤT CẢ màu NGOẠI TRỪ đen thuần trắng thuần
315
  """
316
+
317
+ def __init__(self, config: OptimizationConfig):
318
+ self.config = config
319
+
320
+ def create_valid_color_mask(self, image: np.ndarray) -> np.ndarray:
321
+ """
322
+ Tạo mask cho TẤT CẢ pixel có màu (không phải đen/trắng/xám thuần)
323
+
324
+ Trong HSV:
325
+ - Đen thuần: V (value) rất thấp (0-20)
326
+ - Trắng thuần: V rất cao (235-255) VÀ S (saturation) rất thấp (0-25)
327
+ - Xám thuần: S rất thấp (0-30) - không phân biệt hue
328
+ - MỌI màu khác: VALID!
329
+ """
330
+ hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
331
+ h, s, v = cv2.split(hsv)
332
+
333
+ # Tạo mask GIỮ LẠI tất cả pixel
334
+ valid_mask = np.ones(image.shape[:2], dtype=np.uint8) * 255
335
+
336
+ # Loại BỎ đen thuần: V < 20
337
+ if self.config.exclude_pure_black:
338
+ black_mask = v < 20
339
+ valid_mask[black_mask] = 0
340
+ logger.info(f"Excluded {np.sum(black_mask)} pure black pixels")
341
+
342
+ # Loại BỎ trắng thuần: V > 235 AND S < 25
343
+ if self.config.exclude_pure_white:
344
+ white_mask = (v > 235) & (s < 25)
345
+ valid_mask[white_mask] = 0
346
+ logger.info(f"Excluded {np.sum(white_mask)} pure white pixels")
347
+
348
+ # Loại BỎ xám thuần: S < 30 (màu không có saturation = màu xám)
349
+ # Nhưng KHÔNG loại nếu đã là đen hoặc trắng thuần (đã loại ở trên)
350
+ gray_mask = (s < 30) & (v >= 20) & (v <= 235)
351
+ valid_mask[gray_mask] = 0
352
+ logger.info(f"Excluded {np.sum(gray_mask)} gray pixels")
353
+
354
+ logger.info(f"Valid colored pixels: {np.sum(valid_mask > 0)}")
355
+ return valid_mask
356
+
357
+ def cluster_colors(self, image: np.ndarray, valid_mask: np.ndarray) -> List[np.ndarray]:
358
+ """
359
+ Group các màu giống nhau bằng K-means clustering
360
+ """
361
+ masks = []
362
+
363
+ # Lấy tất cả pixel hợp lệ
364
+ valid_pixels = image[valid_mask > 0]
365
+
366
+ if len(valid_pixels) < 100:
367
+ logger.warning("Not enough valid pixels for clustering")
368
+ return [valid_mask]
369
+
370
+ # K-means clustering
371
+ pixels_flat = valid_pixels.reshape(-1, 3).astype(np.float32)
372
+ n_clusters = min(self.config.n_color_clusters, len(pixels_flat) // 100)
373
+
374
+ if n_clusters < 2:
375
+ return [valid_mask]
376
+
377
+ logger.info(f"Clustering into {n_clusters} color groups...")
378
+
379
+ try:
380
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
381
+ labels = kmeans.fit_predict(pixels_flat)
382
+ centers = kmeans.cluster_centers_.astype(np.uint8)
383
+
384
+ # Tạo mask cho mỗi cluster
385
+ pixel_coords = np.argwhere(valid_mask > 0)
386
+
387
+ for cluster_id in range(n_clusters):
388
+ cluster_mask = np.zeros(image.shape[:2], dtype=np.uint8)
389
+ cluster_pixels = pixel_coords[labels == cluster_id]
390
+
391
+ if len(cluster_pixels) < 50:
392
+ continue
393
+
394
+ for coord in cluster_pixels:
395
+ cluster_mask[coord[0], coord[1]] = 255
396
+
397
+ # Clean up mask
398
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
399
+ cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
400
+ cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_OPEN, kernel, iterations=1)
401
+
402
+ if np.sum(cluster_mask) > 100:
403
+ masks.append(cluster_mask)
404
+ logger.info(f" Cluster {cluster_id}: {np.sum(cluster_mask)} pixels, "
405
+ f"center color: {centers[cluster_id]}")
406
+
407
+ except Exception as e:
408
+ logger.error(f"Clustering failed: {e}")
409
+ return [valid_mask]
410
+
411
+ return masks
412
+
413
 
414
+ class EnhancedSeatExtractor:
415
+ def __init__(self, config: OptimizationConfig = OptimizationConfig()):
416
+ self.config = config
417
+ self.executor = ThreadPoolExecutor(max_workers=4)
418
+ self.bg_remover = BackgroundRemover()
419
+ self.text_detector = TextDetector(config)
420
+ self.color_detector = SmartColorDetector(config)
421
+ logger.info("✅ Enhanced Extractor with Smart Color Detection initialized")
422
+
423
+ def compute_image_hash(self, image: np.ndarray) -> str:
424
+ return hashlib.md5(image.tobytes()).hexdigest()
425
+
426
+ def detect_sections_in_mask(self, mask: np.ndarray, text_regions: List[Dict]) -> List[Dict]:
427
+ """Detect sections from a color mask"""
428
+ sections = []
429
+
430
+ if np.sum(mask) < self.config.min_section_area:
431
+ return sections
432
+
433
+ # KHÔNG loại bỏ text regions - giữ nguyên sections hoàn chỉnh
434
+ # Text là PART OF section, không phải noise cần loại bỏ
435
+ text_excluded_mask = mask.copy()
436
+
437
+ # Morphological operations - GIảM iterations để không "ăn mòn" sections
438
+ kernel = cv2.getStructuringElement(
439
+ cv2.MORPH_ELLIPSE,
440
+ (self.config.morphology_kernel_size, self.config.morphology_kernel_size)
441
+ )
442
+ # Chỉ CLOSE để nối các vùng gần nhau, không OPEN để tránh làm nhỏ sections
443
+ cleaned_mask = cv2.morphologyEx(text_excluded_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
444
+
445
+ # Find contours
446
+ contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
447
+
448
+ for contour in contours:
449
+ area = cv2.contourArea(contour)
450
+
451
+ if area < self.config.min_section_area or area > self.config.max_section_area:
452
+ continue
453
+
454
+ # Check solidity (shape quality)
455
+ hull = cv2.convexHull(contour)
456
+ hull_area = cv2.contourArea(hull)
457
+ solidity = area / hull_area if hull_area > 0 else 0
458
+
459
+ if solidity < self.config.min_solidity:
460
+ continue
461
+
462
+ # Simplify contour
463
+ epsilon = 0.01 * cv2.arcLength(contour, True)
464
+ approx = cv2.approxPolyDP(contour, epsilon, True)
465
+
466
+ if len(approx) >= 3:
467
+ x, y, w, h = cv2.boundingRect(contour)
468
+ sections.append({
469
+ 'contour': approx,
470
+ 'bbox': [x, y, x + w, y + h],
471
+ 'area': area,
472
+ 'confidence': min(1.0, solidity),
473
+ 'center': (x + w // 2, y + h // 2),
474
+ 'solidity': solidity
475
+ })
476
+
477
+ return sections
478
+
479
+ def extract_polygons_enhanced(self, image: np.ndarray) -> PolygonResponse:
480
+ """Main extraction pipeline"""
481
+ start_time = time.time()
482
+
483
+ # Check cache
484
+ image_hash = self.compute_image_hash(image)
485
+ if image_hash in RESULTS_CACHE:
486
+ logger.info("Returning cached results")
487
+ cached_result = RESULTS_CACHE[image_hash]
488
+ cached_result.cache_hit = True
489
+ return cached_result
490
+
491
+ # Ensure RGB
492
+ if len(image.shape) == 2:
493
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
494
+ elif len(image.shape) == 3:
495
+ if image.shape[2] == 4:
496
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
497
+
498
+ # Step 1: Background Removal
499
+ if self.config.use_background_removal:
500
+ logger.info("🔄 Removing background...")
501
+ pil_image = Image.fromarray(image).convert('RGB')
502
+ processed_image, bg_mask = self.bg_remover.remove_background(pil_image)
503
+ image = np.array(processed_image)
504
+
505
+ if len(image.shape) != 3 or image.shape[2] != 3:
506
+ if len(image.shape) == 2:
507
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
508
+
509
+ # Step 2: OCR Text Detection
510
+ text_regions = []
511
+ if self.config.use_ocr:
512
+ logger.info("🔄 Detecting text...")
513
+ text_regions = self.text_detector.detect_text(image)
514
+
515
+ # Step 3: Smart Color Detection
516
+ logger.info("🔄 Detecting all colors (excluding black/white)...")
517
+ valid_color_mask = self.color_detector.create_valid_color_mask(image)
518
+
519
+ # Step 4: Cluster Colors
520
+ all_sections = []
521
+ if self.config.use_color_clustering:
522
+ logger.info("🔄 Clustering colors...")
523
+ color_masks = self.color_detector.cluster_colors(image, valid_color_mask)
524
+ logger.info(f"Found {len(color_masks)} color groups")
525
+
526
+ # Detect sections in each color group
527
+ for i, mask in enumerate(color_masks):
528
+ logger.info(f"Processing color group {i + 1}/{len(color_masks)}...")
529
+ sections = self.detect_sections_in_mask(mask, text_regions)
530
+
531
+ for section in sections:
532
+ section['color_group'] = i
533
+
534
+ all_sections.extend(sections)
535
+ logger.info(f" Found {len(sections)} sections in group {i}")
536
+ else:
537
+ # Single pass without clustering
538
+ all_sections = self.detect_sections_in_mask(valid_color_mask, text_regions)
539
+
540
+ # Step 5: Remove overlapping sections
541
+ filtered_sections = self.remove_overlapping_sections(all_sections)
542
+
543
+ # Convert to response format
544
+ polygons = []
545
+ confidence_scores = []
546
+ areas = []
547
+ bounding_boxes = []
548
+ labels = []
549
+
550
+ for i, section in enumerate(filtered_sections):
551
+ contour = section['contour']
552
+ polygon = contour.reshape(-1, 2).tolist()
553
+
554
+ polygons.append(polygon)
555
+ confidence_scores.append(section['confidence'])
556
+ areas.append(section['area'])
557
+ bounding_boxes.append(section['bbox'])
558
+ labels.append(f"Section_{i + 1}")
559
+
560
+ # Group sections
561
+ seat_groups = self.group_sections(filtered_sections)
562
+
563
+ processing_time = time.time() - start_time
564
+ geojson_output = self.to_geojson(filtered_sections)
565
+
566
+ response = PolygonResponse(
567
+ polygons=polygons,
568
+ confidence_scores=confidence_scores,
569
+ areas=areas,
570
+ bounding_boxes=bounding_boxes,
571
+ labels=labels,
572
+ seat_groups=seat_groups,
573
+ detected_text=[{
574
+ 'text': t['text'],
575
+ 'confidence': t['confidence'],
576
+ 'bbox': t['bbox'],
577
+ 'language': t.get('language', 'unknown')
578
+ } for t in text_regions],
579
+ processing_info={
580
+ "total_sections": len(polygons),
581
+ "total_text_regions": len(text_regions),
582
+ "vietnamese_text": sum(1 for t in text_regions if t.get('language') == 'vi'),
583
+ "english_text": sum(1 for t in text_regions if t.get('language') == 'en'),
584
+ "processing_time": processing_time,
585
+ "clustering_enabled": self.config.use_color_clustering
586
+ },
587
+ cache_hit=False,
588
+ geojson=geojson_output
589
+ )
590
+
591
+ # Cache result
592
+ if len(RESULTS_CACHE) >= MAX_CACHE_SIZE:
593
+ RESULTS_CACHE.pop(next(iter(RESULTS_CACHE)))
594
+ RESULTS_CACHE[image_hash] = response
595
+
596
+ return response
597
+
598
+ def remove_overlapping_sections(self, sections: List[Dict]) -> List[Dict]:
599
+ if not sections:
600
+ return sections
601
+
602
+ sorted_sections = sorted(sections, key=lambda x: x['confidence'], reverse=True)
603
+ filtered = []
604
+
605
+ for section in sorted_sections:
606
+ overlap = False
607
+ for accepted in filtered:
608
+ if self.calculate_overlap(section['bbox'], accepted['bbox']) > 0.5:
609
+ overlap = True
610
+ break
611
+
612
+ if not overlap:
613
+ filtered.append(section)
614
+
615
+ return filtered
616
+
617
+ def calculate_overlap(self, bbox1: List, bbox2: List) -> float:
618
+ x1_1, y1_1, x2_1, y2_1 = bbox1
619
+ x1_2, y1_2, x2_2, y2_2 = bbox2
620
+
621
+ x1_int = max(x1_1, x1_2)
622
+ y1_int = max(y1_1, y1_2)
623
+ x2_int = min(x2_1, x2_2)
624
+ y2_int = min(y2_1, y2_2)
625
+
626
+ if x2_int <= x1_int or y2_int <= y1_int:
627
+ return 0.0
628
+
629
+ intersection = (x2_int - x1_int) * (y2_int - y1_int)
630
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
631
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
632
+ union = area1 + area2 - intersection
633
+
634
+ return intersection / union if union > 0 else 0.0
635
+
636
+ def group_sections(self, sections: List[Dict]) -> Dict[str, List[int]]:
637
+ groups = defaultdict(list)
638
+
639
+ for idx, section in enumerate(sections):
640
+ group_id = section.get('color_group', 0)
641
+ groups[f"ColorGroup_{group_id}"].append(idx)
642
+
643
+ return dict(groups)
644
+
645
+ def to_geojson(self, sections: List[Dict]) -> Dict[str, Any]:
646
+ features = []
647
+ for section in sections:
648
+ contour = section['contour'].reshape(-1, 2).tolist()
649
+ features.append({
650
+ "type": "Feature",
651
+ "properties": {
652
+ "confidence": section.get("confidence"),
653
+ "area": section.get("area"),
654
+ "color_group": section.get("color_group")
655
+ },
656
+ "geometry": {
657
+ "type": "Polygon",
658
+ "coordinates": [[list(map(float, p)) for p in contour]]
659
+ }
660
+ })
661
+
662
+ return {
663
+ "type": "FeatureCollection",
664
+ "features": features
665
+ }
666
+
667
+
668
+ @app.on_event("startup")
669
+ async def startup_event():
670
+ global extractor
671
+ try:
672
+ config = OptimizationConfig(
673
+ use_background_removal=True,
674
+ use_ocr=True,
675
+ exclude_pure_black=True,
676
+ exclude_pure_white=True,
677
+ use_color_clustering=True,
678
+ n_color_clusters=20,
679
+ min_section_area=500,
680
+ max_section_area=50000,
681
+ ocr_languages=["vi", "en"], # For info only
682
+ ocr_gpu=True
683
+ )
684
+ extractor = EnhancedSeatExtractor(config)
685
+
686
+ logger.info("Loading BiRefNet...")
687
+ extractor.bg_remover.load_model()
688
+
689
+ logger.info("Loading PaddleOCR (PP-OCRv4_mobile)...")
690
+ extractor.text_detector.load_models()
691
+
692
+ logger.info("✅ System initialized successfully")
693
+ logger.info("✅ Using PaddleOCR lite for Vietnamese")
694
+ logger.info("✅ Color detection: ALL colors except pure black/white/gray")
695
+ except Exception as e:
696
+ logger.error(f"Initialization failed: {e}")
697
+ import traceback
698
+ traceback.print_exc()
699
+
700
+
701
+ @app.post("/extract-seats/", response_model=PolygonResponse)
702
+ async def extract_seats_endpoint(
703
+ file: UploadFile = File(...),
704
+ use_background_removal: bool = Query(True),
705
+ use_ocr: bool = Query(True),
706
+ use_clustering: bool = Query(True),
707
+ n_clusters: int = Query(20, ge=2, le=50)
708
+ ):
709
+ """
710
+ Extract sections with smart color detection
711
+
712
+ Detects ALL colors except:
713
+ - Pure black (V < 20 in HSV)
714
+ - Pure white (V > 235 AND S < 25 in HSV)
715
+ """
716
+ if extractor is None:
717
+ raise HTTPException(status_code=503, detail="System not initialized")
718
+
719
+ if not file.content_type.startswith("image/"):
720
+ raise HTTPException(status_code=400, detail="Must be an image")
721
+
722
+ try:
723
+ contents = await file.read()
724
+ image = Image.open(io.BytesIO(contents))
725
+ image_array = np.array(image)
726
+
727
+ # Update config
728
+ extractor.config.use_background_removal = use_background_removal
729
+ extractor.config.use_ocr = use_ocr
730
+ extractor.config.use_color_clustering = use_clustering
731
+ extractor.config.n_color_clusters = n_clusters
732
+
733
+ result = extractor.extract_polygons_enhanced(image_array)
734
+ return result
735
+
736
+ except Exception as e:
737
+ logger.error(f"Processing failed: {e}")
738
+ import traceback
739
+ traceback.print_exc()
740
+ raise HTTPException(status_code=500, detail=f"Failed: {str(e)}")
741
 
742
 
743
  if __name__ == "__main__":
744
+ import os
745
+ uvicorn.run(
746
+ "main:app",
747
+ host="0.0.0.0",
748
+ port=int(os.environ.get("PORT", 7860)),
749
+ reload=False,
750
+ log_level="info"
751
+ )