import os import sys import logging import traceback import warnings warnings.filterwarnings("ignore") import numpy as np import gradio as gr from PIL import Image, ImageDraw, ImageFont logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger("oryctes-classifier") logger.info("Logger initialised.") logger.info(f"Gradio version: {gr.__version__}") MODEL_PATH = "best.pt" model = None try: from ultralytics import YOLO if os.path.exists(MODEL_PATH): model = YOLO(MODEL_PATH) logger.info(f"Model loaded: {MODEL_PATH}") else: logger.warning(f"best.pt not found at '{MODEL_PATH}'. Running in fallback mode.") except Exception as e: logger.error(f"Failed to load model: {e}\n{traceback.format_exc()}") model = None HEALTH_CLASSES = { 0: "healthy", 1: "unhealthy", 2: "unspecified", } CONFIDENCE_THRESHOLDS = { "healthy": 0.50, "unhealthy": 0.45, "unspecified": 0.30, } CLASS_COLORS = { "healthy": "#2ecc71", "unhealthy": "#e74c3c", "unspecified": "#f39c12", } # If model outputs "unknown" but you want "unspecified" everywhere, normalize here. NAME_NORMALIZATION = { "unknown": "unspecified", } def normalize_class_name(name: str) -> str: if not isinstance(name, str): return str(name) return NAME_NORMALIZATION.get(name.strip().lower(), name.strip().lower()) def health_cascade(probs: dict) -> tuple: ranked = sorted(probs.items(), key=lambda x: x[1], reverse=True) for cls_name, conf in ranked: threshold = CONFIDENCE_THRESHOLDS.get(cls_name, 0.30) if conf >= threshold: return cls_name, conf return ranked[0] def multi_run_predict(image: Image.Image, runs: int = 3) -> dict: """ Run model multiple times and average for better stability. Important: do NOT manually resize to a square (distorts aspect ratio). Let Ultralytics handle preprocessing via imgsz. """ if model is None: return {} accumulated = {} imgsz_list = [224, 256, 192] for i in range(runs): imgsz = imgsz_list[i % len(imgsz_list)] try: result = model(image, imgsz=imgsz, verbose=False)[0] names = result.names probs = result.probs.data.cpu().numpy() for idx, prob in enumerate(probs): cls_name = names.get(idx, f"class_{idx}") cls_name = normalize_class_name(cls_name) accumulated[cls_name] = accumulated.get(cls_name, 0.0) + float(prob) except Exception as e: logger.warning(f"Run {i+1} failed: {e}") continue if not accumulated: return {} return {k: v / runs for k, v in accumulated.items()} def predict_classification(image: Image.Image) -> dict: if image is None: return { "success": False, "class": "unspecified", "confidence": 0.0, "all_probs": {}, "message": "No image provided.", } image = image.convert("RGB") if model is None: return { "success": True, "class": "unspecified", "confidence": 0.0, "all_probs": {c: 0.0 for c in HEALTH_CLASSES.values()}, "message": "Model not available. Please upload best.pt to the Space.", } try: avg_probs = multi_run_predict(image, runs=3) if not avg_probs: raise ValueError("No probabilities returned from model.") predicted_class, confidence = health_cascade(avg_probs) predicted_class = normalize_class_name(predicted_class) logger.info(f"Prediction: {predicted_class} ({confidence:.4f})") return { "success": True, "class": predicted_class, "confidence": round(confidence, 4), "all_probs": {k: round(v, 4) for k, v in avg_probs.items()}, "message": "Classification successful.", } except Exception as e: logger.error(f"Prediction error: {e}\n{traceback.format_exc()}") return { "success": True, "class": "unspecified", "confidence": 0.0, "all_probs": {c: 0.0 for c in HEALTH_CLASSES.values()}, "message": f"Prediction failed: {str(e)}", } def _escape_html(s: str) -> str: return ( str(s) .replace("&", "&​amp;") .replace("<", "&​lt;") .replace(">", "&​gt;") ) def predict_on_health(input_image): """ Gradio prediction function. Returns: annotated PIL image, HTML string (colored). """ if input_image is None: blank = Image.new("RGB", (400, 200), color="#1a1a2e") draw = ImageDraw.Draw(blank) draw.text((80, 90), "Please upload an image.", fill="white") return blank, "
No image uploaded.
" # Convert numpy array (Gradio default) to PIL if isinstance(input_image, np.ndarray): pil_image = Image.fromarray(input_image.astype(np.uint8)) elif isinstance(input_image, Image.Image): pil_image = input_image else: pil_image = Image.fromarray(np.array(input_image).astype(np.uint8)) result = predict_classification(pil_image) cls_name = normalize_class_name(result["class"]) confidence = float(result["confidence"]) all_probs = result.get("all_probs", {}) or {} message = result.get("message", "") # Draw colored bar on image img_display = pil_image.convert("RGB").copy() w, h = img_display.size draw = ImageDraw.Draw(img_display) bar_h = max(50, h // 8) bar_color = CLASS_COLORS.get(cls_name, "#888888") draw.rectangle([0, h - bar_h, w, h], fill=bar_color) label = f"{cls_name.upper()} {confidence * 100:.1f}%" try: font = ImageFont.truetype( "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", max(14, bar_h // 2), ) except Exception: font = ImageFont.load_default() bbox = draw.textbbox((0, 0), label, font=font) text_w = bbox[2] - bbox[0] text_h = bbox[3] - bbox[1] text_x = (w - text_w) // 2 text_y = h - bar_h + (bar_h - text_h) // 2 draw.text((text_x, text_y), label, fill="white", font=font) # Build colored HTML output emoji = {"healthy": "✅", "unhealthy": "❌", "unspecified": "⚠️"}.get(cls_name, "🔍") lines = [ f"{emoji} Predicted Class : {cls_name.upper()}", f"📊 Confidence : {confidence * 100:.2f}%", "", "── All Class Probabilities ──", ] for c, p in sorted(all_probs.items(), key=lambda x: x[1], reverse=True): try: p = float(p) except Exception: p = 0.0 bar = "█" * int(max(0.0, min(1.0, p)) * 20) lines.append(f" {str(c):<14} {p * 100:5.1f}% {bar}") lines += ["", f"ℹ️ {message}"] text_color = CLASS_COLORS.get(cls_name, "#ffffff") safe_lines = "
".join(_escape_html(line) for line in lines) html = f"""
{safe_lines}
""" return img_display, html with gr.Blocks(title="Oryctes Health Classifier") as demo: gr.HTML( """

🌴 Oryctes Health Classifier

Upload an image to classify it as Healthy, Unhealthy, or Unspecified.

Model: YOLOv8n-cls &​nbsp;·&​nbsp; 3 classes &​nbsp;·&​nbsp; Multi-run averaging

""" ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="📷 Upload Image", type="numpy", height=350, ) classify_btn = gr.Button( value="🔍 Classify", variant="primary", ) with gr.Column(scale=1): output_image = gr.Image( label="🖼️ Image Classification Result", type="pil", height=350, ) output_text = gr.HTML(label="📋 Image Classification Text") gr.HTML( """

Powered by YOLOv8 · Gradio | cullamatmf123/cocoscanclassification

""" ) classify_btn.click( fn=predict_on_health, inputs=input_image, outputs=[output_image, output_text], ) input_image.change( fn=predict_on_health, inputs=input_image, outputs=[output_image, output_text], ) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)