import numpy as np import tensorflow as tf import cv2 import smtplib import tempfile import threading from tensorflow.keras.utils import img_to_array from tensorflow.keras import layers, regularizers from tensorflow.keras.applications.efficientnet_v2 import preprocess_input from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from PIL import Image import gradio as gr # Create base model input_shape = (224, 224, 3) base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False) base_model.trainable = False # Create Functional model inputs = layers.Input(shape=input_shape, name="input_layer") x = base_model(inputs, training=False) x = layers.GlobalAveragePooling2D(name="pooling_layer")(x) x = layers.Dense(12, kernel_regularizer=regularizers.l2(0.001))(x) outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x) model = tf.keras.Model(inputs, outputs) # Compile the model model.compile( loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(), metrics=["accuracy"] ) # Load weights model.load_weights('pest_classif.weights.h5') # Class labels CLASS_LABELS = ['Fourmis', 'Abeilles', 'Scarabe', 'Chenille', 'Verre de terre', 'Perce-oreille', 'Criquet', 'Papillon de nuit', 'Limace', 'Escargot', 'Guêpes', 'Charançon'] def preprocess_image(frame): if isinstance(frame, Image.Image): frame = np.array(frame) if len(frame.shape) == 3 and frame.shape[2] == 3: if not isinstance(frame, np.ndarray) or frame.dtype != np.uint8: frame = frame.astype(np.uint8) frame_rgb = frame else: frame_rgb = frame frame_resized = cv2.resize(frame_rgb, (224, 224)) img_array = img_to_array(frame_resized) img_array = preprocess_input(img_array) return np.expand_dims(img_array, axis=0) def send_email(predicted_class, confidence): from_address = "# Your email address" to_address = "# Recipient's email address" subject = "Alerte peste Detectée!" msg = MIMEMultipart() msg['From'] = from_address msg['To'] = to_address msg['Subject'] = subject body = f""" Vous recevez ce message car nous sommes à {confidence*100:.2f}% certain que votre serre est infestée de {predicted_class} Veuillez à prendre les mesures qu'il faut pour pallier à cette situation """ msg.attach(MIMEText(body, 'plain')) try: server = smtplib.SMTP('smtp.gmail.com', 587) server.starttls() server.login(from_address, 'klcwvzxzvmzxbhfq') text = msg.as_string() server.sendmail(from_address, to_address, text) server.quit() except Exception as e: print(f"Erreur d'envoi d'email: {e}") def classify_image(frame): annotated_frame = frame.copy() if isinstance(frame, np.ndarray) else np.array(frame).copy() processed_frame = preprocess_image(frame) predictions = model.predict(processed_frame, verbose=0) predicted_class_idx = np.argmax(predictions, axis=1)[0] confidence = predictions[0][predicted_class_idx] predicted_label = CLASS_LABELS[predicted_class_idx] label_text = f"Classe: {predicted_label} ({confidence*100:.1f}%)" if not isinstance(annotated_frame, np.ndarray): annotated_frame = np.array(annotated_frame) cv2.putText(annotated_frame, label_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) if confidence > 0.85: send_email(predicted_label, confidence) return annotated_frame, predicted_label, f"{confidence*100:.2f}%" def predict_image(image): try: print(f"[DEBUG] predict_image called with image type: {type(image)}") if image is None: print("[ERROR] No image provided") return None, "No image provided", "N/A" # Avec type="pil", Gradio nous donne directement une PIL Image # On vérifie quand même au cas où if not isinstance(image, Image.Image): print(f"[DEBUG] Converting {type(image)} to PIL Image...") from PIL import Image as PILImage if isinstance(image, np.ndarray): image = PILImage.fromarray(image) else: print(f"[ERROR] Unexpected image type: {type(image)}") return None, f"Unexpected type: {type(image)}", "N/A" print(f"[DEBUG] Image size: {image.size}, mode: {image.mode}") annotated_img, label, conf = classify_image(image) print(f"[DEBUG] Classification successful: {label} ({conf})") return annotated_img, label, conf except Exception as e: print(f"[ERROR] Exception in predict_image: {str(e)}") import traceback traceback.print_exc() return None, f"Error: {str(e)}", "N/A" def predict_video(video_path): if video_path is None: return None cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) output_path = tempfile.mktemp(suffix='.mp4') fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) frame_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break if frame_count % 5 == 0: annotated_frame, _, _ = classify_image(frame) out.write(annotated_frame) else: out.write(frame) frame_count += 1 cap.release() out.release() return output_path def predict_webcam(image): return predict_image(image) # Create Gradio Interface with gr.Blocks(title="SerraSafe - Détection de Pestes") as demo: gr.Markdown( """ # 🌱 SerraSafe Guardian - Système de Détection de Pestes Ce système utilise l'intelligence artificielle pour détecter et classifier les pestes dans votre serre. **Classes détectées:** Fourmis, Abeilles, Scarabe, Chenille, Verre de terre, Perce-oreille, Criquet, Papillon de nuit, Limace, Escargot, Guêpes, Charançon """ ) with gr.Tab("📷 Télécharger Image"): with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Télécharger une image") image_button = gr.Button("Analyser l'image", variant="primary") with gr.Column(): image_output = gr.Image(label="Résultat") image_label = gr.Textbox(label="Classe détectée") image_confidence = gr.Textbox(label="Confiance") image_button.click( fn=predict_image, inputs=image_input, outputs=[image_output, image_label, image_confidence] ) with gr.Tab("📹 Télécharger Vidéo"): with gr.Row(): with gr.Column(): video_input = gr.Video(label="Télécharger une vidéo") video_button = gr.Button("Analyser la vidéo", variant="primary") with gr.Column(): video_output = gr.Video(label="Vidéo annotée") video_button.click( fn=predict_video, inputs=video_input, outputs=video_output ) with gr.Tab("🎥 Webcam en Direct"): gr.Markdown("### Capturez une image avec votre webcam pour l'analyser") with gr.Row(): with gr.Column(): webcam_input = gr.Image(sources=["webcam"], type="pil", label="Webcam", streaming=False) webcam_button = gr.Button("Analyser", variant="primary") with gr.Column(): webcam_output = gr.Image(label="Résultat") webcam_label = gr.Textbox(label="Classe détectée") webcam_confidence = gr.Textbox(label="Confiance") webcam_button.click( fn=predict_webcam, inputs=webcam_input, outputs=[webcam_output, webcam_label, webcam_confidence] ) if __name__ == "__main__": print("Lancement de l'interface SerraSafe...") demo.launch( share=False, server_name="0.0.0.0", server_port=7860, show_error=True )