nusaibah0110's picture
Fix restart loop: implement lazy loading for YOLO models
e195df9
#------------------------------------------------
# Acetowhite Contour Inference
# -----------------------------------------------
import os
import cv2
import numpy as np
import torch
from ultralytics import YOLO
# MODEL LOAD (Safe Backend Path) - LAZY LOADING
AW_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "AW_yolo.pt")
CERVIX_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "cervix_yolo.pt")
aw_model = None
cervix_model = None
def load_aw_model():
"""Lazy load the Acetowhite model on first use"""
global aw_model
if aw_model is not None:
return aw_model
try:
print(f"πŸ”„ Loading Acetowhite model from: {AW_MODEL_PATH}")
aw_model = YOLO(AW_MODEL_PATH)
aw_model.to('cpu')
# Patch the Segment head to prevent the detect() error
if hasattr(aw_model.model, 'model'):
for module in aw_model.model.modules():
if module.__class__.__name__ == 'Segment':
print("⚠️ Patching Segment head to prevent detect() error")
# Disable the problematic detect call
module.detect = lambda self, x: x
print("βœ… Acetowhite model loaded successfully")
return aw_model
except Exception as e:
print(f"❌ Error loading Acetowhite model: {e}")
import traceback
traceback.print_exc()
return None
def load_cervix_model():
"""Lazy load the Cervix model on first use"""
global cervix_model
if cervix_model is not None:
return cervix_model
try:
print(f"πŸ”„ Loading Cervix model from: {CERVIX_MODEL_PATH}")
cervix_model = YOLO(CERVIX_MODEL_PATH)
cervix_model.to('cpu')
print("βœ… Cervix model loaded successfully")
return cervix_model
except Exception as e:
print(f"❌ Error loading Cervix model: {e}")
import traceback
traceback.print_exc()
return None
# CONFIGURABLE PARAMETERS
MIN_AREA = 150 # minimum contour area (px)
SMOOTHING_EPSILON = 0.002 # polygon smoothing factor
DEFAULT_CONF = 0.4 # default confidence threshold
IMG_SIZE = 640 # inference size
# MAIN INFERENCE FUNCTION
def infer_aw_contour(frame, conf_threshold=DEFAULT_CONF):
if frame is None:
return {
"overlay": None,
"contours": [],
"detections": 0,
"frame_width": 0,
"frame_height": 0
}
model = load_aw_model()
if model is None:
print("❌ Acetowhite model not available")
return {
"overlay": None,
"contours": [],
"detections": 0,
"frame_width": frame.shape[1],
"frame_height": frame.shape[0]
}
overlay = frame.copy()
contours_list = []
detection_count = 0
try:
print(f"πŸ”„ Running YOLO prediction on frame shape: {frame.shape}")
results = model.predict(
frame,
conf=conf_threshold,
imgsz=IMG_SIZE,
verbose=False,
device='cpu'
)
# Handle both list and single result
if isinstance(results, (list, tuple)):
result = results[0]
else:
result = results
print(f"βœ… YOLO prediction complete")
# Try to extract masks if available
if hasattr(result, 'masks') and result.masks is not None:
try:
masks = result.masks.xy
if len(masks) > 0:
print(f"βœ… Found {len(masks)} masks")
for idx, polygon in enumerate(masks):
confidence = float(result.boxes.conf[idx])
if confidence < conf_threshold:
continue
contour = polygon.astype(np.int32)
area = cv2.contourArea(contour)
if area < MIN_AREA:
continue
epsilon = SMOOTHING_EPSILON * cv2.arcLength(contour, True)
contour = cv2.approxPolyDP(contour, epsilon, True)
cv2.polylines(overlay, [contour], isClosed=True, color=(0, 255, 0), thickness=2)
contours_list.append({
"points": contour.tolist(),
"area": float(area),
"confidence": round(confidence, 3)
})
detection_count += 1
except Exception as mask_err:
print(f"⚠️ Could not extract masks: {mask_err}")
except Exception as e:
print(f"❌ YOLO prediction error: {e}")
import traceback
traceback.print_exc()
# Continue with empty results rather than crashing
return {
"overlay": overlay if detection_count > 0 else None,
"contours": contours_list,
"detections": detection_count,
"frame_width": frame.shape[1],
"frame_height": frame.shape[0]
}
#-----------------------------------------------
# Cervical and Image Quality Check Inference
# ----------------------------------------------
import cv2
import numpy as np
from ultralytics import YOLO
from collections import deque
# Stability buffer for video
detect_history = deque(maxlen=10)
# QUALITY COMPONENT FUNCTIONS
def compute_focus(gray_roi):
focus = cv2.Laplacian(gray_roi, cv2.CV_64F).var()
return min(focus / 200, 1.0)
def compute_exposure(gray_roi):
exposure = np.mean(gray_roi)
if 70 <= exposure <= 180:
return 1.0
return max(0, 1 - abs(exposure - 125) / 125)
def compute_glare(gray_roi):
_, thresh = cv2.threshold(gray_roi, 240, 255, cv2.THRESH_BINARY)
glare_ratio = np.sum(thresh == 255) / gray_roi.size
if glare_ratio < 0.01:
return 1.0
elif glare_ratio < 0.03:
return 0.7
elif glare_ratio < 0.06:
return 0.4
else:
return 0.1
# MAIN FRAME ANALYSIS
def analyze_frame(frame, conf_threshold=0.3):
if frame is None:
return {
"detected": False,
"detection_confidence": 0.0,
"quality_score": 0.0,
"quality_percent": 0
}
model = load_cervix_model()
if model is None:
print("❌ Cervix model not loaded")
return {
"detected": False,
"detection_confidence": 0.0,
"quality_score": 0.0,
"quality_percent": 0
}
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
try:
results = model.predict(
frame,
conf=conf_threshold,
imgsz=640,
verbose=False,
device='cpu'
)
except Exception as e:
print(f"❌ Cervix model prediction error: {e}")
import traceback
traceback.print_exc()
return {
"detected": False,
"detection_confidence": 0.0,
"quality_score": 0.0,
"quality_percent": 0
}
r = results[0]
if r.boxes is None or len(r.boxes) == 0:
return {
"detected": False,
"detection_confidence": 0.0,
"quality_score": 0.0,
"quality_percent": 0
}
# Take highest confidence box
box = r.boxes.xyxy.cpu().numpy()[0].astype(int)
detection_conf = float(r.boxes.conf.cpu().numpy()[0])
x1, y1, x2, y2 = box
roi = gray[y1:y2, x1:x2]
if roi.size == 0:
return {
"detected": False,
"detection_confidence": detection_conf,
"quality_score": 0.0,
"quality_percent": 0
}
# ---- Quality components ----
focus_n = compute_focus(roi)
exposure_n = compute_exposure(roi)
glare_n = compute_glare(roi)
quality_score = (
0.35 * focus_n +
0.30 * exposure_n +
0.35 * glare_n
)
return {
"detected": True,
"detection_confidence": round(detection_conf, 3),
"quality_score": round(float(quality_score), 3),
"quality_percent": int(quality_score * 100),
"focus_score": round(float(focus_n), 3),
"exposure_score": round(float(exposure_n), 3),
"glare_score": round(float(glare_n), 3)
}
# VIDEO STABILITY ANALYSIS
def analyze_video_frame(frame, conf_threshold=0.3):
result = analyze_frame(frame, conf_threshold)
detect_history.append(1 if result["detected"] else 0)
stable_count = sum(detect_history)
if stable_count >= 7:
result["status"] = "Cervix Detected (Stable)"
elif stable_count > 0:
result["status"] = "Cervix Detected (Unstable)"
else:
result["status"] = "Searching Cervix"
return result
#-----------------------------------------------
# Cervix Bounding Box Detection for Annotations
# -----------------------------------------------
def infer_cervix_bbox(frame, conf_threshold=0.4):
"""
Detect cervix bounding boxes using YOLO model.
Returns bounding boxes and annotated frame.
Args:
frame: Input image frame (BGR)
conf_threshold: Confidence threshold for detection
Returns:
Dictionary with annotated overlay and bounding boxes
"""
if frame is None:
return {
"overlay": None,
"bounding_boxes": [],
"detections": 0,
"frame_width": 0,
"frame_height": 0
}
model = load_cervix_model()
if model is None:
return {
"overlay": None,
"bounding_boxes": [],
"detections": 0,
"frame_width": 0,
"frame_height": 0
}
try:
results = model.predict(
frame,
conf=conf_threshold,
imgsz=640,
verbose=False,
device='cpu'
)[0]
overlay = frame.copy()
bounding_boxes = []
detection_count = 0
if results.boxes is not None and len(results.boxes) > 0:
boxes = results.boxes.xyxy.cpu().numpy()
confidences = results.boxes.conf.cpu().numpy()
for idx, box in enumerate(boxes):
x1, y1, x2, y2 = box.astype(int)
confidence = float(confidences[idx])
# Draw bounding box
cv2.rectangle(
overlay,
(x1, y1),
(x2, y2),
(255, 0, 0), # Blue color
3
)
# Store bounding box info
bounding_boxes.append({
"x1": int(x1),
"y1": int(y1),
"x2": int(x2),
"y2": int(y2),
"width": int(x2 - x1),
"height": int(y2 - y1),
"confidence": round(confidence, 3),
"center_x": int((x1 + x2) / 2),
"center_y": int((y1 + y2) / 2)
})
detection_count += 1
return {
"overlay": overlay if detection_count > 0 else None,
"bounding_boxes": bounding_boxes,
"detections": detection_count,
"frame_width": frame.shape[1],
"frame_height": frame.shape[0]
}
except Exception as e:
print(f"❌ Cervix bounding box detection error: {e}")
import traceback
traceback.print_exc()
return {
"overlay": None,
"bounding_boxes": [],
"detections": 0,
"frame_width": frame.shape[1],
"frame_height": frame.shape[0]
}