pest-detection / app.py
cabrel09's picture
Update app.py
62f4bef verified
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import base64
import io
# Labels exacts
labels = [
"Fall Armyworms", "Western Corn Rootworms", "Colorado Potato Beetles", "Thrips",
"Corn Earworms", "Cabbage Loopers", "Armyworms", "Brown Marmorated Stink Bugs",
"Tomato Hornworms", "Citrus Canker", "Aphids", "Corn Borers", "Fruit Flies",
"Africanized Honey Bees (Killer Bees)", "Spider Mites"
]
# Chargement sécurisé
try:
model_path = hf_hub_download(repo_id="cabrel09/insect-detection-model", filename="vit_insects.onnx")
hf_hub_download(repo_id="cabrel09/insect-detection-model", filename="vit_insects.onnx.data")
session = ort.InferenceSession(model_path)
except Exception as e:
print(f"Erreur chargement: {e}")
session = None
def predict(img_input):
if session is None: return {"error": "Modèle non chargé"}
if img_input is None: return None
try:
# GESTION HYBRIDE : Détection automatique du type d'entrée
if isinstance(img_input, str):
# Si c'est du base64 (via API)
if "base64," in img_input:
img_input = img_input.split("base64,")[1]
img_bytes = base64.b64decode(img_input)
img = Image.open(io.BytesIO(img_bytes))
elif isinstance(img_input, dict) and "url" in img_input:
# Si Gradio envoie un dictionnaire (via API structurée)
url_data = img_input["url"]
if "base64," in url_data:
url_data = url_data.split("base64,")[1]
img_bytes = base64.b64decode(url_data)
img = Image.open(io.BytesIO(img_bytes))
else:
# Si c'est déjà une image PIL (via l'interface Web)
img = img_input
# Preprocessing
img = img.convert("RGB").resize((224, 224))
img_array = np.array(img).astype('float32') / 255.0
img_array = (img_array - 0.5) / 0.5
img_array = np.transpose(img_array, (2, 0, 1))
img_array = np.expand_dims(img_array, axis=0)
# Inférence
outputs = session.run(None, {session.get_inputs()[0].name: img_array})
logits = outputs[0][0]
probs = np.exp(logits - np.max(logits))
probs /= probs.sum()
return {labels[i]: float(probs[i]) for i in range(min(len(labels), len(probs)))}
except Exception as e:
return {"error": str(e)}
# Utilisation de gr.Image mais acceptation de types flexibles
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="PlantPatrol Classifier"
)
demo.launch()