amz / app.py
doniramdani820's picture
Upload 6 files
3740adc verified
import base64
import io
import logging
import time
import os
import traceback
import cv2
import numpy as np
import onnxruntime as ort
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from PIL import Image # <-- BARIS YANG HILANG DITAMBAHKAN DI SINI
# --- Konfigurasi Keamanan ---
ENCODED_SECRET_KEY = os.environ.get("SECRET_KEY", "S3VuY2lSYWhhc2lhVXRhbWFfVW50dWtTb2x2ZXJBbWF6b24=")
SECRET_KEY = base64.b64decode(ENCODED_SECRET_KEY).decode('utf-8')
# Mengatur logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Pemetaan Kelas dari data.yaml ---
CLASS_MAPPING = {
"the bags": 0,
"the beds": 1,
"the buckets": 2,
"the chairs": 3,
"the clocks": 4,
"the curtains": 5,
"the hats": 6,
}
# --- Memuat Model ONNX ---
ort_session = None
model_input_height = 0
model_input_width = 0
input_name = ""
# Inisialisasi aplikasi FastAPI
app = FastAPI(title="Amazon Captcha Solver - ONNX FastAPI")
@app.on_event("startup")
def load_model():
"""Memuat model ONNX saat server FastAPI pertama kali dijalankan."""
global ort_session, model_input_height, model_input_width, input_name
logger.info("πŸš€ Memuat model ONNX... Harap tunggu.")
try:
providers = ['CPUExecutionProvider']
ort_session = ort.InferenceSession("best.onnx", providers=providers)
model_inputs = ort_session.get_inputs()
input_shape = model_inputs[0].shape
input_name = model_inputs[0].name
model_input_height, model_input_width = input_shape[2], input_shape[3]
logger.info(f"βœ… Model ONNX berhasil dimuat. Input shape: {input_shape}")
except Exception as e:
logger.error(f"❌ Gagal memuat model ONNX: {e}")
ort_session = None
# --- Model Input Pydantic ---
class SolveRequest(BaseModel):
api_key: str
image_base64: str
target_class: str
# --- FUNGSI PEMBANTU UNTUK ONNX (Tidak Berubah) ---
def preprocess_image(img: Image.Image, height: int, width: int) -> np.ndarray:
img_np = np.array(img.convert('RGB'))
original_height, original_width = img_np.shape[:2]
ratio = min(width / original_width, height / original_height)
new_width, new_height = int(original_width * ratio), int(original_height * ratio)
img_resized = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
padded_image = np.full((height, width, 3), 114, dtype=np.uint8)
dw, dh = (width - new_width) // 2, (height - new_height) // 2
padded_image[dh:dh + new_height, dw:dw + new_width, :] = img_resized
image_tensor = padded_image.astype(np.float32) / 255.0
image_tensor = np.transpose(image_tensor, (2, 0, 1))
image_tensor = np.expand_dims(image_tensor, axis=0)
return image_tensor, ratio, dw, dh
def postprocess_output(output: np.ndarray, ratio: float, dw: int, dh: int, conf_threshold: float, nms_threshold: float) -> list:
outputs = np.transpose(output[0], (1, 0))
boxes, scores, class_ids = [], [], []
for row in outputs:
class_score = np.max(row[4:])
if class_score > conf_threshold:
class_id = np.argmax(row[4:])
cx, cy, w, h = row[0:4]
x1 = int((cx - w / 2 - dw) / ratio)
y1 = int((cy - h / 2 - dh) / ratio)
x2 = int((cx + w / 2 - dw) / ratio)
y2 = int((cy + h / 2 - dh) / ratio)
boxes.append([x1, y1, x2-x1, y2-y1])
scores.append(class_score)
class_ids.append(class_id)
if not boxes: return []
indices = cv2.dnn.NMSBoxes(boxes, np.array(scores), conf_threshold, nms_threshold)
final_boxes = []
if len(indices) > 0:
for i in indices.flatten():
x, y, w, h = boxes[i]
center_x = x + w / 2
center_y = y + h / 2
final_boxes.append({'class_id': class_ids[i], 'box': [center_x, center_y]})
return final_boxes
# --- ENDPOINT UTAMA ---
@app.post("/solve_captcha")
@app.post("/")
async def solve_captcha(payload: SolveRequest):
logger.info("Permintaan baru diterima...")
if payload.api_key != SECRET_KEY:
raise HTTPException(status_code=401, detail="API Key tidak valid.")
if ort_session is None:
raise HTTPException(status_code=503, detail="Model tidak tersedia atau gagal dimuat.")
logger.info(f"🎯 Kelas target: '{payload.target_class}'")
try:
image_bytes = base64.b64decode(payload.image_base64)
image = Image.open(io.BytesIO(image_bytes))
preprocessed_image, ratio, dw, dh = preprocess_image(image, model_input_height, model_input_width)
logger.info("πŸ€– Menjalankan prediksi model ONNX...")
start_time = time.time()
outputs = ort_session.run(None, {input_name: preprocessed_image})
end_time = time.time()
logger.info(f"⏱️ Waktu prediksi ONNX: {(end_time - start_time) * 1000:.2f} ms")
results = postprocess_output(outputs[0], ratio, dw, dh, conf_threshold=0.55, nms_threshold=0.02)
target_id = CLASS_MAPPING.get(payload.target_class.lower().strip(), -1)
if target_id == -1:
raise HTTPException(status_code=400, detail=f"Kelas '{payload.target_class}' tidak ada di mapping server.")
coordinates = [res['box'] for res in results if res['class_id'] == target_id]
expected_object_count = 9
if len(results) != expected_object_count:
error_msg = f"Model mendeteksi {len(results)} objek, bukan {expected_object_count}."
logger.warning(f"⚠️ {error_msg} Melewatkan...")
raise HTTPException(status_code=400, detail={"error_code": "WRONG_OBJECT_COUNT", "message": error_msg})
logger.info(f"βœ… Ditemukan {len(coordinates)} koordinat untuk '{payload.target_class}'")
return {"coordinates": coordinates}
except HTTPException as http_exc:
raise http_exc
except Exception as e:
logger.error(f"πŸ’₯ Terjadi error internal server: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail="Internal Server Error")