Spaces:
Paused
Paused
File size: 6,228 Bytes
3740adc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import base64
import io
import logging
import time
import os
import traceback
import cv2
import numpy as np
import onnxruntime as ort
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from PIL import Image # <-- BARIS YANG HILANG DITAMBAHKAN DI SINI
# --- Konfigurasi Keamanan ---
ENCODED_SECRET_KEY = os.environ.get("SECRET_KEY", "S3VuY2lSYWhhc2lhVXRhbWFfVW50dWtTb2x2ZXJBbWF6b24=")
SECRET_KEY = base64.b64decode(ENCODED_SECRET_KEY).decode('utf-8')
# Mengatur logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Pemetaan Kelas dari data.yaml ---
CLASS_MAPPING = {
"the bags": 0,
"the beds": 1,
"the buckets": 2,
"the chairs": 3,
"the clocks": 4,
"the curtains": 5,
"the hats": 6,
}
# --- Memuat Model ONNX ---
ort_session = None
model_input_height = 0
model_input_width = 0
input_name = ""
# Inisialisasi aplikasi FastAPI
app = FastAPI(title="Amazon Captcha Solver - ONNX FastAPI")
@app.on_event("startup")
def load_model():
"""Memuat model ONNX saat server FastAPI pertama kali dijalankan."""
global ort_session, model_input_height, model_input_width, input_name
logger.info("🚀 Memuat model ONNX... Harap tunggu.")
try:
providers = ['CPUExecutionProvider']
ort_session = ort.InferenceSession("best.onnx", providers=providers)
model_inputs = ort_session.get_inputs()
input_shape = model_inputs[0].shape
input_name = model_inputs[0].name
model_input_height, model_input_width = input_shape[2], input_shape[3]
logger.info(f"✅ Model ONNX berhasil dimuat. Input shape: {input_shape}")
except Exception as e:
logger.error(f"❌ Gagal memuat model ONNX: {e}")
ort_session = None
# --- Model Input Pydantic ---
class SolveRequest(BaseModel):
api_key: str
image_base64: str
target_class: str
# --- FUNGSI PEMBANTU UNTUK ONNX (Tidak Berubah) ---
def preprocess_image(img: Image.Image, height: int, width: int) -> np.ndarray:
img_np = np.array(img.convert('RGB'))
original_height, original_width = img_np.shape[:2]
ratio = min(width / original_width, height / original_height)
new_width, new_height = int(original_width * ratio), int(original_height * ratio)
img_resized = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
padded_image = np.full((height, width, 3), 114, dtype=np.uint8)
dw, dh = (width - new_width) // 2, (height - new_height) // 2
padded_image[dh:dh + new_height, dw:dw + new_width, :] = img_resized
image_tensor = padded_image.astype(np.float32) / 255.0
image_tensor = np.transpose(image_tensor, (2, 0, 1))
image_tensor = np.expand_dims(image_tensor, axis=0)
return image_tensor, ratio, dw, dh
def postprocess_output(output: np.ndarray, ratio: float, dw: int, dh: int, conf_threshold: float, nms_threshold: float) -> list:
outputs = np.transpose(output[0], (1, 0))
boxes, scores, class_ids = [], [], []
for row in outputs:
class_score = np.max(row[4:])
if class_score > conf_threshold:
class_id = np.argmax(row[4:])
cx, cy, w, h = row[0:4]
x1 = int((cx - w / 2 - dw) / ratio)
y1 = int((cy - h / 2 - dh) / ratio)
x2 = int((cx + w / 2 - dw) / ratio)
y2 = int((cy + h / 2 - dh) / ratio)
boxes.append([x1, y1, x2-x1, y2-y1])
scores.append(class_score)
class_ids.append(class_id)
if not boxes: return []
indices = cv2.dnn.NMSBoxes(boxes, np.array(scores), conf_threshold, nms_threshold)
final_boxes = []
if len(indices) > 0:
for i in indices.flatten():
x, y, w, h = boxes[i]
center_x = x + w / 2
center_y = y + h / 2
final_boxes.append({'class_id': class_ids[i], 'box': [center_x, center_y]})
return final_boxes
# --- ENDPOINT UTAMA ---
@app.post("/solve_captcha")
@app.post("/")
async def solve_captcha(payload: SolveRequest):
logger.info("Permintaan baru diterima...")
if payload.api_key != SECRET_KEY:
raise HTTPException(status_code=401, detail="API Key tidak valid.")
if ort_session is None:
raise HTTPException(status_code=503, detail="Model tidak tersedia atau gagal dimuat.")
logger.info(f"🎯 Kelas target: '{payload.target_class}'")
try:
image_bytes = base64.b64decode(payload.image_base64)
image = Image.open(io.BytesIO(image_bytes))
preprocessed_image, ratio, dw, dh = preprocess_image(image, model_input_height, model_input_width)
logger.info("🤖 Menjalankan prediksi model ONNX...")
start_time = time.time()
outputs = ort_session.run(None, {input_name: preprocessed_image})
end_time = time.time()
logger.info(f"⏱️ Waktu prediksi ONNX: {(end_time - start_time) * 1000:.2f} ms")
results = postprocess_output(outputs[0], ratio, dw, dh, conf_threshold=0.55, nms_threshold=0.02)
target_id = CLASS_MAPPING.get(payload.target_class.lower().strip(), -1)
if target_id == -1:
raise HTTPException(status_code=400, detail=f"Kelas '{payload.target_class}' tidak ada di mapping server.")
coordinates = [res['box'] for res in results if res['class_id'] == target_id]
expected_object_count = 9
if len(results) != expected_object_count:
error_msg = f"Model mendeteksi {len(results)} objek, bukan {expected_object_count}."
logger.warning(f"⚠️ {error_msg} Melewatkan...")
raise HTTPException(status_code=400, detail={"error_code": "WRONG_OBJECT_COUNT", "message": error_msg})
logger.info(f"✅ Ditemukan {len(coordinates)} koordinat untuk '{payload.target_class}'")
return {"coordinates": coordinates}
except HTTPException as http_exc:
raise http_exc
except Exception as e:
logger.error(f"💥 Terjadi error internal server: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail="Internal Server Error")
|