| import os |
| import sys |
| import logging |
| import traceback |
| import warnings |
|
|
| warnings.filterwarnings("ignore") |
|
|
| import numpy as np |
| import gradio as gr |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| ) |
| logger = logging.getLogger("oryctes-classifier") |
| logger.info("Logger initialised.") |
| logger.info(f"Gradio version: {gr.__version__}") |
|
|
| MODEL_PATH = "best.pt" |
| model = None |
|
|
| try: |
| from ultralytics import YOLO |
|
|
| if os.path.exists(MODEL_PATH): |
| model = YOLO(MODEL_PATH) |
| logger.info(f"Model loaded: {MODEL_PATH}") |
| else: |
| logger.warning(f"best.pt not found at '{MODEL_PATH}'. Running in fallback mode.") |
| except Exception as e: |
| logger.error(f"Failed to load model: {e}\n{traceback.format_exc()}") |
| model = None |
|
|
| HEALTH_CLASSES = { |
| 0: "healthy", |
| 1: "unhealthy", |
| 2: "unspecified", |
| } |
|
|
| CONFIDENCE_THRESHOLDS = { |
| "healthy": 0.50, |
| "unhealthy": 0.45, |
| "unspecified": 0.30, |
| } |
|
|
| CLASS_COLORS = { |
| "healthy": "#2ecc71", |
| "unhealthy": "#e74c3c", |
| "unspecified": "#f39c12", |
| } |
|
|
| |
| NAME_NORMALIZATION = { |
| "unknown": "unspecified", |
| } |
|
|
|
|
| def normalize_class_name(name: str) -> str: |
| if not isinstance(name, str): |
| return str(name) |
| return NAME_NORMALIZATION.get(name.strip().lower(), name.strip().lower()) |
|
|
|
|
| def health_cascade(probs: dict) -> tuple: |
| ranked = sorted(probs.items(), key=lambda x: x[1], reverse=True) |
| for cls_name, conf in ranked: |
| threshold = CONFIDENCE_THRESHOLDS.get(cls_name, 0.30) |
| if conf >= threshold: |
| return cls_name, conf |
| return ranked[0] |
|
|
|
|
| def multi_run_predict(image: Image.Image, runs: int = 3) -> dict: |
| """ |
| Run model multiple times and average for better stability. |
| |
| Important: do NOT manually resize to a square (distorts aspect ratio). |
| Let Ultralytics handle preprocessing via imgsz. |
| """ |
| if model is None: |
| return {} |
|
|
| accumulated = {} |
| imgsz_list = [224, 256, 192] |
|
|
| for i in range(runs): |
| imgsz = imgsz_list[i % len(imgsz_list)] |
| try: |
| result = model(image, imgsz=imgsz, verbose=False)[0] |
| names = result.names |
| probs = result.probs.data.cpu().numpy() |
|
|
| for idx, prob in enumerate(probs): |
| cls_name = names.get(idx, f"class_{idx}") |
| cls_name = normalize_class_name(cls_name) |
| accumulated[cls_name] = accumulated.get(cls_name, 0.0) + float(prob) |
| except Exception as e: |
| logger.warning(f"Run {i+1} failed: {e}") |
| continue |
|
|
| if not accumulated: |
| return {} |
|
|
| return {k: v / runs for k, v in accumulated.items()} |
|
|
|
|
| def predict_classification(image: Image.Image) -> dict: |
| if image is None: |
| return { |
| "success": False, |
| "class": "unspecified", |
| "confidence": 0.0, |
| "all_probs": {}, |
| "message": "No image provided.", |
| } |
|
|
| image = image.convert("RGB") |
|
|
| if model is None: |
| return { |
| "success": True, |
| "class": "unspecified", |
| "confidence": 0.0, |
| "all_probs": {c: 0.0 for c in HEALTH_CLASSES.values()}, |
| "message": "Model not available. Please upload best.pt to the Space.", |
| } |
|
|
| try: |
| avg_probs = multi_run_predict(image, runs=3) |
| if not avg_probs: |
| raise ValueError("No probabilities returned from model.") |
|
|
| predicted_class, confidence = health_cascade(avg_probs) |
| predicted_class = normalize_class_name(predicted_class) |
|
|
| logger.info(f"Prediction: {predicted_class} ({confidence:.4f})") |
|
|
| return { |
| "success": True, |
| "class": predicted_class, |
| "confidence": round(confidence, 4), |
| "all_probs": {k: round(v, 4) for k, v in avg_probs.items()}, |
| "message": "Classification successful.", |
| } |
| except Exception as e: |
| logger.error(f"Prediction error: {e}\n{traceback.format_exc()}") |
| return { |
| "success": True, |
| "class": "unspecified", |
| "confidence": 0.0, |
| "all_probs": {c: 0.0 for c in HEALTH_CLASSES.values()}, |
| "message": f"Prediction failed: {str(e)}", |
| } |
|
|
|
|
| def _escape_html(s: str) -> str: |
| return ( |
| str(s) |
| .replace("&", "&") |
| .replace("<", "<") |
| .replace(">", ">") |
| ) |
|
|
|
|
| def predict_on_health(input_image): |
| """ |
| Gradio prediction function. |
| Returns: annotated PIL image, HTML string (colored). |
| """ |
| if input_image is None: |
| blank = Image.new("RGB", (400, 200), color="#1a1a2e") |
| draw = ImageDraw.Draw(blank) |
| draw.text((80, 90), "Please upload an image.", fill="white") |
| return blank, "<div style='color:#fff;'>No image uploaded.</div>" |
|
|
| |
| if isinstance(input_image, np.ndarray): |
| pil_image = Image.fromarray(input_image.astype(np.uint8)) |
| elif isinstance(input_image, Image.Image): |
| pil_image = input_image |
| else: |
| pil_image = Image.fromarray(np.array(input_image).astype(np.uint8)) |
|
|
| result = predict_classification(pil_image) |
| cls_name = normalize_class_name(result["class"]) |
| confidence = float(result["confidence"]) |
| all_probs = result.get("all_probs", {}) or {} |
| message = result.get("message", "") |
|
|
| |
| img_display = pil_image.convert("RGB").copy() |
| w, h = img_display.size |
| draw = ImageDraw.Draw(img_display) |
|
|
| bar_h = max(50, h // 8) |
| bar_color = CLASS_COLORS.get(cls_name, "#888888") |
| draw.rectangle([0, h - bar_h, w, h], fill=bar_color) |
|
|
| label = f"{cls_name.upper()} {confidence * 100:.1f}%" |
| try: |
| font = ImageFont.truetype( |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", |
| max(14, bar_h // 2), |
| ) |
| except Exception: |
| font = ImageFont.load_default() |
|
|
| bbox = draw.textbbox((0, 0), label, font=font) |
| text_w = bbox[2] - bbox[0] |
| text_h = bbox[3] - bbox[1] |
| text_x = (w - text_w) // 2 |
| text_y = h - bar_h + (bar_h - text_h) // 2 |
| draw.text((text_x, text_y), label, fill="white", font=font) |
|
|
| |
| emoji = {"healthy": "✅", "unhealthy": "❌", "unspecified": "⚠️"}.get(cls_name, "🔍") |
|
|
| lines = [ |
| f"{emoji} Predicted Class : {cls_name.upper()}", |
| f"📊 Confidence : {confidence * 100:.2f}%", |
| "", |
| "── All Class Probabilities ──", |
| ] |
| for c, p in sorted(all_probs.items(), key=lambda x: x[1], reverse=True): |
| try: |
| p = float(p) |
| except Exception: |
| p = 0.0 |
| bar = "█" * int(max(0.0, min(1.0, p)) * 20) |
| lines.append(f" {str(c):<14} {p * 100:5.1f}% {bar}") |
| lines += ["", f"ℹ️ {message}"] |
|
|
| text_color = CLASS_COLORS.get(cls_name, "#ffffff") |
| safe_lines = "<br>".join(_escape_html(line) for line in lines) |
|
|
| html = f""" |
| <div style=" |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace; |
| white-space: normal; |
| line-height: 1.35; |
| color: {text_color}; |
| "> |
| {safe_lines} |
| </div> |
| """ |
| return img_display, html |
|
|
|
|
| with gr.Blocks(title="Oryctes Health Classifier") as demo: |
| gr.HTML( |
| """ |
| <div style="text-align:center; padding:16px 0;"> |
| <h1>🌴 Oryctes Health Classifier</h1> |
| <p>Upload an image to classify it as |
| <b style="color:#2ecc71">Healthy</b>, |
| <b style="color:#e74c3c">Unhealthy</b>, or |
| <b style="color:#f39c12">Unspecified</b>. |
| </p> |
| <p style="color:#888; font-size:13px;"> |
| Model: YOLOv8n-cls  ·  |
| 3 classes  ·  |
| Multi-run averaging |
| </p> |
| </div> |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image = gr.Image( |
| label="📷 Upload Image", |
| type="numpy", |
| height=350, |
| ) |
| classify_btn = gr.Button( |
| value="🔍 Classify", |
| variant="primary", |
| ) |
|
|
| with gr.Column(scale=1): |
| output_image = gr.Image( |
| label="🖼️ Image Classification Result", |
| type="pil", |
| height=350, |
| ) |
| output_text = gr.HTML(label="📋 Image Classification Text") |
|
|
| gr.HTML( |
| """ |
| <hr> |
| <p style="text-align:center; color:#aaa; font-size:12px;"> |
| Powered by YOLOv8 · Gradio | |
| cullamatmf123/cocoscanclassification |
| </p> |
| """ |
| ) |
|
|
| classify_btn.click( |
| fn=predict_on_health, |
| inputs=input_image, |
| outputs=[output_image, output_text], |
| ) |
|
|
| input_image.change( |
| fn=predict_on_health, |
| inputs=input_image, |
| outputs=[output_image, output_text], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) |