Spaces:
Paused
Paused
| import sys | |
| import json | |
| import io | |
| import os | |
| import logging | |
| from typing import Optional, Tuple, List | |
| from difflib import SequenceMatcher | |
| import random | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.concurrency import run_in_threadpool | |
| import onnxruntime as ort | |
| import asyncio | |
| import yaml | |
| # --- Konfigurasi Aplikasi --- | |
| # Model Crop untuk memisahkan tugas dan petunjuk | |
| CROP_MODEL_WEIGHTS = "crop coordinates/bestcrop.onnx" | |
| CROP_YAML_PATH = "crop coordinates/data.yaml" | |
| # Model A adalah untuk gambar PETUNJUK (bagian bawah) | |
| MODEL_A_WEIGHTS = "best.onnx" | |
| YAML_A_PATH = "data.yaml" | |
| # Model B adalah untuk gambar TUGAS (bagian atas) | |
| MODEL_B_WEIGHTS = "best1.onnx" | |
| YAML_B_PATH = "data1.yaml" | |
| # LOGIKA PENTING: Model A (Petunjuk) menentukan jumlah maksimal klik | |
| # - Jika Model A deteksi 4 objek → maksimal 4 klik di Model B | |
| # - Jika Model A deteksi 6 objek → maksimal 6 klik di Model B | |
| # - Urutan klik: KIRI KE KANAN sesuai urutan Model A (petunjuk) | |
| # - Tidak ada klik jika Model A tidak mendeteksi apapun | |
| CONFIDENCE_THRESHOLD = 0.2 | |
| NMS_IOU_THRESHOLD = 0.8 | |
| SIMILARITY_CUTOFF = 0.6 | |
| PARTNERS_FILENAME = "custom_partners.json" | |
| RAW_PARTNERS = [('kalkulator', 'keyboard'), ('save', 'simpan')] | |
| CUSTOM_PARTNERS = {} | |
| # Konfigurasi Crop - Model-based Detection | |
| CROP_CONFIDENCE_THRESHOLD = 0.3 # Confidence threshold for crop model detection | |
| CROP_NMS_IOU_THRESHOLD = 0.4 # NMS threshold for crop model | |
| # --- Inisialisasi Aplikasi FastAPI --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', stream=sys.stdout) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Optimized ONNX Server - Smart Pairing") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| # --- Variabel Global --- | |
| model_crop_session, model_a_session, model_b_session = None, None, None | |
| CLASS_NAMES_CROP = [] | |
| CLASS_NAMES_A = [] | |
| CLASS_NAMES_B = [] | |
| # --- Optimasi untuk Concurrent Processing --- | |
| import threading | |
| import psutil | |
| import gc | |
| from concurrent.futures import ThreadPoolExecutor | |
| # Thread pool untuk concurrent inference | |
| inference_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="inference") | |
| # --- Fungsi Helper --- | |
| def crop_image_into_parts(image_bytes: bytes) -> Optional[Tuple[Image.Image, Image.Image]]: | |
| """Crop image using ONNX model to detect 'tugas' and 'petunjuk' regions""" | |
| image_np = np.frombuffer(image_bytes, np.uint8) | |
| img_bgr = cv2.imdecode(image_np, cv2.IMREAD_COLOR) | |
| if img_bgr is None: | |
| logger.error("Gagal membaca byte gambar dengan OpenCV.") | |
| return None | |
| def bgr_to_pil(img): return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
| try: | |
| logger.info("Using ONNX Crop Model for region detection") | |
| # Convert to PIL for prediction | |
| original_pil = bgr_to_pil(img_bgr) | |
| # Run crop model prediction | |
| detections = predict_with_onnx(model_crop_session, original_pil, CLASS_NAMES_CROP, "Crop Model") | |
| if not detections: | |
| logger.error("GAGAL: Crop model tidak mendeteksi region apapun.") | |
| return None | |
| # Log what was detected for debugging | |
| detected_classes = [det['class_name'].lower() for det in detections] | |
| logger.info(f"Crop model detected regions: {detected_classes}") | |
| # Find tugas and petunjuk regions | |
| tugas_box, petunjuk_box = None, None | |
| for det in detections: | |
| if det['class_name'].lower() == 'tugas': | |
| tugas_box = det['box'] # [x1, y1, x2, y2] | |
| elif det['class_name'].lower() == 'petunjuk': | |
| petunjuk_box = det['box'] # [x1, y1, x2, y2] | |
| if tugas_box is None: | |
| logger.error("GAGAL: Model crop tidak mendeteksi region 'tugas'.") | |
| return None | |
| if petunjuk_box is None: | |
| logger.error("GAGAL: Model crop tidak mendeteksi region 'petunjuk'.") | |
| return None | |
| # Extract regions using bounding boxes | |
| height, width = img_bgr.shape[:2] | |
| # Ensure coordinates are within image bounds | |
| def clamp_coords(box): | |
| x1, y1, x2, y2 = box | |
| x1 = max(0, min(int(x1), width-1)) | |
| y1 = max(0, min(int(y1), height-1)) | |
| x2 = max(x1+1, min(int(x2), width)) | |
| y2 = max(y1+1, min(int(y2), height)) | |
| return x1, y1, x2, y2 | |
| tugas_x1, tugas_y1, tugas_x2, tugas_y2 = clamp_coords(tugas_box) | |
| petunjuk_x1, petunjuk_y1, petunjuk_x2, petunjuk_y2 = clamp_coords(petunjuk_box) | |
| # Crop regions | |
| tugas_part = img_bgr[tugas_y1:tugas_y2, tugas_x1:tugas_x2] | |
| petunjuk_part = img_bgr[petunjuk_y1:petunjuk_y2, petunjuk_x1:petunjuk_x2] | |
| # Validate cropped regions | |
| if tugas_part.size == 0 or petunjuk_part.size == 0: | |
| logger.error("GAGAL: Salah satu region crop kosong.") | |
| return None | |
| logger.info(f"Crop SUKSES menggunakan ONNX Model: Tugas({tugas_part.shape[:2]}) Petunjuk({petunjuk_part.shape[:2]})") | |
| return bgr_to_pil(tugas_part), bgr_to_pil(petunjuk_part) | |
| except Exception as e: | |
| logger.error(f"GAGAL: Error dalam crop model detection: {e}", exc_info=True) | |
| return None | |
| def rebuild_custom_partners(): | |
| global CUSTOM_PARTNERS, RAW_PARTNERS | |
| if os.path.exists(PARTNERS_FILENAME): | |
| try: | |
| with open(PARTNERS_FILENAME, 'r', encoding='utf-8') as f: RAW_PARTNERS = json.load(f) | |
| logger.info(f"Loaded {len(RAW_PARTNERS)} partners from {PARTNERS_FILENAME}") | |
| except Exception as e: logger.warning(f"Failed to load {PARTNERS_FILENAME}, using default. Error: {e}") | |
| else: logger.info(f"'{PARTNERS_FILENAME}' not found, using default partners.") | |
| CUSTOM_PARTNERS.clear() | |
| for a, b in RAW_PARTNERS: | |
| CUSTOM_PARTNERS.setdefault(a.lower(), []).append(b.lower()) | |
| CUSTOM_PARTNERS.setdefault(b.lower(), []).append(a.lower()) | |
| logger.info(f"Rebuilt custom partners with {len(RAW_PARTNERS)} rules.") | |
| def get_similarity(a, b): return SequenceMatcher(None, a.lower(), b.lower()).ratio() | |
| def pair_and_output(detsA: List[dict], detsB: List[dict]) -> List[dict]: | |
| # Jika tidak ada petunjuk (detsA), tidak ada yang boleh diklik | |
| if not detsA: | |
| logger.warning("Tidak ada deteksi petunjuk (detsA), tidak ada koordinat yang dikembalikan.") | |
| return [] | |
| # PENTING: Urutkan petunjuk dari KIRI KE KANAN untuk menentukan urutan klik | |
| detsA = sorted(detsA, key=lambda d: d['center_coords'][0]) | |
| logger.info(f"Model A (Petunjuk) mendeteksi {len(detsA)} objek, menentukan maksimal {len(detsA)} klik.") | |
| used_indices_B = set() | |
| ordered_results = [None] * len(detsA) | |
| # Kombinasi pencocokan untuk hasil terbaik | |
| for i, da in enumerate(detsA): | |
| best_match_j = -1 | |
| highest_score = -1 # Skor > 1.0 untuk exact, 1.0 untuk partner, 0.6-0.99 untuk similarity | |
| for j, db in enumerate(detsB): | |
| if j in used_indices_B: continue | |
| current_score = 0 | |
| # Prioritas 1: Exact Match | |
| if da['class_name'].lower() == db['class_name'].lower(): | |
| current_score = 1.1 # Beri skor tertinggi agar selalu dipilih | |
| else: | |
| # Prioritas 2: Custom Partners | |
| partners = CUSTOM_PARTNERS.get(da['class_name'].lower(), []) | |
| if db['class_name'].lower() in partners: | |
| current_score = 1.0 | |
| else: | |
| # Prioritas 3: Similarity | |
| sim = get_similarity(da['class_name'], db['class_name']) | |
| if sim >= SIMILARITY_CUTOFF: | |
| current_score = sim | |
| if current_score > highest_score: | |
| highest_score = current_score | |
| best_match_j = j | |
| if best_match_j != -1: | |
| ordered_results[i] = detsB[best_match_j] | |
| used_indices_B.add(best_match_j) | |
| # BATASAN: Jumlah koordinat akhir HARUS SAMA dengan jumlah petunjuk (detsA) | |
| final_results_list = [res for res in ordered_results if res is not None] | |
| # Jika masih ada slot kosong, isi dengan objek terbaik yang belum ter-pair | |
| empty_slots = len(detsA) - len(final_results_list) | |
| if empty_slots > 0: | |
| unmatched_detsB = [db for j, db in enumerate(detsB) if j not in used_indices_B] | |
| if unmatched_detsB: | |
| logger.info(f"Hasil pairing {len(final_results_list)}/{len(detsA)} objek. Mengisi {empty_slots} slot kosong dengan objek ber-confidence tertinggi...") | |
| # Urutkan sisa objek berdasarkan skor confidence | |
| unmatched_detsB_sorted = sorted(unmatched_detsB, key=lambda x: x['score'], reverse=True) | |
| # Isi slot kosong di ordered_results | |
| filled = 0 | |
| for i in range(len(ordered_results)): | |
| if ordered_results[i] is None and filled < empty_slots: | |
| if filled < len(unmatched_detsB_sorted): | |
| ordered_results[i] = unmatched_detsB_sorted[filled] | |
| filled += 1 | |
| # Update final_results_list | |
| final_results_list = [res for res in ordered_results if res is not None] | |
| # PENTING: Final result TIDAK BOLEH lebih dari jumlah petunjuk (detsA) | |
| final_results_list = final_results_list[:len(detsA)] | |
| logger.info(f"Pairing selesai, menghasilkan TEPAT {len(final_results_list)} koordinat (sama dengan jumlah petunjuk).") | |
| final_coords = [{'x': res['center_coords'][0], 'y': res['center_coords'][1]} for res in final_results_list] | |
| return final_coords | |
| def predict_with_onnx(session: ort.InferenceSession, pil_image: Image, class_names: List[str], model_name: str) -> List[dict]: | |
| if pil_image is None: | |
| logger.warning(f"Input gambar untuk {model_name} kosong, prediksi dilewati.") | |
| return [] | |
| input_shape = session.get_inputs()[0].shape | |
| input_height, input_width = input_shape[2], input_shape[3] | |
| img_np = np.array(pil_image) | |
| original_height, original_width, _ = img_np.shape | |
| ratio = min(input_width / original_width, input_height / original_height) | |
| new_width, new_height = int(original_width * ratio), int(original_height * ratio) | |
| img_resized = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_LINEAR) | |
| padded_img = np.full((input_height, input_width, 3), 114, dtype=np.uint8) | |
| dw, dh = (input_width - new_width) // 2, (input_height - new_height) // 2 | |
| padded_img[dh:new_height+dh, dw:new_width+dw, :] = img_resized | |
| input_tensor = padded_img.astype(np.float32) / 255.0 | |
| input_tensor = input_tensor.transpose(2, 0, 1) | |
| input_tensor = np.expand_dims(input_tensor, axis=0) | |
| input_name = session.get_inputs()[0].name | |
| outputs = session.run(None, {input_name: input_tensor}) | |
| preds = np.transpose(outputs[0][0]) | |
| box_scores = preds[:, 4:] | |
| max_scores = np.max(box_scores, axis=1) | |
| valid_preds = max_scores > CONFIDENCE_THRESHOLD | |
| preds, max_scores = preds[valid_preds], max_scores[valid_preds] | |
| if preds.shape[0] == 0: | |
| logger.info(f"Model ONNX ({model_name}) mendeteksi 0 objek.") | |
| return [] | |
| class_ids = np.argmax(preds[:, 4:], axis=1) | |
| # Add debugging info for class_id ranges | |
| if len(class_ids) > 0: | |
| min_class_id, max_class_id = int(np.min(class_ids)), int(np.max(class_ids)) | |
| logger.debug(f"Model {model_name}: Predicted class_ids range from {min_class_id} to {max_class_id} (class_names length: {len(class_names)})") | |
| # Check for potential issues | |
| if max_class_id >= len(class_names): | |
| logger.error(f"Model {model_name}: WARNING! Max predicted class_id ({max_class_id}) exceeds class_names bounds ({len(class_names)-1}). This will cause IndexError!") | |
| boxes_raw = preds[:, :4] | |
| gain = ratio | |
| x_offset, y_offset = dw, dh | |
| boxes_raw[:, 0] = (boxes_raw[:, 0] - x_offset) / gain | |
| boxes_raw[:, 1] = (boxes_raw[:, 1] - y_offset) / gain | |
| boxes_raw[:, 2] /= gain | |
| boxes_raw[:, 3] /= gain | |
| x1, y1 = boxes_raw[:, 0] - boxes_raw[:, 2] / 2, boxes_raw[:, 1] - boxes_raw[:, 3] / 2 | |
| x2, y2 = boxes_raw[:, 0] + boxes_raw[:, 2] / 2, boxes_raw[:, 1] + boxes_raw[:, 3] / 2 | |
| boxes_processed = np.column_stack((x1, y1, x2, y2)).astype(np.float32) | |
| indices = cv2.dnn.NMSBoxes(boxes_processed, max_scores, CONFIDENCE_THRESHOLD, NMS_IOU_THRESHOLD) | |
| detections = [] | |
| if len(indices) > 0: | |
| for i in indices.flatten(): | |
| box, score, class_id = boxes_processed[i], max_scores[i], class_ids[i] | |
| # Add bounds checking for class_id to prevent IndexError | |
| if class_id < 0 or class_id >= len(class_names): | |
| logger.warning(f"Model {model_name}: class_id {class_id} is out of bounds for class_names (length: {len(class_names)}). Skipping this detection.") | |
| continue | |
| detections.append({'class_name': class_names[class_id], 'center_coords': (int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)), 'score': float(score), 'box': box.tolist()}) | |
| # Urutkan berdasarkan confidence tertinggi, tapi tidak dibatasi jumlahnya | |
| detections = sorted(detections, key=lambda x: x['score'], reverse=True) | |
| logger.info(f"Model ONNX ({model_name}) mengembalikan {len(detections)} objek (tidak dibatasi).") | |
| return detections | |
| def startup_event(): | |
| global model_crop_session, model_a_session, model_b_session, CLASS_NAMES_CROP, CLASS_NAMES_A, CLASS_NAMES_B | |
| providers = ['CPUExecutionProvider'] | |
| logger.info(f"Using ONNX Runtime providers: {providers}") | |
| # Log system resources | |
| memory_info = psutil.virtual_memory() | |
| cpu_count = psutil.cpu_count() | |
| logger.info(f"System Resources: {cpu_count} CPUs, {memory_info.total / (1024**3):.1f}GB RAM, {memory_info.available / (1024**3):.1f}GB Available") | |
| try: | |
| logger.info(f"Loading class names for Crop Model from {CROP_YAML_PATH}...") | |
| with open(CROP_YAML_PATH, "r", encoding="utf-8") as f: | |
| data_crop = yaml.safe_load(f) | |
| CLASS_NAMES_CROP = data_crop['names'] | |
| if not CLASS_NAMES_CROP: raise ValueError(f"File {CROP_YAML_PATH} tidak memiliki 'names' atau kosong.") | |
| logger.info(f"{len(CLASS_NAMES_CROP)} classes loaded for Crop Model: {CLASS_NAMES_CROP}") | |
| logger.info(f"Loading class names for Model A (Petunjuk) from {YAML_A_PATH}...") | |
| with open(YAML_A_PATH, "r", encoding="utf-8") as f: | |
| data_a = yaml.safe_load(f) | |
| CLASS_NAMES_A = data_a['names'] | |
| if not CLASS_NAMES_A: raise ValueError(f"File {YAML_A_PATH} tidak memiliki 'names' atau kosong.") | |
| logger.info(f"{len(CLASS_NAMES_A)} classes loaded for Model A.") | |
| logger.info(f"Loading class names for Model B (Tugas) from {YAML_B_PATH}...") | |
| with open(YAML_B_PATH, "r", encoding="utf-8") as f: | |
| data_b = yaml.safe_load(f) | |
| CLASS_NAMES_B = data_b['names'] | |
| if not CLASS_NAMES_B: raise ValueError(f"File {YAML_B_PATH} tidak memiliki 'names' atau kosong.") | |
| logger.info(f"{len(CLASS_NAMES_B)} classes loaded for Model B.") | |
| # Load models with optimized settings for concurrent processing | |
| session_options = ort.SessionOptions() | |
| session_options.intra_op_num_threads = 1 # Limit threads per model | |
| session_options.inter_op_num_threads = 1 | |
| session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL | |
| logger.info(f"Loading ONNX Crop Model from '{CROP_MODEL_WEIGHTS}'...") | |
| model_crop_session = ort.InferenceSession(CROP_MODEL_WEIGHTS, providers=providers, sess_options=session_options) | |
| logger.info(f"Loading ONNX model A (Petunjuk) from '{MODEL_A_WEIGHTS}'...") | |
| model_a_session = ort.InferenceSession(MODEL_A_WEIGHTS, providers=providers, sess_options=session_options) | |
| logger.info(f"Loading ONNX model B (Tugas) from '{MODEL_B_WEIGHTS}'...") | |
| model_b_session = ort.InferenceSession(MODEL_B_WEIGHTS, providers=providers, sess_options=session_options) | |
| logger.info("All ONNX models loaded successfully with concurrent processing optimization.") | |
| rebuild_custom_partners() | |
| # Force garbage collection after model loading | |
| gc.collect() | |
| memory_after = psutil.virtual_memory() | |
| logger.info(f"Memory after model loading: {memory_after.available / (1024**3):.1f}GB Available") | |
| except Exception as e: | |
| logger.error(f"FATAL: Failed to complete startup. Cause: {e}", exc_info=True) | |
| sys.exit(1) | |
| def status_check(): | |
| memory_info = psutil.virtual_memory() | |
| return { | |
| "status": "Optimized ONNX server is running!", | |
| "system_info": { | |
| "cpu_count": psutil.cpu_count(), | |
| "memory_total_gb": round(memory_info.total / (1024**3), 1), | |
| "memory_available_gb": round(memory_info.available / (1024**3), 1), | |
| "memory_percent_used": memory_info.percent | |
| } | |
| } | |
| def health_check(): | |
| """Health check endpoint for monitoring""" | |
| try: | |
| memory_info = psutil.virtual_memory() | |
| return { | |
| "status": "healthy", | |
| "memory_available_gb": round(memory_info.available / (1024**3), 1), | |
| "memory_percent_used": memory_info.percent, | |
| "models_loaded": all([model_crop_session, model_a_session, model_b_session]) | |
| } | |
| except Exception as e: | |
| return {"status": "unhealthy", "error": str(e)} | |
| def shutdown_event(): | |
| """Cleanup resources on shutdown""" | |
| global inference_executor | |
| logger.info("Shutting down application...") | |
| if inference_executor: | |
| inference_executor.shutdown(wait=True) | |
| gc.collect() | |
| logger.info("Application shutdown complete.") | |
| async def predict( | |
| main_image: Optional[UploadFile] = File(None), | |
| hint_image: Optional[UploadFile] = File(None), | |
| image: Optional[UploadFile] = File(None) | |
| ): | |
| try: | |
| petunjuk_pil, tugas_pil = None, None | |
| if main_image and hint_image: | |
| logger.info("Processing 2 images (main + hint) from FormData...") | |
| main_bytes, hint_bytes = await main_image.read(), await hint_image.read() | |
| tugas_pil = await run_in_threadpool(Image.open, io.BytesIO(main_bytes)) | |
| petunjuk_pil = await run_in_threadpool(Image.open, io.BytesIO(hint_bytes)) | |
| elif image: | |
| logger.info("Processing 1 image with server-side crop...") | |
| image_bytes = await image.read() | |
| result = await run_in_threadpool(crop_image_into_parts, image_bytes) | |
| if result is None: | |
| raise HTTPException(status_code=400, detail="Gagal memotong gambar di server.") | |
| tugas_pil, petunjuk_pil = result | |
| else: | |
| raise HTTPException(status_code=400, detail="Input tidak valid.") | |
| if tugas_pil: tugas_pil = tugas_pil.convert("RGB") | |
| if petunjuk_pil: petunjuk_pil = petunjuk_pil.convert("RGB") | |
| logger.info("--- Predicting 'hint' (Model A) and 'task' (Model B) concurrently with ONNX ---") | |
| # Use dedicated thread pool for inference to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| task_a = loop.run_in_executor(inference_executor, predict_with_onnx, model_a_session, petunjuk_pil, CLASS_NAMES_A, "Model A (Petunjuk)") | |
| task_b = loop.run_in_executor(inference_executor, predict_with_onnx, model_b_session, tugas_pil, CLASS_NAMES_B, "Model B (Tugas)") | |
| detsA, detsB = await asyncio.gather(task_a, task_b) | |
| # Log memory usage after inference | |
| memory_after_inference = psutil.virtual_memory() | |
| logger.info(f"Memory after inference: {memory_after_inference.available / (1024**3):.1f}GB Available") | |
| logger.info(f"Hasil deteksi mentah: Model A (Petunjuk) menemukan {len(detsA)} objek, Model B (Tugas) menemukan {len(detsB)} objek.") | |
| coords = await run_in_threadpool(pair_and_output, detsA, detsB) | |
| logger.info(f"Hasil setelah pairing: Ditemukan {len(coords)} koordinat akhir (DIBATASI sesuai jumlah petunjuk Model A).") | |
| # Cleanup memory after processing | |
| del detsA, detsB, tugas_pil, petunjuk_pil | |
| gc.collect() | |
| return {'coordinates': coords} | |
| except Exception as e: | |
| logger.error(f"Error processing request: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f'Internal server error: {str(e)}') |