Spaces:
Paused
Paused
| import base64 | |
| import io | |
| import logging | |
| import time | |
| import os | |
| import traceback | |
| import cv2 | |
| import numpy as np | |
| import onnxruntime as ort | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import List | |
| from PIL import Image # <-- BARIS YANG HILANG DITAMBAHKAN DI SINI | |
| # --- Konfigurasi Keamanan --- | |
| ENCODED_SECRET_KEY = os.environ.get("SECRET_KEY", "S3VuY2lSYWhhc2lhVXRhbWFfVW50dWtTb2x2ZXJBbWF6b24=") | |
| SECRET_KEY = base64.b64decode(ENCODED_SECRET_KEY).decode('utf-8') | |
| # Mengatur logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- Pemetaan Kelas dari data.yaml --- | |
| CLASS_MAPPING = { | |
| "the bags": 0, | |
| "the beds": 1, | |
| "the buckets": 2, | |
| "the chairs": 3, | |
| "the clocks": 4, | |
| "the curtains": 5, | |
| "the hats": 6, | |
| } | |
| # --- Memuat Model ONNX --- | |
| ort_session = None | |
| model_input_height = 0 | |
| model_input_width = 0 | |
| input_name = "" | |
| # Inisialisasi aplikasi FastAPI | |
| app = FastAPI(title="Amazon Captcha Solver - ONNX FastAPI") | |
| def load_model(): | |
| """Memuat model ONNX saat server FastAPI pertama kali dijalankan.""" | |
| global ort_session, model_input_height, model_input_width, input_name | |
| logger.info("π Memuat model ONNX... Harap tunggu.") | |
| try: | |
| providers = ['CPUExecutionProvider'] | |
| ort_session = ort.InferenceSession("best.onnx", providers=providers) | |
| model_inputs = ort_session.get_inputs() | |
| input_shape = model_inputs[0].shape | |
| input_name = model_inputs[0].name | |
| model_input_height, model_input_width = input_shape[2], input_shape[3] | |
| logger.info(f"β Model ONNX berhasil dimuat. Input shape: {input_shape}") | |
| except Exception as e: | |
| logger.error(f"β Gagal memuat model ONNX: {e}") | |
| ort_session = None | |
| # --- Model Input Pydantic --- | |
| class SolveRequest(BaseModel): | |
| api_key: str | |
| image_base64: str | |
| target_class: str | |
| # --- FUNGSI PEMBANTU UNTUK ONNX (Tidak Berubah) --- | |
| def preprocess_image(img: Image.Image, height: int, width: int) -> np.ndarray: | |
| img_np = np.array(img.convert('RGB')) | |
| original_height, original_width = img_np.shape[:2] | |
| ratio = min(width / original_width, height / original_height) | |
| new_width, new_height = int(original_width * ratio), int(original_height * ratio) | |
| img_resized = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_LINEAR) | |
| padded_image = np.full((height, width, 3), 114, dtype=np.uint8) | |
| dw, dh = (width - new_width) // 2, (height - new_height) // 2 | |
| padded_image[dh:dh + new_height, dw:dw + new_width, :] = img_resized | |
| image_tensor = padded_image.astype(np.float32) / 255.0 | |
| image_tensor = np.transpose(image_tensor, (2, 0, 1)) | |
| image_tensor = np.expand_dims(image_tensor, axis=0) | |
| return image_tensor, ratio, dw, dh | |
| def postprocess_output(output: np.ndarray, ratio: float, dw: int, dh: int, conf_threshold: float, nms_threshold: float) -> list: | |
| outputs = np.transpose(output[0], (1, 0)) | |
| boxes, scores, class_ids = [], [], [] | |
| for row in outputs: | |
| class_score = np.max(row[4:]) | |
| if class_score > conf_threshold: | |
| class_id = np.argmax(row[4:]) | |
| cx, cy, w, h = row[0:4] | |
| x1 = int((cx - w / 2 - dw) / ratio) | |
| y1 = int((cy - h / 2 - dh) / ratio) | |
| x2 = int((cx + w / 2 - dw) / ratio) | |
| y2 = int((cy + h / 2 - dh) / ratio) | |
| boxes.append([x1, y1, x2-x1, y2-y1]) | |
| scores.append(class_score) | |
| class_ids.append(class_id) | |
| if not boxes: return [] | |
| indices = cv2.dnn.NMSBoxes(boxes, np.array(scores), conf_threshold, nms_threshold) | |
| final_boxes = [] | |
| if len(indices) > 0: | |
| for i in indices.flatten(): | |
| x, y, w, h = boxes[i] | |
| center_x = x + w / 2 | |
| center_y = y + h / 2 | |
| final_boxes.append({'class_id': class_ids[i], 'box': [center_x, center_y]}) | |
| return final_boxes | |
| # --- ENDPOINT UTAMA --- | |
| async def solve_captcha(payload: SolveRequest): | |
| logger.info("Permintaan baru diterima...") | |
| if payload.api_key != SECRET_KEY: | |
| raise HTTPException(status_code=401, detail="API Key tidak valid.") | |
| if ort_session is None: | |
| raise HTTPException(status_code=503, detail="Model tidak tersedia atau gagal dimuat.") | |
| logger.info(f"π― Kelas target: '{payload.target_class}'") | |
| try: | |
| image_bytes = base64.b64decode(payload.image_base64) | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| preprocessed_image, ratio, dw, dh = preprocess_image(image, model_input_height, model_input_width) | |
| logger.info("π€ Menjalankan prediksi model ONNX...") | |
| start_time = time.time() | |
| outputs = ort_session.run(None, {input_name: preprocessed_image}) | |
| end_time = time.time() | |
| logger.info(f"β±οΈ Waktu prediksi ONNX: {(end_time - start_time) * 1000:.2f} ms") | |
| results = postprocess_output(outputs[0], ratio, dw, dh, conf_threshold=0.55, nms_threshold=0.02) | |
| target_id = CLASS_MAPPING.get(payload.target_class.lower().strip(), -1) | |
| if target_id == -1: | |
| raise HTTPException(status_code=400, detail=f"Kelas '{payload.target_class}' tidak ada di mapping server.") | |
| coordinates = [res['box'] for res in results if res['class_id'] == target_id] | |
| expected_object_count = 9 | |
| if len(results) != expected_object_count: | |
| error_msg = f"Model mendeteksi {len(results)} objek, bukan {expected_object_count}." | |
| logger.warning(f"β οΈ {error_msg} Melewatkan...") | |
| raise HTTPException(status_code=400, detail={"error_code": "WRONG_OBJECT_COUNT", "message": error_msg}) | |
| logger.info(f"β Ditemukan {len(coordinates)} koordinat untuk '{payload.target_class}'") | |
| return {"coordinates": coordinates} | |
| except HTTPException as http_exc: | |
| raise http_exc | |
| except Exception as e: | |
| logger.error(f"π₯ Terjadi error internal server: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail="Internal Server Error") | |