import os import sys import logging import traceback import warnings warnings.filterwarnings("ignore") import numpy as np import gradio as gr from PIL import Image, ImageDraw, ImageFont logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger("oryctes-classifier") logger.info("Logger initialised.") logger.info(f"Gradio version: {gr.__version__}") MODEL_PATH = "best.pt" model = None try: from ultralytics import YOLO if os.path.exists(MODEL_PATH): model = YOLO(MODEL_PATH) logger.info(f"Model loaded: {MODEL_PATH}") else: logger.warning(f"best.pt not found at '{MODEL_PATH}'. Running in fallback mode.") except Exception as e: logger.error(f"Failed to load model: {e}\n{traceback.format_exc()}") model = None HEALTH_CLASSES = { 0: "healthy", 1: "unhealthy", 2: "unspecified", } CONFIDENCE_THRESHOLDS = { "healthy": 0.50, "unhealthy": 0.45, "unspecified": 0.30, } CLASS_COLORS = { "healthy": "#2ecc71", "unhealthy": "#e74c3c", "unspecified": "#f39c12", } # If model outputs "unknown" but you want "unspecified" everywhere, normalize here. NAME_NORMALIZATION = { "unknown": "unspecified", } def normalize_class_name(name: str) -> str: if not isinstance(name, str): return str(name) return NAME_NORMALIZATION.get(name.strip().lower(), name.strip().lower()) def health_cascade(probs: dict) -> tuple: ranked = sorted(probs.items(), key=lambda x: x[1], reverse=True) for cls_name, conf in ranked: threshold = CONFIDENCE_THRESHOLDS.get(cls_name, 0.30) if conf >= threshold: return cls_name, conf return ranked[0] def multi_run_predict(image: Image.Image, runs: int = 3) -> dict: """ Run model multiple times and average for better stability. Important: do NOT manually resize to a square (distorts aspect ratio). Let Ultralytics handle preprocessing via imgsz. """ if model is None: return {} accumulated = {} imgsz_list = [224, 256, 192] for i in range(runs): imgsz = imgsz_list[i % len(imgsz_list)] try: result = model(image, imgsz=imgsz, verbose=False)[0] names = result.names probs = result.probs.data.cpu().numpy() for idx, prob in enumerate(probs): cls_name = names.get(idx, f"class_{idx}") cls_name = normalize_class_name(cls_name) accumulated[cls_name] = accumulated.get(cls_name, 0.0) + float(prob) except Exception as e: logger.warning(f"Run {i+1} failed: {e}") continue if not accumulated: return {} return {k: v / runs for k, v in accumulated.items()} def predict_classification(image: Image.Image) -> dict: if image is None: return { "success": False, "class": "unspecified", "confidence": 0.0, "all_probs": {}, "message": "No image provided.", } image = image.convert("RGB") if model is None: return { "success": True, "class": "unspecified", "confidence": 0.0, "all_probs": {c: 0.0 for c in HEALTH_CLASSES.values()}, "message": "Model not available. Please upload best.pt to the Space.", } try: avg_probs = multi_run_predict(image, runs=3) if not avg_probs: raise ValueError("No probabilities returned from model.") predicted_class, confidence = health_cascade(avg_probs) predicted_class = normalize_class_name(predicted_class) logger.info(f"Prediction: {predicted_class} ({confidence:.4f})") return { "success": True, "class": predicted_class, "confidence": round(confidence, 4), "all_probs": {k: round(v, 4) for k, v in avg_probs.items()}, "message": "Classification successful.", } except Exception as e: logger.error(f"Prediction error: {e}\n{traceback.format_exc()}") return { "success": True, "class": "unspecified", "confidence": 0.0, "all_probs": {c: 0.0 for c in HEALTH_CLASSES.values()}, "message": f"Prediction failed: {str(e)}", } def _escape_html(s: str) -> str: return ( str(s) .replace("&", "&") .replace("<", "<") .replace(">", ">") ) def predict_on_health(input_image): """ Gradio prediction function. Returns: annotated PIL image, HTML string (colored). """ if input_image is None: blank = Image.new("RGB", (400, 200), color="#1a1a2e") draw = ImageDraw.Draw(blank) draw.text((80, 90), "Please upload an image.", fill="white") return blank, "
Upload an image to classify it as Healthy, Unhealthy, or Unspecified.
Model: YOLOv8n-cls  ·  3 classes  ·  Multi-run averaging
Powered by YOLOv8 · Gradio | cullamatmf123/cocoscanclassification
""" ) classify_btn.click( fn=predict_on_health, inputs=input_image, outputs=[output_image, output_text], ) input_image.change( fn=predict_on_health, inputs=input_image, outputs=[output_image, output_text], ) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", server_port=7860)