minh9972t12 commited on
Commit
fdc8d37
·
verified ·
1 Parent(s): ac5bf8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -719
app.py CHANGED
@@ -1,744 +1,158 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException, Query, BackgroundTasks
2
- import numpy as np
3
- import cv2
4
- from PIL import Image
5
- import io
6
- from typing import List, Dict, Any, Optional, Tuple
7
- from pydantic import BaseModel
8
- import logging
9
- from pathlib import Path
10
- import time
11
- import hashlib
12
- from concurrent.futures import ThreadPoolExecutor
13
- from collections import defaultdict
14
- from dataclasses import dataclass, field
15
- import warnings
16
- from fastapi.middleware.cors import CORSMiddleware
17
- import torch
18
- from torchvision import transforms
19
- import onnxruntime as ort
20
- from sklearn.cluster import KMeans
21
- import uvicorn
22
- # PaddleOCR for Vietnamese
23
- try:
24
- from paddleocr import PaddleOCR
25
- PADDLEOCR_AVAILABLE = True
26
- except ImportError:
27
- PADDLEOCR_AVAILABLE = False
28
- logger.warning("PaddleOCR not available. Install: pip install paddleocr")
29
-
30
- warnings.filterwarnings("ignore")
31
-
32
- logging.basicConfig(level=logging.INFO)
33
- logger = logging.getLogger(__name__)
34
-
35
- app = FastAPI(
36
- title="Fixed Seat Extraction API - Smart Color Detection",
37
- description="Detects ALL colors except pure black and white",
38
- version="6.0.0"
39
- )
40
-
41
- app.add_middleware(
42
- CORSMiddleware,
43
- allow_origins=["*"],
44
- allow_credentials=True,
45
- allow_methods=["*"],
46
- allow_headers=["*"],
47
- )
48
-
49
- CACHE_DIR = Path("cache")
50
- CACHE_DIR.mkdir(exist_ok=True)
51
- RESULTS_CACHE = {}
52
- MAX_CACHE_SIZE = 100
53
-
54
- extractor = None
55
-
56
-
57
- class PolygonResponse(BaseModel):
58
- polygons: List[List[List[float]]]
59
- confidence_scores: List[float]
60
- areas: List[float]
61
- bounding_boxes: List[List[float]]
62
- labels: List[str]
63
- seat_groups: Dict[str, List[int]]
64
- processing_info: Dict[str, Any]
65
- cache_hit: bool = False
66
- detected_text: List[Dict[str, Any]] = []
67
- geojson: Optional[Dict[str, Any]] = None
68
-
69
-
70
- @dataclass
71
- class OptimizationConfig:
72
- """Fixed configuration - detect all colors except black/white"""
73
- use_background_removal: bool = True
74
- use_ocr: bool = True
75
-
76
- # Color detection - NEW LOGIC
77
- # Loại BỎ thuần đen và thuần trắng, GIỮ LẠI tất cả còn lại
78
- exclude_pure_black: bool = True # V < 20 in HSV
79
- exclude_pure_white: bool = True # V > 235 AND S < 25 in HSV
80
-
81
- # Clustering để group màu giống nhau
82
- use_color_clustering: bool = True
83
- n_color_clusters: int = 20 # Số lượng nhóm màu
84
-
85
- # Detection thresholds
86
- min_section_area: int = 500 # Diện tích tối thiểu
87
- max_section_area: int = 50000
88
- min_solidity: float = 0.3 # Shape quality
89
-
90
- # Morphology
91
- morphology_kernel_size: int = 3
92
-
93
- # OCR
94
- ocr_languages: List[str] = field(default_factory=lambda: ["vi", "en"])
95
- ocr_gpu: bool = True
96
-
97
-
98
- class BackgroundRemover:
99
- """Background removal using BiRefNet ONNX"""
100
-
101
- def __init__(self):
102
- self.session = None
103
- self.input_name = None
104
- self.output_name = None
105
- self.transform = transforms.Compose([
106
- transforms.Resize((1024, 1024)),
107
- transforms.ToTensor(),
108
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
109
- ])
110
-
111
- def load_model(self):
112
- if self.session is None:
113
- try:
114
- providers = []
115
- if ort.get_device() == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers():
116
- providers.append('CUDAExecutionProvider')
117
- providers.append('CPUExecutionProvider')
118
-
119
- model_path = "models/BiRefNet.onnx"
120
- self.session = ort.InferenceSession(model_path, providers=providers)
121
- self.input_name = self.session.get_inputs()[0].name
122
- self.output_name = self.session.get_outputs()[0].name
123
-
124
- logger.info(f"✅ BiRefNet loaded: {self.session.get_providers()}")
125
- except Exception as e:
126
- logger.error(f"BiRefNet load failed: {e}")
127
- self.session = None
128
-
129
- def remove_background(self, image: Image.Image) -> Tuple[Image.Image, np.ndarray]:
130
- if self.session is None:
131
- if image.mode != 'RGB':
132
- image = image.convert('RGB')
133
- return image, None
134
-
135
- if image.mode != 'RGB':
136
- image = image.convert('RGB')
137
-
138
- image_size = image.size
139
- input_tensor = self.transform(image).unsqueeze(0)
140
- input_numpy = input_tensor.numpy()
141
-
142
- try:
143
- outputs = self.session.run([self.output_name], {self.input_name: input_numpy})
144
- pred_numpy = outputs[0][0]
145
- pred_numpy = 1 / (1 + np.exp(-pred_numpy))
146
-
147
- if len(pred_numpy.shape) == 3:
148
- pred_numpy = pred_numpy[0]
149
-
150
- pred_numpy = (pred_numpy * 255).astype(np.uint8)
151
- pred_pil = Image.fromarray(pred_numpy, mode='L')
152
- mask = pred_pil.resize(image_size)
153
- except Exception as e:
154
- logger.error(f"ONNX inference failed: {e}")
155
- return image, None
156
-
157
- mask_np = np.array(mask)
158
- if len(mask_np.shape) == 3:
159
- mask_np = mask_np[:, :, 0]
160
-
161
- image_array = np.array(image)
162
- if len(image_array.shape) == 2:
163
- image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB)
164
- elif image_array.shape[2] == 4:
165
- image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB)
166
-
167
- masked_array = np.zeros_like(image_array)
168
- mask_normalized = mask_np.astype(np.float32) / 255.0
169
-
170
- for c in range(3):
171
- masked_array[:, :, c] = (image_array[:, :, c] * mask_normalized).astype(np.uint8)
172
-
173
- processed_image = Image.fromarray(masked_array)
174
- return processed_image, mask_np
175
-
176
-
177
- class TextDetector:
178
- """OCR with Vietnamese support using PaddleOCR"""
179
 
180
- def __init__(self, config: OptimizationConfig):
181
- self.config = config
182
- self.ocr = None
183
-
184
- def load_models(self):
185
- if not PADDLEOCR_AVAILABLE:
186
- logger.error("PaddleOCR not available")
187
- return
188
-
189
- try:
190
- # Initialize PaddleOCR với mobile lite models cho Vietnamese
191
- self.ocr = PaddleOCR(
192
- lang='latin', # Vietnamese sử dụng latin script
193
- # Sử dụng PP-OCRv4 mobile models (lightweight)
194
- text_detection_model_name="PP-OCRv4_mobile_det",
195
- text_recognition_model_name="PP-OCRv4_mobile_rec",
196
- # Tắt các features không cần thiết để tăng tốc
197
- use_angle_cls=False,
198
- use_doc_orientation_classify=False,
199
- use_doc_unwarping=False,
200
- use_textline_orientation=False,
201
- # GPU settings
202
- use_gpu=torch.cuda.is_available() and self.config.ocr_gpu,
203
- # Giảm batch size cho lightweight
204
- det_db_box_thresh=0.5, # Detection threshold
205
- det_db_unclip_ratio=1.6, # Unclip ratio cho bbox
206
- # Rec settings
207
- rec_batch_num=1,
208
- drop_score=0.3, # Confidence threshold thấp để catch nhiều text
209
- # Tắt logging
210
- show_log=False
211
- )
212
- logger.info("✅ PaddleOCR loaded (PP-OCRv4_mobile) for Vietnamese")
213
- logger.info(f" GPU enabled: {torch.cuda.is_available() and self.config.ocr_gpu}")
214
- except Exception as e:
215
- logger.error(f"PaddleOCR load failed: {e}")
216
- import traceback
217
- traceback.print_exc()
218
- self.ocr = None
219
-
220
- def preprocess_for_vietnamese_ocr(self, image: np.ndarray) -> np.ndarray:
221
- """
222
- Preprocessing tối ưu cho Vietnamese OCR với PaddleOCR
223
- """
224
- if len(image.shape) == 3:
225
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
226
- else:
227
- gray = image.copy()
228
-
229
- # 1. Denoise
230
- denoised = cv2.fastNlMeansDenoising(gray, h=7)
231
-
232
- # 2. Sharpen để diacritics rõ hơn
233
- kernel_sharpen = np.array([[-1,-1,-1],
234
- [-1, 9,-1],
235
- [-1,-1,-1]])
236
- sharpened = cv2.filter2D(denoised, -1, kernel_sharpen)
237
-
238
- # 3. CLAHE
239
- clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))
240
- enhanced = clahe.apply(sharpened)
241
-
242
- # 4. Contrast
243
- alpha = 1.3
244
- beta = 10
245
- adjusted = cv2.convertScaleAbs(enhanced, alpha=alpha, beta=beta)
246
-
247
- # PaddleOCR có thể nhận grayscale hoặc RGB
248
- # Trả về RGB để consistent
249
- rgb = cv2.cvtColor(adjusted, cv2.COLOR_GRAY2RGB)
250
-
251
- return rgb
252
-
253
- def detect_language(self, text: str) -> str:
254
- """Detect Vietnamese by diacritics"""
255
- vietnamese_chars = 'àáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ'
256
- if any(c in vietnamese_chars for c in text.lower()):
257
- return 'vi'
258
- return 'en'
259
-
260
- def detect_text(self, image: np.ndarray) -> List[Dict]:
261
- text_regions = []
262
- if self.ocr is None:
263
- logger.warning("PaddleOCR not initialized")
264
- return text_regions
265
-
266
- try:
267
- # Preprocessing
268
- preprocessed = self.preprocess_for_vietnamese_ocr(image)
269
-
270
- # PaddleOCR inference
271
- # result[0] là list của page đầu tiên
272
- # Mỗi item: [bbox_points, (text, confidence)]
273
- result = self.ocr.ocr(preprocessed, cls=False)
274
-
275
- if result is None or len(result) == 0:
276
- logger.warning("PaddleOCR returned no results")
277
- return text_regions
278
-
279
- # Parse kết quả
280
- for line in result[0]:
281
- if line is None:
282
- continue
283
-
284
- bbox_points, (text, confidence) = line
285
-
286
- # bbox_points format: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
287
- x_coords = [point[0] for point in bbox_points]
288
- y_coords = [point[1] for point in bbox_points]
289
-
290
- if confidence > 0.2: # Threshold thấp để catch nhiều text
291
- # Detect language
292
- language = self.detect_language(text)
293
-
294
- text_regions.append({
295
- 'bbox': [int(min(x_coords)), int(min(y_coords)),
296
- int(max(x_coords)), int(max(y_coords))],
297
- 'text': text,
298
- 'confidence': float(confidence),
299
- 'language': language
300
- })
301
- logger.info(f"OCR: '{text}' (conf: {confidence:.2f}, lang: {language})")
302
-
303
- logger.info(f"✅ Detected {len(text_regions)} text regions")
304
- except Exception as e:
305
- logger.error(f"PaddleOCR failed: {e}")
306
- import traceback
307
- traceback.print_exc()
308
-
309
- return text_regions
310
-
311
-
312
- class SmartColorDetector:
313
- """
314
- LOGIC MỚI: Detect TẤT CẢ màu NGOẠI TRỪ đen thuần và trắng thuần
315
- """
316
-
317
- def __init__(self, config: OptimizationConfig):
318
- self.config = config
319
-
320
- def create_valid_color_mask(self, image: np.ndarray) -> np.ndarray:
321
- """
322
- Tạo mask cho TẤT CẢ pixel có màu (không phải đen/trắng/xám thuần)
323
-
324
- Trong HSV:
325
- - Đen thuần: V (value) rất thấp (0-20)
326
- - Trắng thuần: V rất cao (235-255) VÀ S (saturation) rất thấp (0-25)
327
- - Xám thuần: S rất thấp (0-30) - không phân biệt hue
328
- - MỌI màu khác: VALID!
329
- """
330
- hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
331
- h, s, v = cv2.split(hsv)
332
-
333
- # Tạo mask GIỮ LẠI tất cả pixel
334
- valid_mask = np.ones(image.shape[:2], dtype=np.uint8) * 255
335
-
336
- # Loại BỎ đen thuần: V < 20
337
- if self.config.exclude_pure_black:
338
- black_mask = v < 20
339
- valid_mask[black_mask] = 0
340
- logger.info(f"Excluded {np.sum(black_mask)} pure black pixels")
341
-
342
- # Loại BỎ trắng thuần: V > 235 AND S < 25
343
- if self.config.exclude_pure_white:
344
- white_mask = (v > 235) & (s < 25)
345
- valid_mask[white_mask] = 0
346
- logger.info(f"Excluded {np.sum(white_mask)} pure white pixels")
347
-
348
- # Loại BỎ xám thuần: S < 30 (màu không có saturation = màu xám)
349
- # Nhưng KHÔNG loại nếu đã là đen hoặc trắng thuần (đã loại ở trên)
350
- gray_mask = (s < 30) & (v >= 20) & (v <= 235)
351
- valid_mask[gray_mask] = 0
352
- logger.info(f"Excluded {np.sum(gray_mask)} gray pixels")
353
-
354
- logger.info(f"Valid colored pixels: {np.sum(valid_mask > 0)}")
355
- return valid_mask
356
-
357
- def cluster_colors(self, image: np.ndarray, valid_mask: np.ndarray) -> List[np.ndarray]:
358
- """
359
- Group các màu giống nhau bằng K-means clustering
360
- """
361
- masks = []
362
-
363
- # Lấy tất cả pixel hợp lệ
364
- valid_pixels = image[valid_mask > 0]
365
-
366
- if len(valid_pixels) < 100:
367
- logger.warning("Not enough valid pixels for clustering")
368
- return [valid_mask]
369
-
370
- # K-means clustering
371
- pixels_flat = valid_pixels.reshape(-1, 3).astype(np.float32)
372
- n_clusters = min(self.config.n_color_clusters, len(pixels_flat) // 100)
373
-
374
- if n_clusters < 2:
375
- return [valid_mask]
376
-
377
- logger.info(f"Clustering into {n_clusters} color groups...")
378
 
379
- try:
380
- kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
381
- labels = kmeans.fit_predict(pixels_flat)
382
- centers = kmeans.cluster_centers_.astype(np.uint8)
383
 
384
- # Tạo mask cho mỗi cluster
385
- pixel_coords = np.argwhere(valid_mask > 0)
 
386
 
387
- for cluster_id in range(n_clusters):
388
- cluster_mask = np.zeros(image.shape[:2], dtype=np.uint8)
389
- cluster_pixels = pixel_coords[labels == cluster_id]
390
-
391
- if len(cluster_pixels) < 50:
392
- continue
393
-
394
- for coord in cluster_pixels:
395
- cluster_mask[coord[0], coord[1]] = 255
396
-
397
- # Clean up mask
398
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
399
- cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
400
- cluster_mask = cv2.morphologyEx(cluster_mask, cv2.MORPH_OPEN, kernel, iterations=1)
401
-
402
- if np.sum(cluster_mask) > 100:
403
- masks.append(cluster_mask)
404
- logger.info(f" Cluster {cluster_id}: {np.sum(cluster_mask)} pixels, "
405
- f"center color: {centers[cluster_id]}")
406
 
407
- except Exception as e:
408
- logger.error(f"Clustering failed: {e}")
409
- return [valid_mask]
410
-
411
- return masks
412
-
413
-
414
- class EnhancedSeatExtractor:
415
- def __init__(self, config: OptimizationConfig = OptimizationConfig()):
416
- self.config = config
417
- self.executor = ThreadPoolExecutor(max_workers=4)
418
- self.bg_remover = BackgroundRemover()
419
- self.text_detector = TextDetector(config)
420
- self.color_detector = SmartColorDetector(config)
421
- logger.info("✅ Enhanced Extractor with Smart Color Detection initialized")
422
-
423
- def compute_image_hash(self, image: np.ndarray) -> str:
424
- return hashlib.md5(image.tobytes()).hexdigest()
425
-
426
- def detect_sections_in_mask(self, mask: np.ndarray, text_regions: List[Dict]) -> List[Dict]:
427
- """Detect sections from a color mask"""
428
- sections = []
429
-
430
- if np.sum(mask) < self.config.min_section_area:
431
- return sections
432
-
433
- # KHÔNG loại bỏ text regions - giữ nguyên sections hoàn chỉnh
434
- # Text là PART OF section, không phải noise cần loại bỏ
435
- text_excluded_mask = mask.copy()
436
-
437
- # Morphological operations - GIảM iterations để không "ăn mòn" sections
438
- kernel = cv2.getStructuringElement(
439
- cv2.MORPH_ELLIPSE,
440
- (self.config.morphology_kernel_size, self.config.morphology_kernel_size)
441
- )
442
- # Chỉ CLOSE để nối các vùng gần nhau, không OPEN để tránh làm nhỏ sections
443
- cleaned_mask = cv2.morphologyEx(text_excluded_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
444
-
445
- # Find contours
446
- contours, _ = cv2.findContours(cleaned_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
447
-
448
- for contour in contours:
449
- area = cv2.contourArea(contour)
450
 
451
- if area < self.config.min_section_area or area > self.config.max_section_area:
452
- continue
453
 
454
- # Check solidity (shape quality)
455
- hull = cv2.convexHull(contour)
456
- hull_area = cv2.contourArea(hull)
457
- solidity = area / hull_area if hull_area > 0 else 0
458
 
459
- if solidity < self.config.min_solidity:
460
- continue
 
 
461
 
462
- # Simplify contour
463
- epsilon = 0.01 * cv2.arcLength(contour, True)
464
- approx = cv2.approxPolyDP(contour, epsilon, True)
 
465
 
466
- if len(approx) >= 3:
467
- x, y, w, h = cv2.boundingRect(contour)
468
- sections.append({
469
- 'contour': approx,
470
- 'bbox': [x, y, x + w, y + h],
471
- 'area': area,
472
- 'confidence': min(1.0, solidity),
473
- 'center': (x + w // 2, y + h // 2),
474
- 'solidity': solidity
475
- })
476
-
477
- return sections
478
-
479
- def extract_polygons_enhanced(self, image: np.ndarray) -> PolygonResponse:
480
- """Main extraction pipeline"""
481
- start_time = time.time()
482
-
483
- # Check cache
484
- image_hash = self.compute_image_hash(image)
485
- if image_hash in RESULTS_CACHE:
486
- logger.info("Returning cached results")
487
- cached_result = RESULTS_CACHE[image_hash]
488
- cached_result.cache_hit = True
489
- return cached_result
490
-
491
- # Ensure RGB
492
- if len(image.shape) == 2:
493
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
494
- elif len(image.shape) == 3:
495
- if image.shape[2] == 4:
496
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
497
-
498
- # Step 1: Background Removal
499
- if self.config.use_background_removal:
500
- logger.info("🔄 Removing background...")
501
- pil_image = Image.fromarray(image).convert('RGB')
502
- processed_image, bg_mask = self.bg_remover.remove_background(pil_image)
503
- image = np.array(processed_image)
504
 
505
- if len(image.shape) != 3 or image.shape[2] != 3:
506
- if len(image.shape) == 2:
507
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
508
-
509
- # Step 2: OCR Text Detection
510
- text_regions = []
511
- if self.config.use_ocr:
512
- logger.info("🔄 Detecting text...")
513
- text_regions = self.text_detector.detect_text(image)
514
-
515
- # Step 3: Smart Color Detection
516
- logger.info("🔄 Detecting all colors (excluding black/white)...")
517
- valid_color_mask = self.color_detector.create_valid_color_mask(image)
518
-
519
- # Step 4: Cluster Colors
520
- all_sections = []
521
- if self.config.use_color_clustering:
522
- logger.info("🔄 Clustering colors...")
523
- color_masks = self.color_detector.cluster_colors(image, valid_color_mask)
524
- logger.info(f"Found {len(color_masks)} color groups")
525
 
526
- # Detect sections in each color group
527
- for i, mask in enumerate(color_masks):
528
- logger.info(f"Processing color group {i + 1}/{len(color_masks)}...")
529
- sections = self.detect_sections_in_mask(mask, text_regions)
530
-
531
- for section in sections:
532
- section['color_group'] = i
533
-
534
- all_sections.extend(sections)
535
- logger.info(f" Found {len(sections)} sections in group {i}")
536
  else:
537
- # Single pass without clustering
538
- all_sections = self.detect_sections_in_mask(valid_color_mask, text_regions)
539
-
540
- # Step 5: Remove overlapping sections
541
- filtered_sections = self.remove_overlapping_sections(all_sections)
542
-
543
- # Convert to response format
544
- polygons = []
545
- confidence_scores = []
546
- areas = []
547
- bounding_boxes = []
548
- labels = []
549
-
550
- for i, section in enumerate(filtered_sections):
551
- contour = section['contour']
552
- polygon = contour.reshape(-1, 2).tolist()
553
 
554
- polygons.append(polygon)
555
- confidence_scores.append(section['confidence'])
556
- areas.append(section['area'])
557
- bounding_boxes.append(section['bbox'])
558
- labels.append(f"Section_{i + 1}")
559
-
560
- # Group sections
561
- seat_groups = self.group_sections(filtered_sections)
562
-
563
- processing_time = time.time() - start_time
564
- geojson_output = self.to_geojson(filtered_sections)
565
-
566
- response = PolygonResponse(
567
- polygons=polygons,
568
- confidence_scores=confidence_scores,
569
- areas=areas,
570
- bounding_boxes=bounding_boxes,
571
- labels=labels,
572
- seat_groups=seat_groups,
573
- detected_text=[{
574
- 'text': t['text'],
575
- 'confidence': t['confidence'],
576
- 'bbox': t['bbox'],
577
- 'language': t.get('language', 'unknown')
578
- } for t in text_regions],
579
- processing_info={
580
- "total_sections": len(polygons),
581
- "total_text_regions": len(text_regions),
582
- "vietnamese_text": sum(1 for t in text_regions if t.get('language') == 'vi'),
583
- "english_text": sum(1 for t in text_regions if t.get('language') == 'en'),
584
- "processing_time": processing_time,
585
- "clustering_enabled": self.config.use_color_clustering
586
- },
587
- cache_hit=False,
588
- geojson=geojson_output
589
- )
590
-
591
- # Cache result
592
- if len(RESULTS_CACHE) >= MAX_CACHE_SIZE:
593
- RESULTS_CACHE.pop(next(iter(RESULTS_CACHE)))
594
- RESULTS_CACHE[image_hash] = response
595
-
596
- return response
597
-
598
- def remove_overlapping_sections(self, sections: List[Dict]) -> List[Dict]:
599
- if not sections:
600
- return sections
601
-
602
- sorted_sections = sorted(sections, key=lambda x: x['confidence'], reverse=True)
603
- filtered = []
604
-
605
- for section in sorted_sections:
606
- overlap = False
607
- for accepted in filtered:
608
- if self.calculate_overlap(section['bbox'], accepted['bbox']) > 0.5:
609
- overlap = True
610
- break
611
-
612
- if not overlap:
613
- filtered.append(section)
614
-
615
- return filtered
616
-
617
- def calculate_overlap(self, bbox1: List, bbox2: List) -> float:
618
- x1_1, y1_1, x2_1, y2_1 = bbox1
619
- x1_2, y1_2, x2_2, y2_2 = bbox2
620
-
621
- x1_int = max(x1_1, x1_2)
622
- y1_int = max(y1_1, y1_2)
623
- x2_int = min(x2_1, x2_2)
624
- y2_int = min(y2_1, y2_2)
625
-
626
- if x2_int <= x1_int or y2_int <= y1_int:
627
- return 0.0
628
-
629
- intersection = (x2_int - x1_int) * (y2_int - y1_int)
630
- area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
631
- area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
632
- union = area1 + area2 - intersection
633
-
634
- return intersection / union if union > 0 else 0.0
635
 
636
- def group_sections(self, sections: List[Dict]) -> Dict[str, List[int]]:
637
- groups = defaultdict(list)
638
-
639
- for idx, section in enumerate(sections):
640
- group_id = section.get('color_group', 0)
641
- groups[f"ColorGroup_{group_id}"].append(idx)
642
-
643
- return dict(groups)
644
 
645
- def to_geojson(self, sections: List[Dict]) -> Dict[str, Any]:
646
- features = []
647
- for section in sections:
648
- contour = section['contour'].reshape(-1, 2).tolist()
649
- features.append({
650
- "type": "Feature",
651
- "properties": {
652
- "confidence": section.get("confidence"),
653
- "area": section.get("area"),
654
- "color_group": section.get("color_group")
655
- },
656
- "geometry": {
657
- "type": "Polygon",
658
- "coordinates": [[list(map(float, p)) for p in contour]]
659
- }
660
- })
661
-
662
- return {
663
- "type": "FeatureCollection",
664
- "features": features
665
- }
666
-
667
-
668
- @app.on_event("startup")
669
- async def startup_event():
670
- global extractor
671
  try:
672
- config = OptimizationConfig(
673
- use_background_removal=True,
674
- use_ocr=True,
675
- exclude_pure_black=True,
676
- exclude_pure_white=True,
677
- use_color_clustering=True,
678
- n_color_clusters=20,
679
- min_section_area=500,
680
- max_section_area=50000,
681
- ocr_languages=["vi", "en"], # For info only
682
- ocr_gpu=True
683
  )
684
- extractor = EnhancedSeatExtractor(config)
685
-
686
- logger.info("Loading BiRefNet...")
687
- extractor.bg_remover.load_model()
688
-
689
- logger.info("Loading PaddleOCR (PP-OCRv4_mobile)...")
690
- extractor.text_detector.load_models()
691
 
692
- logger.info("✅ System initialized successfully")
693
- logger.info("✅ Using PaddleOCR lite for Vietnamese")
694
- logger.info("✅ Color detection: ALL colors except pure black/white/gray")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  except Exception as e:
696
- logger.error(f"Initialization failed: {e}")
697
- import traceback
698
- traceback.print_exc()
699
 
700
 
701
- @app.post("/extract-seats/", response_model=PolygonResponse)
702
- async def extract_seats_endpoint(
703
- file: UploadFile = File(...),
704
- use_background_removal: bool = Query(True),
705
- use_ocr: bool = Query(True),
706
- use_clustering: bool = Query(True),
707
- n_clusters: int = Query(20, ge=2, le=50)
708
- ):
709
- """
710
- Extract sections with smart color detection
711
-
712
- Detects ALL colors except:
713
- - Pure black (V < 20 in HSV)
714
- - Pure white (V > 235 AND S < 25 in HSV)
715
- """
716
- if extractor is None:
717
- raise HTTPException(status_code=503, detail="System not initialized")
718
-
719
- if not file.content_type.startswith("image/"):
720
- raise HTTPException(status_code=400, detail="Must be an image")
721
-
722
- try:
723
- contents = await file.read()
724
- image = Image.open(io.BytesIO(contents))
725
- image_array = np.array(image)
726
-
727
- # Update config
728
- extractor.config.use_background_removal = use_background_removal
729
- extractor.config.use_ocr = use_ocr
730
- extractor.config.use_color_clustering = use_clustering
731
- extractor.config.n_color_clusters = n_clusters
732
-
733
- result = extractor.extract_polygons_enhanced(image_array)
734
- return result
735
-
736
- except Exception as e:
737
- logger.error(f"Processing failed: {e}")
738
- import traceback
739
- traceback.print_exc()
740
- raise HTTPException(status_code=500, detail=f"Failed: {str(e)}")
741
-
742
 
743
  if __name__ == "__main__":
744
  import os
 
1
+ """
2
+ Test script for Event Tags Generator API
3
+ """
4
+
5
+ import requests
6
+ import json
7
+
8
+ # API endpoint
9
+ BASE_URL = "http://localhost:8001"
10
+
11
+ def test_generate_tags():
12
+ """Test single event tag generation"""
13
+
14
+ print("=" * 60)
15
+ print("Testing Event Tags Generator")
16
+ print("=" * 60)
17
+
18
+ # Test data
19
+ event_data = {
20
+ "event_name": "Vietnam Music Festival 2025",
21
+ "category": "Âm nhạc",
22
+ "short_description": "Lễ hội âm nhạc quốc tế lớn nhất Việt Nam năm 2025",
23
+ "detailed_description": """
24
+ Vietnam Music Festival 2025 là sự kiện âm nhạc đỉnh cao quy tụ các nghệ sĩ
25
+ nổi tiếng trong nước và quốc tế. Sự kiện diễn ra trong 3 ngày với hơn 50
26
+ nghệ sĩ tham gia, từ nhạc pop, rock, EDM đến acoustic. Đặc biệt có sự góp
27
+ mặt của các DJ hàng đầu thế giới. Không gian festival rộng 10,000m2 tại
28
+ trung tâm Hà Nội với hệ thống âm thanh ánh sáng hiện đại. Dự kiến thu hút
29
+ hơn 30,000 khán giả mỗi ngày.
30
+ """,
31
+ "max_tags": 12,
32
+ "language": "vi"
33
+ }
34
+
35
+ print("\n📤 REQUEST:")
36
+ print(json.dumps(event_data, indent=2, ensure_ascii=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ try:
39
+ # Call API
40
+ response = requests.post(
41
+ f"{BASE_URL}/generate-tags",
42
+ json=event_data,
43
+ headers={"Content-Type": "application/json"}
44
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ if response.status_code == 200:
47
+ result = response.json()
 
 
48
 
49
+ print("\n✅ SUCCESS!")
50
+ print("\n📥 RESPONSE:")
51
+ print(json.dumps(result, indent=2, ensure_ascii=False))
52
 
53
+ print("\n" + "=" * 60)
54
+ print("GENERATED METADATA:")
55
+ print("=" * 60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ print(f"\n🏷️ TAGS ({len(result['generated_tags'])} tags):")
58
+ for tag in result['generated_tags']:
59
+ print(f" • {tag}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ print(f"\n📁 PRIMARY CATEGORY: {result['primary_category']}")
 
62
 
63
+ if result['secondary_categories']:
64
+ print(f"\n📂 SECONDARY CATEGORIES:")
65
+ for cat in result['secondary_categories']:
66
+ print(f" • {cat}")
67
 
68
+ if result['keywords']:
69
+ print(f"\n🔍 SEO KEYWORDS:")
70
+ for kw in result['keywords']:
71
+ print(f" • {kw}")
72
 
73
+ if result['hashtags']:
74
+ print(f"\n#️⃣ HASHTAGS:")
75
+ for ht in result['hashtags']:
76
+ print(f" • {ht}")
77
 
78
+ if result['target_audience']:
79
+ print(f"\n👥 TARGET AUDIENCE:")
80
+ for aud in result['target_audience']:
81
+ print(f" • {aud}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ print(f"\n😊 SENTIMENT: {result['sentiment']}")
84
+ print(f"💯 CONFIDENCE: {result['confidence_score']}")
85
+ print(f"⏱️ GENERATION TIME: {result['generation_time']}")
86
+ print(f"🤖 MODEL USED: {result['model_used']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
 
 
 
 
 
 
 
 
 
 
88
  else:
89
+ print(f"\n❌ ERROR: {response.status_code}")
90
+ print(response.text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ except requests.exceptions.ConnectionError:
93
+ print("\n❌ ERROR: Cannot connect to API")
94
+ print("Make sure the server is running: python event_tags_generator.py")
95
+ except Exception as e:
96
+ print(f"\n❌ ERROR: {str(e)}")
97
+
98
+
99
+ def test_batch_generation():
100
+ """Test batch event tag generation"""
101
+
102
+ print("\n\n" + "=" * 60)
103
+ print("Testing Batch Tag Generation")
104
+ print("=" * 60)
105
+
106
+ events = [
107
+ {
108
+ "event_name": "Tech Summit Vietnam 2025",
109
+ "category": "Công nghệ",
110
+ "short_description": "Hội nghị công nghệ lớn nhất Đông Nam Á",
111
+ "detailed_description": "Sự kiện quy tụ các chuyên gia AI, Blockchain, Cloud Computing từ Google, Microsoft, Amazon...",
112
+ "max_tags": 10,
113
+ "language": "vi"
114
+ },
115
+ {
116
+ "event_name": "Food Festival Saigon",
117
+ "category": "Ẩm thực",
118
+ "short_description": "Lễ hội ẩm thực đường phố Sài Gòn",
119
+ "detailed_description": "Khám phá hơn 100 món ăn đường phố đặc trưng của Sài Gòn với các đầu bếp nổi tiếng...",
120
+ "max_tags": 8,
121
+ "language": "vi"
122
+ }
123
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ print(f"\n📤 Generating tags for {len(events)} events...")
 
 
 
 
 
 
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  try:
128
+ response = requests.post(
129
+ f"{BASE_URL}/generate-tags/batch",
130
+ json=events,
131
+ headers={"Content-Type": "application/json"}
 
 
 
 
 
 
 
132
  )
 
 
 
 
 
 
 
133
 
134
+ if response.status_code == 200:
135
+ result = response.json()
136
+ print(f"\nBatch completed!")
137
+ print(f" Total: {result['total']}")
138
+ print(f" Successful: {result['successful']}")
139
+ print(f" Failed: {result['failed']}")
140
+
141
+ for item in result['results']:
142
+ if item['success']:
143
+ print(f"\n✓ {item['event_name']}")
144
+ print(f" Tags: {', '.join(item['data']['generated_tags'][:5])}...")
145
+ else:
146
+ print(f"\n✗ {item['event_name']}")
147
+ print(f" Error: {item['error']}")
148
+ else:
149
+ print(f"\n❌ ERROR: {response.status_code}")
150
+ print(response.text)
151
+
152
  except Exception as e:
153
+ print(f"\n❌ ERROR: {str(e)}")
 
 
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  if __name__ == "__main__":
158
  import os