Cascade AI
Добавлено подробное логирование процессов обучения и предсказания
7056451
import gradio as gr
import torch
import numpy as np
import cv2
import os
import shutil
import time
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
import logging
from train import prepare_dataset, load_images_and_labels, train_and_evaluate_model
# Настройка логирования
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('tomato_disease_classifier')
# Список классов болезней
DISEASE_CLASSES = [
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy'
]
def load_model():
"""Загрузка обученной модели"""
try:
model_path = '/tmp/tomato_disease_classifier.pth'
logger.info(f"Попытка загрузки модели из {model_path}")
# Если модель не существует, обучаем
if not torch.os.path.exists(model_path):
logger.warning(f"Модель не найдена в {model_path}, начинаем обучение")
# Копируем архив во временную директорию
dataset_path = '/tmp/tomato_dataset.zip'
if not os.path.exists(dataset_path):
logger.info(f"Копируем архив датасета в {dataset_path}")
shutil.copy('tomato_dataset.zip', dataset_path)
# Подготовка датасета
logger.info("Подготовка датасета...")
prepare_dataset(dataset_path)
logger.info("Загрузка изображений и меток...")
X, y = load_images_and_labels()
logger.info(f"Загружено {len(X)} изображений с {len(set(y))} классами")
logger.info("Начало обучения модели...")
train_and_evaluate_model(X, y)
logger.info("Модель успешно обучена и сохранена")
# Загрузка модели
logger.info(f"Загрузка модели из {model_path}")
model_data = torch.load(model_path)
# Создание scaler
scaler = StandardScaler()
scaler.mean_ = model_data['mean']
scaler.scale_ = model_data['std']
classifier = model_data['classifier']
logger.info("Модель успешно загружена")
return scaler, classifier
except Exception as e:
logger.error(f"Ошибка загрузки модели: {e}")
return None, None
def preprocess_image(image):
"""Подготовка изображения для предсказания"""
if image is None:
return None
# Resize и flatten
img_resized = cv2.resize(image, (64, 64))
img_flattened = img_resized.flatten()
return img_flattened
def predict_disease(image):
"""Предсказание болезни томата"""
if image is None:
logger.warning("Получено пустое изображение")
return "Пожалуйста, загрузите изображение"
logger.info("Начало процесса предсказания")
# Загрузка модели
logger.info("Загрузка модели для предсказания...")
scaler, classifier = load_model()
if scaler is None or classifier is None:
logger.error("Не удалось загрузить модель для предсказания")
return "Ошибка загрузки модели. Возможно, нужно сначала обучить модель."
# Предобработка изображения
logger.info("Предобработка изображения...")
processed_image = preprocess_image(image)
if processed_image is None:
logger.error("Не удалось обработать изображение")
return "Не удалось обработать изображение"
# Масштабирование
logger.info("Масштабирование изображения...")
processed_image = scaler.transform([processed_image])
# Предсказание
logger.info("Выполнение предсказания...")
prediction = classifier.predict(processed_image)
probabilities = classifier.predict_proba(processed_image)[0]
# Формирование результата
result = f"Обнаружено: {prediction[0]}\n\n"
result += "Вероятности:\n"
# Логируем результаты
logger.info(f"Результат предсказания: {prediction[0]}")
# Добавляем вероятности в результат
for disease, prob in zip(DISEASE_CLASSES, probabilities):
result += f"{disease}: {prob*100:.2f}%\n"
if prob > 0.1: # Логируем только значимые вероятности
logger.info(f" - {disease}: {prob*100:.2f}%")
logger.info("Предсказание успешно завершено")
return result
# Создание Gradio интерфейса
iface = gr.Interface(
fn=predict_disease,
inputs=gr.Image(type="numpy", label="Загрузите изображение листа томата"),
outputs=gr.Textbox(label="Результат диагностики"),
title="🍅Диагностика болезней томатов",
description="Загрузите изображение листа томата для определения заболевания"
)
# Функция для принудительного обучения модели
def train_model_if_needed(force=False):
"""Обучение модели, если она не существует или если force=True"""
model_path = '/tmp/tomato_disease_classifier.pth'
# Если модель не существует или force=True, обучаем
if force or not torch.os.path.exists(model_path):
logger.info("Начинаем обучение модели...")
start_time = time.time()
# Копируем архив во временную директорию
dataset_path = '/tmp/tomato_dataset.zip'
if not os.path.exists(dataset_path):
logger.info(f"Копируем архив датасета в {dataset_path}")
shutil.copy('tomato_dataset.zip', dataset_path)
# Подготовка датасета
logger.info("Подготовка датасета...")
prepare_dataset(dataset_path)
# Загрузка изображений и меток
logger.info("Загрузка изображений и меток...")
X, y = load_images_and_labels()
logger.info(f"Загружено {len(X)} изображений с {len(set(y))} классами")
# Обучение модели
logger.info("Начало обучения модели...")
train_and_evaluate_model(X, y)
# Завершение обучения
training_time = time.time() - start_time
logger.info(f"Обучение модели завершено за {training_time:.2f} секунд!")
else:
logger.info("Модель уже существует, пропускаем обучение.")
# Запуск приложения
if __name__ == "__main__":
# Принудительное обучение модели при запуске
train_model_if_needed(force=True)
# Запуск Gradio интерфейса
iface.launch(server_name="0.0.0.0", server_port=7860)