cullamatmf123's picture
Update app.py
f07d6a7 verified
import os
import sys
import logging
import traceback
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger("oryctes-classifier")
logger.info("Logger initialised.")
logger.info(f"Gradio version: {gr.__version__}")
MODEL_PATH = "best.pt"
model = None
try:
from ultralytics import YOLO
if os.path.exists(MODEL_PATH):
model = YOLO(MODEL_PATH)
logger.info(f"Model loaded: {MODEL_PATH}")
else:
logger.warning(f"best.pt not found at '{MODEL_PATH}'. Running in fallback mode.")
except Exception as e:
logger.error(f"Failed to load model: {e}\n{traceback.format_exc()}")
model = None
HEALTH_CLASSES = {
0: "healthy",
1: "unhealthy",
2: "unspecified",
}
CONFIDENCE_THRESHOLDS = {
"healthy": 0.50,
"unhealthy": 0.45,
"unspecified": 0.30,
}
CLASS_COLORS = {
"healthy": "#2ecc71",
"unhealthy": "#e74c3c",
"unspecified": "#f39c12",
}
# If model outputs "unknown" but you want "unspecified" everywhere, normalize here.
NAME_NORMALIZATION = {
"unknown": "unspecified",
}
def normalize_class_name(name: str) -> str:
if not isinstance(name, str):
return str(name)
return NAME_NORMALIZATION.get(name.strip().lower(), name.strip().lower())
def health_cascade(probs: dict) -> tuple:
ranked = sorted(probs.items(), key=lambda x: x[1], reverse=True)
for cls_name, conf in ranked:
threshold = CONFIDENCE_THRESHOLDS.get(cls_name, 0.30)
if conf >= threshold:
return cls_name, conf
return ranked[0]
def multi_run_predict(image: Image.Image, runs: int = 3) -> dict:
"""
Run model multiple times and average for better stability.
Important: do NOT manually resize to a square (distorts aspect ratio).
Let Ultralytics handle preprocessing via imgsz.
"""
if model is None:
return {}
accumulated = {}
imgsz_list = [224, 256, 192]
for i in range(runs):
imgsz = imgsz_list[i % len(imgsz_list)]
try:
result = model(image, imgsz=imgsz, verbose=False)[0]
names = result.names
probs = result.probs.data.cpu().numpy()
for idx, prob in enumerate(probs):
cls_name = names.get(idx, f"class_{idx}")
cls_name = normalize_class_name(cls_name)
accumulated[cls_name] = accumulated.get(cls_name, 0.0) + float(prob)
except Exception as e:
logger.warning(f"Run {i+1} failed: {e}")
continue
if not accumulated:
return {}
return {k: v / runs for k, v in accumulated.items()}
def predict_classification(image: Image.Image) -> dict:
if image is None:
return {
"success": False,
"class": "unspecified",
"confidence": 0.0,
"all_probs": {},
"message": "No image provided.",
}
image = image.convert("RGB")
if model is None:
return {
"success": True,
"class": "unspecified",
"confidence": 0.0,
"all_probs": {c: 0.0 for c in HEALTH_CLASSES.values()},
"message": "Model not available. Please upload best.pt to the Space.",
}
try:
avg_probs = multi_run_predict(image, runs=3)
if not avg_probs:
raise ValueError("No probabilities returned from model.")
predicted_class, confidence = health_cascade(avg_probs)
predicted_class = normalize_class_name(predicted_class)
logger.info(f"Prediction: {predicted_class} ({confidence:.4f})")
return {
"success": True,
"class": predicted_class,
"confidence": round(confidence, 4),
"all_probs": {k: round(v, 4) for k, v in avg_probs.items()},
"message": "Classification successful.",
}
except Exception as e:
logger.error(f"Prediction error: {e}\n{traceback.format_exc()}")
return {
"success": True,
"class": "unspecified",
"confidence": 0.0,
"all_probs": {c: 0.0 for c in HEALTH_CLASSES.values()},
"message": f"Prediction failed: {str(e)}",
}
def _escape_html(s: str) -> str:
return (
str(s)
.replace("&", "&​amp;")
.replace("<", "&​lt;")
.replace(">", "&​gt;")
)
def predict_on_health(input_image):
"""
Gradio prediction function.
Returns: annotated PIL image, HTML string (colored).
"""
if input_image is None:
blank = Image.new("RGB", (400, 200), color="#1a1a2e")
draw = ImageDraw.Draw(blank)
draw.text((80, 90), "Please upload an image.", fill="white")
return blank, "<div style='color:#fff;'>No image uploaded.</div>"
# Convert numpy array (Gradio default) to PIL
if isinstance(input_image, np.ndarray):
pil_image = Image.fromarray(input_image.astype(np.uint8))
elif isinstance(input_image, Image.Image):
pil_image = input_image
else:
pil_image = Image.fromarray(np.array(input_image).astype(np.uint8))
result = predict_classification(pil_image)
cls_name = normalize_class_name(result["class"])
confidence = float(result["confidence"])
all_probs = result.get("all_probs", {}) or {}
message = result.get("message", "")
# Draw colored bar on image
img_display = pil_image.convert("RGB").copy()
w, h = img_display.size
draw = ImageDraw.Draw(img_display)
bar_h = max(50, h // 8)
bar_color = CLASS_COLORS.get(cls_name, "#888888")
draw.rectangle([0, h - bar_h, w, h], fill=bar_color)
label = f"{cls_name.upper()} {confidence * 100:.1f}%"
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
max(14, bar_h // 2),
)
except Exception:
font = ImageFont.load_default()
bbox = draw.textbbox((0, 0), label, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
text_x = (w - text_w) // 2
text_y = h - bar_h + (bar_h - text_h) // 2
draw.text((text_x, text_y), label, fill="white", font=font)
# Build colored HTML output
emoji = {"healthy": "✅", "unhealthy": "❌", "unspecified": "⚠️"}.get(cls_name, "🔍")
lines = [
f"{emoji} Predicted Class : {cls_name.upper()}",
f"📊 Confidence : {confidence * 100:.2f}%",
"",
"── All Class Probabilities ──",
]
for c, p in sorted(all_probs.items(), key=lambda x: x[1], reverse=True):
try:
p = float(p)
except Exception:
p = 0.0
bar = "█" * int(max(0.0, min(1.0, p)) * 20)
lines.append(f" {str(c):<14} {p * 100:5.1f}% {bar}")
lines += ["", f"ℹ️ {message}"]
text_color = CLASS_COLORS.get(cls_name, "#ffffff")
safe_lines = "<br>".join(_escape_html(line) for line in lines)
html = f"""
<div style="
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace;
white-space: normal;
line-height: 1.35;
color: {text_color};
">
{safe_lines}
</div>
"""
return img_display, html
with gr.Blocks(title="Oryctes Health Classifier") as demo:
gr.HTML(
"""
<div style="text-align:center; padding:16px 0;">
<h1>🌴 Oryctes Health Classifier</h1>
<p>Upload an image to classify it as
<b style="color:#2ecc71">Healthy</b>,
<b style="color:#e74c3c">Unhealthy</b>, or
<b style="color:#f39c12">Unspecified</b>.
</p>
<p style="color:#888; font-size:13px;">
Model: YOLOv8n-cls &​nbsp;·&​nbsp;
3 classes &​nbsp;·&​nbsp;
Multi-run averaging
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="📷 Upload Image",
type="numpy",
height=350,
)
classify_btn = gr.Button(
value="🔍 Classify",
variant="primary",
)
with gr.Column(scale=1):
output_image = gr.Image(
label="🖼️ Image Classification Result",
type="pil",
height=350,
)
output_text = gr.HTML(label="📋 Image Classification Text")
gr.HTML(
"""
<hr>
<p style="text-align:center; color:#aaa; font-size:12px;">
Powered by YOLOv8 · Gradio |
cullamatmf123/cocoscanclassification
</p>
"""
)
classify_btn.click(
fn=predict_on_health,
inputs=input_image,
outputs=[output_image, output_text],
)
input_image.change(
fn=predict_on_health,
inputs=input_image,
outputs=[output_image, output_text],
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)