serrasafe / app.py
danyanderson's picture
Update app.py
89ef1f0 verified
import numpy as np
import tensorflow as tf
import cv2
import smtplib
import tempfile
import threading
from tensorflow.keras.utils import img_to_array
from tensorflow.keras import layers, regularizers
from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from PIL import Image
import gradio as gr
# Create base model
input_shape = (224, 224, 3)
base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False)
base_model.trainable = False
# Create Functional model
inputs = layers.Input(shape=input_shape, name="input_layer")
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D(name="pooling_layer")(x)
x = layers.Dense(12, kernel_regularizer=regularizers.l2(0.001))(x)
outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x)
model = tf.keras.Model(inputs, outputs)
# Compile the model
model.compile(
loss="categorical_crossentropy",
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"]
)
# Load weights
model.load_weights('pest_classif.weights.h5')
# Class labels
CLASS_LABELS = ['Fourmis', 'Abeilles', 'Scarabe', 'Chenille', 'Verre de terre',
'Perce-oreille', 'Criquet', 'Papillon de nuit', 'Limace', 'Escargot', 'Guêpes', 'Charançon']
def preprocess_image(frame):
if isinstance(frame, Image.Image):
frame = np.array(frame)
if len(frame.shape) == 3 and frame.shape[2] == 3:
if not isinstance(frame, np.ndarray) or frame.dtype != np.uint8:
frame = frame.astype(np.uint8)
frame_rgb = frame
else:
frame_rgb = frame
frame_resized = cv2.resize(frame_rgb, (224, 224))
img_array = img_to_array(frame_resized)
img_array = preprocess_input(img_array)
return np.expand_dims(img_array, axis=0)
def send_email(predicted_class, confidence):
from_address = "# Your email address"
to_address = "# Recipient's email address"
subject = "Alerte peste Detectée!"
msg = MIMEMultipart()
msg['From'] = from_address
msg['To'] = to_address
msg['Subject'] = subject
body = f"""
Vous recevez ce message car nous sommes à {confidence*100:.2f}% certain que votre serre est infestée de {predicted_class}
Veuillez à prendre les mesures qu'il faut pour pallier à cette situation
"""
msg.attach(MIMEText(body, 'plain'))
try:
server = smtplib.SMTP('smtp.gmail.com', 587)
server.starttls()
server.login(from_address, 'klcwvzxzvmzxbhfq')
text = msg.as_string()
server.sendmail(from_address, to_address, text)
server.quit()
except Exception as e:
print(f"Erreur d'envoi d'email: {e}")
def classify_image(frame):
annotated_frame = frame.copy() if isinstance(frame, np.ndarray) else np.array(frame).copy()
processed_frame = preprocess_image(frame)
predictions = model.predict(processed_frame, verbose=0)
predicted_class_idx = np.argmax(predictions, axis=1)[0]
confidence = predictions[0][predicted_class_idx]
predicted_label = CLASS_LABELS[predicted_class_idx]
label_text = f"Classe: {predicted_label} ({confidence*100:.1f}%)"
if not isinstance(annotated_frame, np.ndarray):
annotated_frame = np.array(annotated_frame)
cv2.putText(annotated_frame, label_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
if confidence > 0.85:
send_email(predicted_label, confidence)
return annotated_frame, predicted_label, f"{confidence*100:.2f}%"
def predict_image(image):
try:
print(f"[DEBUG] predict_image called with image type: {type(image)}")
if image is None:
print("[ERROR] No image provided")
return None, "No image provided", "N/A"
# Avec type="pil", Gradio nous donne directement une PIL Image
# On vérifie quand même au cas où
if not isinstance(image, Image.Image):
print(f"[DEBUG] Converting {type(image)} to PIL Image...")
from PIL import Image as PILImage
if isinstance(image, np.ndarray):
image = PILImage.fromarray(image)
else:
print(f"[ERROR] Unexpected image type: {type(image)}")
return None, f"Unexpected type: {type(image)}", "N/A"
print(f"[DEBUG] Image size: {image.size}, mode: {image.mode}")
annotated_img, label, conf = classify_image(image)
print(f"[DEBUG] Classification successful: {label} ({conf})")
return annotated_img, label, conf
except Exception as e:
print(f"[ERROR] Exception in predict_image: {str(e)}")
import traceback
traceback.print_exc()
return None, f"Error: {str(e)}", "N/A"
def predict_video(video_path):
if video_path is None:
return None
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
output_path = tempfile.mktemp(suffix='.mp4')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % 5 == 0:
annotated_frame, _, _ = classify_image(frame)
out.write(annotated_frame)
else:
out.write(frame)
frame_count += 1
cap.release()
out.release()
return output_path
def predict_webcam(image):
return predict_image(image)
# Create Gradio Interface
with gr.Blocks(title="SerraSafe - Détection de Pestes") as demo:
gr.Markdown(
"""
# 🌱 SerraSafe Guardian - Système de Détection de Pestes
Ce système utilise l'intelligence artificielle pour détecter et classifier les pestes dans votre serre.
**Classes détectées:** Fourmis, Abeilles, Scarabe, Chenille, Verre de terre, Perce-oreille, Criquet, Papillon de nuit, Limace, Escargot, Guêpes, Charançon
"""
)
with gr.Tab("📷 Télécharger Image"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Télécharger une image")
image_button = gr.Button("Analyser l'image", variant="primary")
with gr.Column():
image_output = gr.Image(label="Résultat")
image_label = gr.Textbox(label="Classe détectée")
image_confidence = gr.Textbox(label="Confiance")
image_button.click(
fn=predict_image,
inputs=image_input,
outputs=[image_output, image_label, image_confidence]
)
with gr.Tab("📹 Télécharger Vidéo"):
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Télécharger une vidéo")
video_button = gr.Button("Analyser la vidéo", variant="primary")
with gr.Column():
video_output = gr.Video(label="Vidéo annotée")
video_button.click(
fn=predict_video,
inputs=video_input,
outputs=video_output
)
with gr.Tab("🎥 Webcam en Direct"):
gr.Markdown("### Capturez une image avec votre webcam pour l'analyser")
with gr.Row():
with gr.Column():
webcam_input = gr.Image(sources=["webcam"], type="pil", label="Webcam", streaming=False)
webcam_button = gr.Button("Analyser", variant="primary")
with gr.Column():
webcam_output = gr.Image(label="Résultat")
webcam_label = gr.Textbox(label="Classe détectée")
webcam_confidence = gr.Textbox(label="Confiance")
webcam_button.click(
fn=predict_webcam,
inputs=webcam_input,
outputs=[webcam_output, webcam_label, webcam_confidence]
)
if __name__ == "__main__":
print("Lancement de l'interface SerraSafe...")
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)