weltech-backend / server.py
Punn1403's picture
Update server.py
87cbff0 verified
from flask import Flask, jsonify, request
from flask_cors import CORS, cross_origin
from ultralytics import YOLO
import cv2
import os
import base64
import numpy as np
import gc
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*", "allow_headers": "*"}})
# ================== CONFIG ==================
# ---------- Detection ----------
DET_MODEL_PATH = "detect.pt"
IMGSZ = 1536
CONF = 0.35
IOU = 0.60
# ---------- Classification ----------
CLS_MODEL_PATH = "class.pt"
CLS_INPUT_SIZE = 224
TARGET_CLASS_ID = 2
# ---------- Crop & Style ----------
CROP_MARGIN = 0.25
FONT = cv2.FONT_HERSHEY_PLAIN
FONT_SCALE = 0.5
FONT_THICKNESS = 1
TEXT_COLOR = (0, 0, 0)
BG_ALPHA = 0.45
PADDING_X = 4
PADDING_Y = 3
CLASS_COLORS = {
0: (0, 0, 255),
1: (0, 255, 0),
2: (255, 0, 0),
3: (0, 255, 255),
}
DEFAULT_COLOR = (180, 180, 180)
# ==========================================
print("Loading AI Models...")
try:
det_model = YOLO(DET_MODEL_PATH)
cls_model = YOLO(CLS_MODEL_PATH)
print("✅ Both Models Loaded Successfully")
except Exception as e:
print(f"❌ Error loading models: {e}")
# ---------- Helper Functions ----------
def crop_with_margin_and_resize(img, box, margin, out_size):
h, w = img.shape[:2]
x1, y1, x2, y2 = map(int, box)
bw = x2 - x1
bh = y2 - y1
mx = int(bw * margin)
my = int(bh * margin)
nx1 = max(0, x1 - mx)
ny1 = max(0, y1 - my)
nx2 = min(w, x2 + mx)
ny2 = min(h, y2 + my)
crop = img[ny1:ny2, nx1:nx2]
if crop.size == 0:
return None
crop_resized = cv2.resize(crop, (out_size, out_size))
return crop_resized
def classify_wbc(crop_bgr, cls_model):
crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
res = cls_model(crop_rgb, imgsz=CLS_INPUT_SIZE, verbose=False)[0]
cls_id = int(res.probs.top1)
conf = float(res.probs.top1conf)
name = cls_model.names[cls_id]
return name, conf
def draw_label(out, overlay, text, x1, y1, color):
(tw, th), _ = cv2.getTextSize(label, FONT, FONT_SCALE, int(FONT_THICKNESS))
tx1 = x1
ty1 = y1 - th - PADDING_Y * 2
tx2 = x1 + tw + PADDING_X * 2
ty2 = y1
if ty1 < 0:
ty1 = y1
ty2 = y1 + th + PADDING_Y * 2
cv2.rectangle(overlay, (tx1, ty1), (tx2, ty2), color, -1)
cv2.putText(out, text, (tx1 + PADDING_X, ty2 - PADDING_Y),
FONT, FONT_SCALE, TEXT_COLOR, FONT_THICKNESS, cv2.LINE_AA)
# ================== API ENDPOINTS ==================
@app.route('/', methods=['GET'])
def index():
return jsonify({"status": "online", "message": "WelTech AI Server (Dual Model) is running"}), 200
@app.route('/process-frame', methods=['POST', 'OPTIONS'])
@cross_origin()
def process_frame():
if request.method == 'OPTIONS':
return jsonify({"status": "ok"}), 200
try:
data = request.json
image_b64 = data.get('image')
if not image_b64:
return jsonify({"status": "error", "message": "No image data"}), 400
# Decode Image
encoded_data = image_b64.split(',')[1] if ',' in image_b64 else image_b64
nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8)
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame is None:
return jsonify({"status": "error", "message": "Image Decode Failed"}), 400
# 1. Detection Phase
results = det_model(
frame,
imgsz=IMGSZ,
conf=0.25,
iou=0.20,
max_det=1500,
agnostic_nms=True,
verbose=False
)[0]
counts = {}
wbc_subcounts = {}
vis = frame.copy()
overlay = frame.copy()
if results.boxes is not None and len(results.boxes) > 0:
boxes = results.boxes.xyxy.cpu().numpy()
clss = results.boxes.cls.cpu().numpy()
confs = results.boxes.conf.cpu().numpy()
det_names = det_model.names
wbc_draw_list = []
other_draw_list = []
for box, cls, conf in zip(boxes, clss, confs):
x1, y1, x2, y2 = map(int, box)
cls_id = int(cls)
base_name = det_names[cls_id]
counts[base_name] = counts.get(base_name, 0) + 1
# --- 1. จัดเตรียมชื่อและดึงค่าความมั่นใจ ---
if cls_id == TARGET_CLASS_ID:
crop = crop_with_margin_and_resize(frame, box, CROP_MARGIN, CLS_INPUT_SIZE)
if crop is not None:
wbc_name, wbc_conf = classify_wbc(crop, cls_model)
raw_name = wbc_name.lower()
if 'neut' in raw_name: std_name = 'Neutrophil'
elif 'lymph' in raw_name: std_name = 'Lymphocyte'
elif 'mono' in raw_name: std_name = 'Monocyte'
elif 'eo' in raw_name: std_name = 'Eosinophil'
elif 'baso' in raw_name: std_name = 'Basophil'
else: std_name = wbc_name.capitalize()
wbc_subcounts[std_name] = wbc_subcounts.get(std_name, 0) + 1
display_label = f"{std_name} {wbc_conf:.2f}"
# ใช้ความมั่นใจจากโมเดลแยกชนิดเป็นเกณฑ์ตัดสินสี
final_conf = wbc_conf
cell_type = "wbc"
else:
display_label = f"WBC {conf:.2f}"
final_conf = conf
cell_type = "wbc"
else:
display_label = f"{base_name} {conf:.2f}"
final_conf = conf
cell_type = base_name.lower()
# --- 2. กำหนดสีตามระดับความมั่นใจ (OpenCV ใช้ระบบ BGR) ---
if final_conf < 0.40:
current_color = (0, 0, 255)
elif final_conf < 0.80:
current_color = (0, 165, 255)
else:
if "rbc" in cell_type:
current_color = (19, 69, 139) # น้ำตาล (Brown)
elif "wbc" in cell_type:
current_color = (255, 0, 0) # น้ำเงิน (Blue)
elif "platelet" in cell_type:
current_color = (128, 128, 128) # เทา (Gray)
else:
current_color = (180, 180, 180) # สีสำรอง
# เก็บใส่ List แยกกันเพื่อวาด WBC ทีหลัง (ให้อยู่บนสุดไม่โดนทับ)
if cls_id == TARGET_CLASS_ID:
wbc_draw_list.append((x1, y1, x2, y2, display_label, current_color))
else:
other_draw_list.append((x1, y1, x2, y2, display_label, current_color))
# --- 3. เริ่มวาดลงบนภาพ (แบบ Clean UI ไม่มีพื้นหลังข้อความ) ---
for (x1, y1, x2, y2, label, color) in other_draw_list + wbc_draw_list:
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 1)
cv2.putText(vis, label, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)
# Encode กลับเป็น Base64
_, buffer = cv2.imencode('.jpg', vis, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
processed_b64 = base64.b64encode(buffer).decode('utf-8')
final_processed_image = f"data:image/jpeg;base64,{processed_b64}"
total_wbc_count = sum(wbc_subcounts.values())
return jsonify({
"status": "success",
"counts": counts,
"wbc_details": wbc_subcounts,
"total": len(results.boxes) if results.boxes is not None else 0,
"total_wbc": total_wbc_count,
"processed_image": final_processed_image
})
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({"status": "error", "message": str(e)}), 500
finally:
gc.collect()
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)