Spaces:
Sleeping
Sleeping
| from flask import Flask, jsonify, request | |
| from flask_cors import CORS, cross_origin | |
| from ultralytics import YOLO | |
| import cv2 | |
| import os | |
| import base64 | |
| import numpy as np | |
| import gc | |
| app = Flask(__name__) | |
| CORS(app, resources={r"/*": {"origins": "*", "allow_headers": "*"}}) | |
| # ================== CONFIG ================== | |
| # ---------- Detection ---------- | |
| DET_MODEL_PATH = "detect.pt" | |
| IMGSZ = 1536 | |
| CONF = 0.35 | |
| IOU = 0.60 | |
| # ---------- Classification ---------- | |
| CLS_MODEL_PATH = "class.pt" | |
| CLS_INPUT_SIZE = 224 | |
| TARGET_CLASS_ID = 2 | |
| # ---------- Crop & Style ---------- | |
| CROP_MARGIN = 0.25 | |
| FONT = cv2.FONT_HERSHEY_PLAIN | |
| FONT_SCALE = 0.5 | |
| FONT_THICKNESS = 1 | |
| TEXT_COLOR = (0, 0, 0) | |
| BG_ALPHA = 0.45 | |
| PADDING_X = 4 | |
| PADDING_Y = 3 | |
| CLASS_COLORS = { | |
| 0: (0, 0, 255), | |
| 1: (0, 255, 0), | |
| 2: (255, 0, 0), | |
| 3: (0, 255, 255), | |
| } | |
| DEFAULT_COLOR = (180, 180, 180) | |
| # ========================================== | |
| print("Loading AI Models...") | |
| try: | |
| det_model = YOLO(DET_MODEL_PATH) | |
| cls_model = YOLO(CLS_MODEL_PATH) | |
| print("✅ Both Models Loaded Successfully") | |
| except Exception as e: | |
| print(f"❌ Error loading models: {e}") | |
| # ---------- Helper Functions ---------- | |
| def crop_with_margin_and_resize(img, box, margin, out_size): | |
| h, w = img.shape[:2] | |
| x1, y1, x2, y2 = map(int, box) | |
| bw = x2 - x1 | |
| bh = y2 - y1 | |
| mx = int(bw * margin) | |
| my = int(bh * margin) | |
| nx1 = max(0, x1 - mx) | |
| ny1 = max(0, y1 - my) | |
| nx2 = min(w, x2 + mx) | |
| ny2 = min(h, y2 + my) | |
| crop = img[ny1:ny2, nx1:nx2] | |
| if crop.size == 0: | |
| return None | |
| crop_resized = cv2.resize(crop, (out_size, out_size)) | |
| return crop_resized | |
| def classify_wbc(crop_bgr, cls_model): | |
| crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) | |
| res = cls_model(crop_rgb, imgsz=CLS_INPUT_SIZE, verbose=False)[0] | |
| cls_id = int(res.probs.top1) | |
| conf = float(res.probs.top1conf) | |
| name = cls_model.names[cls_id] | |
| return name, conf | |
| def draw_label(out, overlay, text, x1, y1, color): | |
| (tw, th), _ = cv2.getTextSize(label, FONT, FONT_SCALE, int(FONT_THICKNESS)) | |
| tx1 = x1 | |
| ty1 = y1 - th - PADDING_Y * 2 | |
| tx2 = x1 + tw + PADDING_X * 2 | |
| ty2 = y1 | |
| if ty1 < 0: | |
| ty1 = y1 | |
| ty2 = y1 + th + PADDING_Y * 2 | |
| cv2.rectangle(overlay, (tx1, ty1), (tx2, ty2), color, -1) | |
| cv2.putText(out, text, (tx1 + PADDING_X, ty2 - PADDING_Y), | |
| FONT, FONT_SCALE, TEXT_COLOR, FONT_THICKNESS, cv2.LINE_AA) | |
| # ================== API ENDPOINTS ================== | |
| def index(): | |
| return jsonify({"status": "online", "message": "WelTech AI Server (Dual Model) is running"}), 200 | |
| def process_frame(): | |
| if request.method == 'OPTIONS': | |
| return jsonify({"status": "ok"}), 200 | |
| try: | |
| data = request.json | |
| image_b64 = data.get('image') | |
| if not image_b64: | |
| return jsonify({"status": "error", "message": "No image data"}), 400 | |
| # Decode Image | |
| encoded_data = image_b64.split(',')[1] if ',' in image_b64 else image_b64 | |
| nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| return jsonify({"status": "error", "message": "Image Decode Failed"}), 400 | |
| # 1. Detection Phase | |
| results = det_model( | |
| frame, | |
| imgsz=IMGSZ, | |
| conf=0.25, | |
| iou=0.20, | |
| max_det=1500, | |
| agnostic_nms=True, | |
| verbose=False | |
| )[0] | |
| counts = {} | |
| wbc_subcounts = {} | |
| vis = frame.copy() | |
| overlay = frame.copy() | |
| if results.boxes is not None and len(results.boxes) > 0: | |
| boxes = results.boxes.xyxy.cpu().numpy() | |
| clss = results.boxes.cls.cpu().numpy() | |
| confs = results.boxes.conf.cpu().numpy() | |
| det_names = det_model.names | |
| wbc_draw_list = [] | |
| other_draw_list = [] | |
| for box, cls, conf in zip(boxes, clss, confs): | |
| x1, y1, x2, y2 = map(int, box) | |
| cls_id = int(cls) | |
| base_name = det_names[cls_id] | |
| counts[base_name] = counts.get(base_name, 0) + 1 | |
| # --- 1. จัดเตรียมชื่อและดึงค่าความมั่นใจ --- | |
| if cls_id == TARGET_CLASS_ID: | |
| crop = crop_with_margin_and_resize(frame, box, CROP_MARGIN, CLS_INPUT_SIZE) | |
| if crop is not None: | |
| wbc_name, wbc_conf = classify_wbc(crop, cls_model) | |
| raw_name = wbc_name.lower() | |
| if 'neut' in raw_name: std_name = 'Neutrophil' | |
| elif 'lymph' in raw_name: std_name = 'Lymphocyte' | |
| elif 'mono' in raw_name: std_name = 'Monocyte' | |
| elif 'eo' in raw_name: std_name = 'Eosinophil' | |
| elif 'baso' in raw_name: std_name = 'Basophil' | |
| else: std_name = wbc_name.capitalize() | |
| wbc_subcounts[std_name] = wbc_subcounts.get(std_name, 0) + 1 | |
| display_label = f"{std_name} {wbc_conf:.2f}" | |
| # ใช้ความมั่นใจจากโมเดลแยกชนิดเป็นเกณฑ์ตัดสินสี | |
| final_conf = wbc_conf | |
| cell_type = "wbc" | |
| else: | |
| display_label = f"WBC {conf:.2f}" | |
| final_conf = conf | |
| cell_type = "wbc" | |
| else: | |
| display_label = f"{base_name} {conf:.2f}" | |
| final_conf = conf | |
| cell_type = base_name.lower() | |
| # --- 2. กำหนดสีตามระดับความมั่นใจ (OpenCV ใช้ระบบ BGR) --- | |
| if final_conf < 0.40: | |
| current_color = (0, 0, 255) | |
| elif final_conf < 0.80: | |
| current_color = (0, 165, 255) | |
| else: | |
| if "rbc" in cell_type: | |
| current_color = (19, 69, 139) # น้ำตาล (Brown) | |
| elif "wbc" in cell_type: | |
| current_color = (255, 0, 0) # น้ำเงิน (Blue) | |
| elif "platelet" in cell_type: | |
| current_color = (128, 128, 128) # เทา (Gray) | |
| else: | |
| current_color = (180, 180, 180) # สีสำรอง | |
| # เก็บใส่ List แยกกันเพื่อวาด WBC ทีหลัง (ให้อยู่บนสุดไม่โดนทับ) | |
| if cls_id == TARGET_CLASS_ID: | |
| wbc_draw_list.append((x1, y1, x2, y2, display_label, current_color)) | |
| else: | |
| other_draw_list.append((x1, y1, x2, y2, display_label, current_color)) | |
| # --- 3. เริ่มวาดลงบนภาพ (แบบ Clean UI ไม่มีพื้นหลังข้อความ) --- | |
| for (x1, y1, x2, y2, label, color) in other_draw_list + wbc_draw_list: | |
| cv2.rectangle(vis, (x1, y1), (x2, y2), color, 1) | |
| cv2.putText(vis, label, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA) | |
| # Encode กลับเป็น Base64 | |
| _, buffer = cv2.imencode('.jpg', vis, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) | |
| processed_b64 = base64.b64encode(buffer).decode('utf-8') | |
| final_processed_image = f"data:image/jpeg;base64,{processed_b64}" | |
| total_wbc_count = sum(wbc_subcounts.values()) | |
| return jsonify({ | |
| "status": "success", | |
| "counts": counts, | |
| "wbc_details": wbc_subcounts, | |
| "total": len(results.boxes) if results.boxes is not None else 0, | |
| "total_wbc": total_wbc_count, | |
| "processed_image": final_processed_image | |
| }) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| finally: | |
| gc.collect() | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port) |