File size: 6,228 Bytes
3740adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import base64
import io
import logging
import time
import os
import traceback

import cv2
import numpy as np
import onnxruntime as ort
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from PIL import Image # <-- BARIS YANG HILANG DITAMBAHKAN DI SINI

# --- Konfigurasi Keamanan ---
ENCODED_SECRET_KEY = os.environ.get("SECRET_KEY", "S3VuY2lSYWhhc2lhVXRhbWFfVW50dWtTb2x2ZXJBbWF6b24=")
SECRET_KEY = base64.b64decode(ENCODED_SECRET_KEY).decode('utf-8')

# Mengatur logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Pemetaan Kelas dari data.yaml ---
CLASS_MAPPING = {
    "the bags": 0,
    "the beds": 1,
    "the buckets": 2,
    "the chairs": 3,
    "the clocks": 4,
    "the curtains": 5,
    "the hats": 6,
}

# --- Memuat Model ONNX ---
ort_session = None
model_input_height = 0
model_input_width = 0
input_name = ""

# Inisialisasi aplikasi FastAPI
app = FastAPI(title="Amazon Captcha Solver - ONNX FastAPI")

@app.on_event("startup")
def load_model():
    """Memuat model ONNX saat server FastAPI pertama kali dijalankan."""
    global ort_session, model_input_height, model_input_width, input_name
    logger.info("🚀 Memuat model ONNX... Harap tunggu.")
    try:
        providers = ['CPUExecutionProvider']
        ort_session = ort.InferenceSession("best.onnx", providers=providers)
        model_inputs = ort_session.get_inputs()
        input_shape = model_inputs[0].shape
        input_name = model_inputs[0].name
        model_input_height, model_input_width = input_shape[2], input_shape[3]
        logger.info(f"✅ Model ONNX berhasil dimuat. Input shape: {input_shape}")
    except Exception as e:
        logger.error(f"❌ Gagal memuat model ONNX: {e}")
        ort_session = None

# --- Model Input Pydantic ---
class SolveRequest(BaseModel):
    api_key: str
    image_base64: str
    target_class: str

# --- FUNGSI PEMBANTU UNTUK ONNX (Tidak Berubah) ---
def preprocess_image(img: Image.Image, height: int, width: int) -> np.ndarray:
    img_np = np.array(img.convert('RGB'))
    original_height, original_width = img_np.shape[:2]
    ratio = min(width / original_width, height / original_height)
    new_width, new_height = int(original_width * ratio), int(original_height * ratio)
    img_resized = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
    padded_image = np.full((height, width, 3), 114, dtype=np.uint8)
    dw, dh = (width - new_width) // 2, (height - new_height) // 2
    padded_image[dh:dh + new_height, dw:dw + new_width, :] = img_resized
    image_tensor = padded_image.astype(np.float32) / 255.0
    image_tensor = np.transpose(image_tensor, (2, 0, 1))
    image_tensor = np.expand_dims(image_tensor, axis=0)
    return image_tensor, ratio, dw, dh

def postprocess_output(output: np.ndarray, ratio: float, dw: int, dh: int, conf_threshold: float, nms_threshold: float) -> list:
    outputs = np.transpose(output[0], (1, 0))
    boxes, scores, class_ids = [], [], []
    for row in outputs:
        class_score = np.max(row[4:])
        if class_score > conf_threshold:
            class_id = np.argmax(row[4:])
            cx, cy, w, h = row[0:4]
            x1 = int((cx - w / 2 - dw) / ratio)
            y1 = int((cy - h / 2 - dh) / ratio)
            x2 = int((cx + w / 2 - dw) / ratio)
            y2 = int((cy + h / 2 - dh) / ratio)
            boxes.append([x1, y1, x2-x1, y2-y1])
            scores.append(class_score)
            class_ids.append(class_id)
    if not boxes: return []
    indices = cv2.dnn.NMSBoxes(boxes, np.array(scores), conf_threshold, nms_threshold)
    final_boxes = []
    if len(indices) > 0:
        for i in indices.flatten():
            x, y, w, h = boxes[i]
            center_x = x + w / 2
            center_y = y + h / 2
            final_boxes.append({'class_id': class_ids[i], 'box': [center_x, center_y]})
    return final_boxes

# --- ENDPOINT UTAMA ---
@app.post("/solve_captcha")
@app.post("/")
async def solve_captcha(payload: SolveRequest):
    logger.info("Permintaan baru diterima...")
    
    if payload.api_key != SECRET_KEY:
        raise HTTPException(status_code=401, detail="API Key tidak valid.")

    if ort_session is None:
        raise HTTPException(status_code=503, detail="Model tidak tersedia atau gagal dimuat.")

    logger.info(f"🎯 Kelas target: '{payload.target_class}'")
    try:
        image_bytes = base64.b64decode(payload.image_base64)
        image = Image.open(io.BytesIO(image_bytes))

        preprocessed_image, ratio, dw, dh = preprocess_image(image, model_input_height, model_input_width)
        
        logger.info("🤖 Menjalankan prediksi model ONNX...")
        start_time = time.time()
        outputs = ort_session.run(None, {input_name: preprocessed_image})
        end_time = time.time()
        logger.info(f"⏱️ Waktu prediksi ONNX: {(end_time - start_time) * 1000:.2f} ms")
        
        results = postprocess_output(outputs[0], ratio, dw, dh, conf_threshold=0.55, nms_threshold=0.02)
        
        target_id = CLASS_MAPPING.get(payload.target_class.lower().strip(), -1)
        if target_id == -1:
            raise HTTPException(status_code=400, detail=f"Kelas '{payload.target_class}' tidak ada di mapping server.")
            
        coordinates = [res['box'] for res in results if res['class_id'] == target_id]
        
        expected_object_count = 9
        if len(results) != expected_object_count:
             error_msg = f"Model mendeteksi {len(results)} objek, bukan {expected_object_count}."
             logger.warning(f"⚠️  {error_msg} Melewatkan...")
             raise HTTPException(status_code=400, detail={"error_code": "WRONG_OBJECT_COUNT", "message": error_msg})
        
        logger.info(f"✅ Ditemukan {len(coordinates)} koordinat untuk '{payload.target_class}'")
        return {"coordinates": coordinates}

    except HTTPException as http_exc:
        raise http_exc
    except Exception as e:
        logger.error(f"💥 Terjadi error internal server: {e}")
        traceback.print_exc()
        raise HTTPException(status_code=500, detail="Internal Server Error")