from flask import Flask, jsonify, request from flask_cors import CORS, cross_origin from ultralytics import YOLO import cv2 import os import base64 import numpy as np import gc app = Flask(__name__) CORS(app, resources={r"/*": {"origins": "*", "allow_headers": "*"}}) # ================== CONFIG ================== # ---------- Detection ---------- DET_MODEL_PATH = "detect.pt" IMGSZ = 1536 CONF = 0.35 IOU = 0.60 # ---------- Classification ---------- CLS_MODEL_PATH = "class.pt" CLS_INPUT_SIZE = 224 TARGET_CLASS_ID = 2 # ---------- Crop & Style ---------- CROP_MARGIN = 0.25 FONT = cv2.FONT_HERSHEY_PLAIN FONT_SCALE = 0.5 FONT_THICKNESS = 1 TEXT_COLOR = (0, 0, 0) BG_ALPHA = 0.45 PADDING_X = 4 PADDING_Y = 3 CLASS_COLORS = { 0: (0, 0, 255), 1: (0, 255, 0), 2: (255, 0, 0), 3: (0, 255, 255), } DEFAULT_COLOR = (180, 180, 180) # ========================================== print("Loading AI Models...") try: det_model = YOLO(DET_MODEL_PATH) cls_model = YOLO(CLS_MODEL_PATH) print("✅ Both Models Loaded Successfully") except Exception as e: print(f"❌ Error loading models: {e}") # ---------- Helper Functions ---------- def crop_with_margin_and_resize(img, box, margin, out_size): h, w = img.shape[:2] x1, y1, x2, y2 = map(int, box) bw = x2 - x1 bh = y2 - y1 mx = int(bw * margin) my = int(bh * margin) nx1 = max(0, x1 - mx) ny1 = max(0, y1 - my) nx2 = min(w, x2 + mx) ny2 = min(h, y2 + my) crop = img[ny1:ny2, nx1:nx2] if crop.size == 0: return None crop_resized = cv2.resize(crop, (out_size, out_size)) return crop_resized def classify_wbc(crop_bgr, cls_model): crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) res = cls_model(crop_rgb, imgsz=CLS_INPUT_SIZE, verbose=False)[0] cls_id = int(res.probs.top1) conf = float(res.probs.top1conf) name = cls_model.names[cls_id] return name, conf def draw_label(out, overlay, text, x1, y1, color): (tw, th), _ = cv2.getTextSize(label, FONT, FONT_SCALE, int(FONT_THICKNESS)) tx1 = x1 ty1 = y1 - th - PADDING_Y * 2 tx2 = x1 + tw + PADDING_X * 2 ty2 = y1 if ty1 < 0: ty1 = y1 ty2 = y1 + th + PADDING_Y * 2 cv2.rectangle(overlay, (tx1, ty1), (tx2, ty2), color, -1) cv2.putText(out, text, (tx1 + PADDING_X, ty2 - PADDING_Y), FONT, FONT_SCALE, TEXT_COLOR, FONT_THICKNESS, cv2.LINE_AA) # ================== API ENDPOINTS ================== @app.route('/', methods=['GET']) def index(): return jsonify({"status": "online", "message": "WelTech AI Server (Dual Model) is running"}), 200 @app.route('/process-frame', methods=['POST', 'OPTIONS']) @cross_origin() def process_frame(): if request.method == 'OPTIONS': return jsonify({"status": "ok"}), 200 try: data = request.json image_b64 = data.get('image') if not image_b64: return jsonify({"status": "error", "message": "No image data"}), 400 # Decode Image encoded_data = image_b64.split(',')[1] if ',' in image_b64 else image_b64 nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8) frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if frame is None: return jsonify({"status": "error", "message": "Image Decode Failed"}), 400 # 1. Detection Phase results = det_model( frame, imgsz=IMGSZ, conf=0.25, iou=0.20, max_det=1500, agnostic_nms=True, verbose=False )[0] counts = {} wbc_subcounts = {} vis = frame.copy() overlay = frame.copy() if results.boxes is not None and len(results.boxes) > 0: boxes = results.boxes.xyxy.cpu().numpy() clss = results.boxes.cls.cpu().numpy() confs = results.boxes.conf.cpu().numpy() det_names = det_model.names wbc_draw_list = [] other_draw_list = [] for box, cls, conf in zip(boxes, clss, confs): x1, y1, x2, y2 = map(int, box) cls_id = int(cls) base_name = det_names[cls_id] counts[base_name] = counts.get(base_name, 0) + 1 # --- 1. จัดเตรียมชื่อและดึงค่าความมั่นใจ --- if cls_id == TARGET_CLASS_ID: crop = crop_with_margin_and_resize(frame, box, CROP_MARGIN, CLS_INPUT_SIZE) if crop is not None: wbc_name, wbc_conf = classify_wbc(crop, cls_model) raw_name = wbc_name.lower() if 'neut' in raw_name: std_name = 'Neutrophil' elif 'lymph' in raw_name: std_name = 'Lymphocyte' elif 'mono' in raw_name: std_name = 'Monocyte' elif 'eo' in raw_name: std_name = 'Eosinophil' elif 'baso' in raw_name: std_name = 'Basophil' else: std_name = wbc_name.capitalize() wbc_subcounts[std_name] = wbc_subcounts.get(std_name, 0) + 1 display_label = f"{std_name} {wbc_conf:.2f}" # ใช้ความมั่นใจจากโมเดลแยกชนิดเป็นเกณฑ์ตัดสินสี final_conf = wbc_conf cell_type = "wbc" else: display_label = f"WBC {conf:.2f}" final_conf = conf cell_type = "wbc" else: display_label = f"{base_name} {conf:.2f}" final_conf = conf cell_type = base_name.lower() # --- 2. กำหนดสีตามระดับความมั่นใจ (OpenCV ใช้ระบบ BGR) --- if final_conf < 0.40: current_color = (0, 0, 255) elif final_conf < 0.80: current_color = (0, 165, 255) else: if "rbc" in cell_type: current_color = (19, 69, 139) # น้ำตาล (Brown) elif "wbc" in cell_type: current_color = (255, 0, 0) # น้ำเงิน (Blue) elif "platelet" in cell_type: current_color = (128, 128, 128) # เทา (Gray) else: current_color = (180, 180, 180) # สีสำรอง # เก็บใส่ List แยกกันเพื่อวาด WBC ทีหลัง (ให้อยู่บนสุดไม่โดนทับ) if cls_id == TARGET_CLASS_ID: wbc_draw_list.append((x1, y1, x2, y2, display_label, current_color)) else: other_draw_list.append((x1, y1, x2, y2, display_label, current_color)) # --- 3. เริ่มวาดลงบนภาพ (แบบ Clean UI ไม่มีพื้นหลังข้อความ) --- for (x1, y1, x2, y2, label, color) in other_draw_list + wbc_draw_list: cv2.rectangle(vis, (x1, y1), (x2, y2), color, 1) cv2.putText(vis, label, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA) # Encode กลับเป็น Base64 _, buffer = cv2.imencode('.jpg', vis, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) processed_b64 = base64.b64encode(buffer).decode('utf-8') final_processed_image = f"data:image/jpeg;base64,{processed_b64}" total_wbc_count = sum(wbc_subcounts.values()) return jsonify({ "status": "success", "counts": counts, "wbc_details": wbc_subcounts, "total": len(results.boxes) if results.boxes is not None else 0, "total_wbc": total_wbc_count, "processed_image": final_processed_image }) except Exception as e: import traceback traceback.print_exc() return jsonify({"status": "error", "message": str(e)}), 500 finally: gc.collect() if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port)