import base64 import io import logging import time import os import traceback import cv2 import numpy as np import onnxruntime as ort from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List from PIL import Image # <-- BARIS YANG HILANG DITAMBAHKAN DI SINI # --- Konfigurasi Keamanan --- ENCODED_SECRET_KEY = os.environ.get("SECRET_KEY", "S3VuY2lSYWhhc2lhVXRhbWFfVW50dWtTb2x2ZXJBbWF6b24=") SECRET_KEY = base64.b64decode(ENCODED_SECRET_KEY).decode('utf-8') # Mengatur logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Pemetaan Kelas dari data.yaml --- CLASS_MAPPING = { "the bags": 0, "the beds": 1, "the buckets": 2, "the chairs": 3, "the clocks": 4, "the curtains": 5, "the hats": 6, } # --- Memuat Model ONNX --- ort_session = None model_input_height = 0 model_input_width = 0 input_name = "" # Inisialisasi aplikasi FastAPI app = FastAPI(title="Amazon Captcha Solver - ONNX FastAPI") @app.on_event("startup") def load_model(): """Memuat model ONNX saat server FastAPI pertama kali dijalankan.""" global ort_session, model_input_height, model_input_width, input_name logger.info("🚀 Memuat model ONNX... Harap tunggu.") try: providers = ['CPUExecutionProvider'] ort_session = ort.InferenceSession("best.onnx", providers=providers) model_inputs = ort_session.get_inputs() input_shape = model_inputs[0].shape input_name = model_inputs[0].name model_input_height, model_input_width = input_shape[2], input_shape[3] logger.info(f"✅ Model ONNX berhasil dimuat. Input shape: {input_shape}") except Exception as e: logger.error(f"❌ Gagal memuat model ONNX: {e}") ort_session = None # --- Model Input Pydantic --- class SolveRequest(BaseModel): api_key: str image_base64: str target_class: str # --- FUNGSI PEMBANTU UNTUK ONNX (Tidak Berubah) --- def preprocess_image(img: Image.Image, height: int, width: int) -> np.ndarray: img_np = np.array(img.convert('RGB')) original_height, original_width = img_np.shape[:2] ratio = min(width / original_width, height / original_height) new_width, new_height = int(original_width * ratio), int(original_height * ratio) img_resized = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_LINEAR) padded_image = np.full((height, width, 3), 114, dtype=np.uint8) dw, dh = (width - new_width) // 2, (height - new_height) // 2 padded_image[dh:dh + new_height, dw:dw + new_width, :] = img_resized image_tensor = padded_image.astype(np.float32) / 255.0 image_tensor = np.transpose(image_tensor, (2, 0, 1)) image_tensor = np.expand_dims(image_tensor, axis=0) return image_tensor, ratio, dw, dh def postprocess_output(output: np.ndarray, ratio: float, dw: int, dh: int, conf_threshold: float, nms_threshold: float) -> list: outputs = np.transpose(output[0], (1, 0)) boxes, scores, class_ids = [], [], [] for row in outputs: class_score = np.max(row[4:]) if class_score > conf_threshold: class_id = np.argmax(row[4:]) cx, cy, w, h = row[0:4] x1 = int((cx - w / 2 - dw) / ratio) y1 = int((cy - h / 2 - dh) / ratio) x2 = int((cx + w / 2 - dw) / ratio) y2 = int((cy + h / 2 - dh) / ratio) boxes.append([x1, y1, x2-x1, y2-y1]) scores.append(class_score) class_ids.append(class_id) if not boxes: return [] indices = cv2.dnn.NMSBoxes(boxes, np.array(scores), conf_threshold, nms_threshold) final_boxes = [] if len(indices) > 0: for i in indices.flatten(): x, y, w, h = boxes[i] center_x = x + w / 2 center_y = y + h / 2 final_boxes.append({'class_id': class_ids[i], 'box': [center_x, center_y]}) return final_boxes # --- ENDPOINT UTAMA --- @app.post("/solve_captcha") @app.post("/") async def solve_captcha(payload: SolveRequest): logger.info("Permintaan baru diterima...") if payload.api_key != SECRET_KEY: raise HTTPException(status_code=401, detail="API Key tidak valid.") if ort_session is None: raise HTTPException(status_code=503, detail="Model tidak tersedia atau gagal dimuat.") logger.info(f"🎯 Kelas target: '{payload.target_class}'") try: image_bytes = base64.b64decode(payload.image_base64) image = Image.open(io.BytesIO(image_bytes)) preprocessed_image, ratio, dw, dh = preprocess_image(image, model_input_height, model_input_width) logger.info("🤖 Menjalankan prediksi model ONNX...") start_time = time.time() outputs = ort_session.run(None, {input_name: preprocessed_image}) end_time = time.time() logger.info(f"⏱️ Waktu prediksi ONNX: {(end_time - start_time) * 1000:.2f} ms") results = postprocess_output(outputs[0], ratio, dw, dh, conf_threshold=0.55, nms_threshold=0.02) target_id = CLASS_MAPPING.get(payload.target_class.lower().strip(), -1) if target_id == -1: raise HTTPException(status_code=400, detail=f"Kelas '{payload.target_class}' tidak ada di mapping server.") coordinates = [res['box'] for res in results if res['class_id'] == target_id] expected_object_count = 9 if len(results) != expected_object_count: error_msg = f"Model mendeteksi {len(results)} objek, bukan {expected_object_count}." logger.warning(f"⚠️ {error_msg} Melewatkan...") raise HTTPException(status_code=400, detail={"error_code": "WRONG_OBJECT_COUNT", "message": error_msg}) logger.info(f"✅ Ditemukan {len(coordinates)} koordinat untuk '{payload.target_class}'") return {"coordinates": coordinates} except HTTPException as http_exc: raise http_exc except Exception as e: logger.error(f"💥 Terjadi error internal server: {e}") traceback.print_exc() raise HTTPException(status_code=500, detail="Internal Server Error")