beeproject / predict.py
abidjoyia's picture
Upload 11 files
610edea verified
import os
import json
import librosa
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
from PIL import Image
from googleapiclient.discovery import build
from google.oauth2 import service_account
from googleapiclient.http import MediaIoBaseUpload
from flask import jsonify
import logging
from datetime import datetime
import subprocess
import tempfile
import shutil
from database import update_hive_health_in_db
# Ensure matplotlib runs in non-GUI mode
import matplotlib
matplotlib.use("Agg")
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Google Drive API Setup
SCOPES = ['https://www.googleapis.com/auth/drive.file']
PARENT_FOLDER_ID = os.getenv("GOOGLE_DRIVE_FOLDER_ID")
service_account_info = json.loads(os.getenv("G_Drive_Credentials"))
credentials = service_account.Credentials.from_service_account_info(service_account_info)
# Initialize Google Drive API
drive_service = build("drive", "v3", credentials=credentials)
# Load the bee/no bee model
MODEL_PATH = "./mobilenet_best_model_merged_bee_nobee.keras"
bee_model = load_model(MODEL_PATH)
# Load the queen model (MobileNet)
QUEEN_MODEL_PATH = "./mobilenet_best_model_merged_queen_noqueen.keras"
queen_model = load_model(QUEEN_MODEL_PATH)
# Load the mite attack model (MobileNet)
MITE_MODEL_PATH = "./mobilenet_best_model_merged_mite_nomite.keras"
mite_model = load_model(MITE_MODEL_PATH)
def check_ffmpeg():
"""Check if FFmpeg is available."""
return shutil.which("ffmpeg") is not None
def convert_to_wav(input_path, output_path):
"""Convert audio file to WAV format using FFmpeg."""
try:
if not os.path.exists(input_path):
raise FileNotFoundError("Input audio file does not exist")
if not check_ffmpeg():
raise RuntimeError("FFmpeg is not installed or not found in PATH")
cmd = [
"ffmpeg",
"-i", input_path,
"-acodec", "pcm_s16le",
"-ar", "44100",
"-ac", "2",
"-y",
output_path
]
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg conversion failed: {result.stderr}")
logger.info(f"Converted {input_path} to {output_path}")
return output_path
except Exception as e:
logger.error(f"Error converting audio to WAV: {str(e)}")
raise
def get_or_create_folder(folder_name, parent_id):
"""Retrieve or create a folder in Google Drive and return its ID."""
query = f"name='{folder_name}' and mimeType='application/vnd.google-apps.folder' and '{parent_id}' in parents and trashed=false"
response = drive_service.files().list(q=query, spaces='drive', fields='files(id, name)').execute()
files = response.get('files', [])
if files:
return files[0].get('id')
folder_metadata = {
'name': folder_name,
'mimeType': 'application/vnd.google-apps.folder',
'parents': [parent_id]
}
folder = drive_service.files().create(body=folder_metadata, fields='id').execute()
return folder.get('id')
def upload_to_drive(audio_file, result, user_predict=None):
"""Uploads an audio file to the appropriate Google Drive folder(s) based on prediction result and user_predict."""
try:
file_ids = []
if result == "not bee":
folder_name = "not bee"
folder_id = get_or_create_folder(folder_name, PARENT_FOLDER_ID)
file_metadata = {
'name': audio_file.filename,
'parents': [folder_id]
}
media = MediaIoBaseUpload(audio_file, mimetype=audio_file.content_type, resumable=True)
file = drive_service.files().create(
body=file_metadata,
media_body=media,
fields='id'
).execute()
file_ids.append(file.get('id'))
logger.info(f"Uploaded {audio_file.filename} to Google Drive folder '{folder_name}' with file ID: {file.get('id')}")
else:
bee_folder_id = get_or_create_folder("bee", PARENT_FOLDER_ID)
audio_file.seek(0)
file_metadata = {
'name': audio_file.filename,
'parents': [bee_folder_id]
}
media = MediaIoBaseUpload(audio_file, mimetype=audio_file.content_type, resumable=True)
file = drive_service.files().create(
body=file_metadata,
media_body=media,
fields='id'
).execute()
file_ids.append(file.get('id'))
logger.info(f"Uploaded {audio_file.filename} to Google Drive folder 'bee' with file ID: {file.get('id')}")
if user_predict:
user_predict = user_predict.strip().lower()
valid_folders = {'healthy', 'no queen', 'mite attack', 'chalkbrood'}
if user_predict in valid_folders:
folder_id = get_or_create_folder(user_predict, PARENT_FOLDER_ID)
audio_file.seek(0)
file_metadata = {
'name': audio_file.filename,
'parents': [folder_id]
}
media = MediaIoBaseUpload(audio_file, mimetype=audio_file.content_type, resumable=True)
file = drive_service.files().create(
body=file_metadata,
media_body=media,
fields='id'
).execute()
file_ids.append(file.get('id'))
logger.info(f"Uploaded {audio_file.filename} to Google Drive folder '{user_predict}' with file ID: {file.get('id')}")
else:
logger.warning(f"Ignoring invalid user_predict value: {user_predict}")
return file_ids[0]
except Exception as e:
logger.error(f"Error uploading to Google Drive: {str(e)}")
raise
def create_mel_spectrogram(audio_segment, sr):
"""Creates a mel spectrogram from an audio segment."""
try:
spectrogram = librosa.feature.melspectrogram(y=audio_segment, sr=sr, n_mels=128)
spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max)
plt.figure(figsize=(2, 2), dpi=100)
plt.axis('off')
plt.imshow(spectrogram_db, aspect='auto', cmap='magma', origin='lower')
plt.tight_layout(pad=0)
temp_image_path = "/tmp/temp_spectrogram.png"
plt.savefig(temp_image_path, bbox_inches='tight', pad_inches=0)
plt.close()
img = Image.open(temp_image_path).convert('RGB')
img = img.resize((224, 224))
img_array = np.array(img) / 255.0
os.remove(temp_image_path)
return img_array
except Exception as e:
logger.error(f"Error creating spectrogram: {e}")
return None
def predict_queen_audio(file_path, model):
"""Processes audio for queen detection using 10-second segments."""
try:
y, sr = librosa.load(file_path, sr=None)
duration = librosa.get_duration(y=y, sr=sr)
if duration <= 10:
return {"error": "Audio file must be longer than 10 seconds"}
healthy_count = 0
total_segments = 0
segment_start = 0
while segment_start < duration:
segment_end = min(segment_start + 10, duration)
if segment_end - segment_start < 10 and segment_start > 0:
segment_start = max(0, duration - 10)
segment_end = duration
audio_segment = y[int(segment_start * sr):int(segment_end * sr)]
spectrogram = create_mel_spectrogram(audio_segment, sr)
if spectrogram is not None:
spectrogram = np.expand_dims(spectrogram, axis=0)
prediction = model.predict(spectrogram)
probability = prediction[0][0] # Assuming binary classification (0: no queen, 1: healthy)
if probability >= 0.8: # Aligned with predict_audio for consistency
healthy_count += 1
total_segments += 1
segment_start += 10
if total_segments > 0:
healthy_percentage = (healthy_count / total_segments) * 100
result = "healthy" if healthy_percentage >= 70 else "no queen"
return result
else:
return {"error": "No valid segments processed"}
except Exception as e:
logger.error(f"Error in queen prediction: {e}")
return {"error": str(e)}
def predict_mite_audio(file_path, model):
"""Processes audio for mite attack detection using 10-second segments."""
try:
y, sr = librosa.load(file_path, sr=None)
duration = librosa.get_duration(y=y, sr=sr)
if duration <= 10:
return {"error": "Audio file must be longer than 10 seconds"}
healthy_count = 0
total_segments = 0
segment_start = 0
while segment_start < duration:
segment_end = min(segment_start + 10, duration)
if segment_end - segment_start < 10 and segment_start > 0:
segment_start = max(0, duration - 10)
segment_end = duration
audio_segment = y[int(segment_start * sr):int(segment_end * sr)]
spectrogram = create_mel_spectrogram(audio_segment, sr)
if spectrogram is not None:
spectrogram = np.expand_dims(spectrogram, axis=0)
prediction = model.predict(spectrogram)
probability = prediction[0][0] # Assuming binary classification (0: mite attack, 1: healthy)
if probability >= 0.8: # Consistent threshold
healthy_count += 1
total_segments += 1
segment_start += 10
if total_segments > 0:
healthy_percentage = (healthy_count / total_segments) * 100
result = "healthy" if healthy_percentage >= 70 else "mite attack"
return result
else:
return {"error": "No valid segments processed"}
except Exception as e:
logger.error(f"Error in mite attack prediction: {e}")
return {"error": str(e)}
def predict_audio(audio_path, request_id):
"""Predicts whether an audio file contains bee sounds."""
try:
y, sr = librosa.load(audio_path, sr=None)
if y is None or sr is None:
return {"error": "Failed to load audio"}
duration = librosa.get_duration(y=y, sr=sr)
if duration <= 10:
return {"error": "Audio file must be longer than 10 seconds"}
bee_count = 0
total_segments = 0
segment_start = 0
while segment_start < duration:
segment_end = min(segment_start + 10, duration)
if segment_end - segment_start < 10 and segment_start > 0:
segment_start = max(0, duration - 10)
segment_end = duration
audio_segment = y[int(segment_start * sr):int(segment_end * sr)]
spectrogram = create_mel_spectrogram(audio_segment, sr)
if spectrogram is not None:
spectrogram = np.expand_dims(spectrogram, axis=0)
prediction = bee_model.predict(spectrogram)
probability = prediction[0][0]
if probability <= 0.2:
bee_count += 1
total_segments += 1
segment_start += 10
if total_segments > 0:
bee_percentage = (bee_count / total_segments) * 100
result = "bee" if bee_percentage >= 70 else "not bee"
logger.info(f"Request {request_id} - Prediction result: {result}")
return {"result": result}
else:
return {"result": "try again"}
except Exception as e:
logger.error(f"Request {request_id} - Error during bee prediction: {e}")
return {"error": str(e)}
def handle_predict(request, save_prediction):
"""Handles the prediction logic for the /predict route."""
request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
logger.info(f"Request {request_id} - Starting prediction")
if 'audio' not in request.files or 'user_id' not in request.form:
logger.error(f"Request {request_id} - Missing data")
return jsonify({"error": "Missing data"}), 400
user_id = request.form['user_id']
audio_file = request.files['audio']
hive_id = request.form.get('hive_id')
user_predict = request.form.get('user_predict')
if audio_file.filename == '':
logger.error(f"Request {request_id} - No file selected")
return jsonify({"error": "No file selected"}), 400
if not check_ffmpeg():
logger.error(f"Request {request_id} - FFmpeg not found")
return jsonify({"error": "Server error: FFmpeg is not available"}), 500
temp_dir = tempfile.gettempdir()
original_filename = audio_file.filename
original_ext = os.path.splitext(original_filename)[1].lower()
temp_original_path = os.path.join(temp_dir, f"temp_audio_{request_id}{original_ext}")
audio_file.save(temp_original_path)
temp_wav_path = os.path.join(temp_dir, f"temp_audio_{request_id}.wav")
if original_ext not in ['.wav', '.mp3']:
try:
convert_to_wav(temp_original_path, temp_wav_path)
except Exception as e:
os.remove(temp_original_path)
logger.error(f"Request {request_id} - Audio conversion failed: {str(e)}")
return jsonify({"error": f"Failed to process audio file: {str(e)}"}), 400
else:
temp_wav_path = temp_original_path
logger.info(f"Request {request_id} - Using original file (no conversion needed): {temp_wav_path}")
prediction_result = predict_audio(temp_wav_path, request_id)
if "error" in prediction_result:
os.remove(temp_original_path)
if temp_wav_path != temp_original_path and os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
logger.error(f"Request {request_id} - Prediction failed: {prediction_result['error']}")
return jsonify({"result": "try again"}), 400
result = prediction_result["result"].lower()
if result == "try again":
os.remove(temp_original_path)
if temp_wav_path != temp_original_path and os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
logger.info(f"Request {request_id} - Result: try again")
return jsonify({"result": "try again"})
if result == "bee":
queen_result = predict_queen_audio(temp_wav_path, queen_model)
mite_result = predict_mite_audio(temp_wav_path, mite_model)
if "error" in queen_result or "error" in mite_result:
os.remove(temp_original_path)
if temp_wav_path != temp_original_path and os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
logger.error(f"Request {request_id} - Queen prediction failed: {queen_result.get('error', 'Unknown')}, Mite prediction failed: {mite_result.get('error', 'Unknown')}")
return jsonify({"result": "try again"}), 400
logger.info(f"Request {request_id} - Queen prediction result: {queen_result}, Mite prediction result: {mite_result}")
# Combine results according to specified logic
if queen_result == "healthy" and mite_result == "healthy":
result = "healthy"
elif queen_result == "no queen" and mite_result == "healthy":
result = "no queen"
elif queen_result == "healthy" and mite_result == "mite attack":
result = "mite attack"
elif queen_result == "no queen" and mite_result == "mite attack":
result = "no queen,mite attack"
else:
result = "try again" # Fallback for unexpected cases
logger.warning(f"Request {request_id} - Unexpected combination: queen={queen_result}, mite={mite_result}")
if user_predict and hive_id:
try:
user_predict = user_predict.strip().lower()
update_hive_health_in_db(hive_id, user_predict)
logger.info(f"Request {request_id} - Updated hive {hive_id} health_status to {user_predict}")
except Exception as e:
logger.error(f"Request {request_id} - Failed to update hive health status: {str(e)}")
return jsonify({"error": f"Failed to update hive health status: {str(e)}"}), 400
file_id = None
try:
audio_file.seek(0)
file_id = upload_to_drive(audio_file, result, user_predict)
except Exception as e:
logger.error(f"Request {request_id} - Failed to upload to Google Drive: {str(e)}")
# Continue with saving prediction and returning result with file_id=None
save_prediction(user_id, audio_file.filename, result, file_id, hive_id, user_predict)
os.remove(temp_original_path)
if temp_wav_path != temp_original_path and os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
logger.info(f"Request {request_id} - Final result: {result}, file_id: {file_id}")
return jsonify({"result": result, "file_id": file_id})