Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import shutil | |
| import time | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.svm import SVC | |
| import logging | |
| from train import prepare_dataset, load_images_and_labels, train_and_evaluate_model | |
| # Настройка логирования | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger('tomato_disease_classifier') | |
| # Список классов болезней | |
| DISEASE_CLASSES = [ | |
| 'Tomato___Bacterial_spot', | |
| 'Tomato___Early_blight', | |
| 'Tomato___Late_blight', | |
| 'Tomato___Leaf_Mold', | |
| 'Tomato___Septoria_leaf_spot', | |
| 'Tomato___Spider_mites Two-spotted_spider_mite', | |
| 'Tomato___Target_Spot', | |
| 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', | |
| 'Tomato___Tomato_mosaic_virus', | |
| 'Tomato___healthy' | |
| ] | |
| def load_model(): | |
| """Загрузка обученной модели""" | |
| try: | |
| model_path = '/tmp/tomato_disease_classifier.pth' | |
| logger.info(f"Попытка загрузки модели из {model_path}") | |
| # Если модель не существует, обучаем | |
| if not torch.os.path.exists(model_path): | |
| logger.warning(f"Модель не найдена в {model_path}, начинаем обучение") | |
| # Копируем архив во временную директорию | |
| dataset_path = '/tmp/tomato_dataset.zip' | |
| if not os.path.exists(dataset_path): | |
| logger.info(f"Копируем архив датасета в {dataset_path}") | |
| shutil.copy('tomato_dataset.zip', dataset_path) | |
| # Подготовка датасета | |
| logger.info("Подготовка датасета...") | |
| prepare_dataset(dataset_path) | |
| logger.info("Загрузка изображений и меток...") | |
| X, y = load_images_and_labels() | |
| logger.info(f"Загружено {len(X)} изображений с {len(set(y))} классами") | |
| logger.info("Начало обучения модели...") | |
| train_and_evaluate_model(X, y) | |
| logger.info("Модель успешно обучена и сохранена") | |
| # Загрузка модели | |
| logger.info(f"Загрузка модели из {model_path}") | |
| model_data = torch.load(model_path) | |
| # Создание scaler | |
| scaler = StandardScaler() | |
| scaler.mean_ = model_data['mean'] | |
| scaler.scale_ = model_data['std'] | |
| classifier = model_data['classifier'] | |
| logger.info("Модель успешно загружена") | |
| return scaler, classifier | |
| except Exception as e: | |
| logger.error(f"Ошибка загрузки модели: {e}") | |
| return None, None | |
| def preprocess_image(image): | |
| """Подготовка изображения для предсказания""" | |
| if image is None: | |
| return None | |
| # Resize и flatten | |
| img_resized = cv2.resize(image, (64, 64)) | |
| img_flattened = img_resized.flatten() | |
| return img_flattened | |
| def predict_disease(image): | |
| """Предсказание болезни томата""" | |
| if image is None: | |
| logger.warning("Получено пустое изображение") | |
| return "Пожалуйста, загрузите изображение" | |
| logger.info("Начало процесса предсказания") | |
| # Загрузка модели | |
| logger.info("Загрузка модели для предсказания...") | |
| scaler, classifier = load_model() | |
| if scaler is None or classifier is None: | |
| logger.error("Не удалось загрузить модель для предсказания") | |
| return "Ошибка загрузки модели. Возможно, нужно сначала обучить модель." | |
| # Предобработка изображения | |
| logger.info("Предобработка изображения...") | |
| processed_image = preprocess_image(image) | |
| if processed_image is None: | |
| logger.error("Не удалось обработать изображение") | |
| return "Не удалось обработать изображение" | |
| # Масштабирование | |
| logger.info("Масштабирование изображения...") | |
| processed_image = scaler.transform([processed_image]) | |
| # Предсказание | |
| logger.info("Выполнение предсказания...") | |
| prediction = classifier.predict(processed_image) | |
| probabilities = classifier.predict_proba(processed_image)[0] | |
| # Формирование результата | |
| result = f"Обнаружено: {prediction[0]}\n\n" | |
| result += "Вероятности:\n" | |
| # Логируем результаты | |
| logger.info(f"Результат предсказания: {prediction[0]}") | |
| # Добавляем вероятности в результат | |
| for disease, prob in zip(DISEASE_CLASSES, probabilities): | |
| result += f"{disease}: {prob*100:.2f}%\n" | |
| if prob > 0.1: # Логируем только значимые вероятности | |
| logger.info(f" - {disease}: {prob*100:.2f}%") | |
| logger.info("Предсказание успешно завершено") | |
| return result | |
| # Создание Gradio интерфейса | |
| iface = gr.Interface( | |
| fn=predict_disease, | |
| inputs=gr.Image(type="numpy", label="Загрузите изображение листа томата"), | |
| outputs=gr.Textbox(label="Результат диагностики"), | |
| title="🍅Диагностика болезней томатов", | |
| description="Загрузите изображение листа томата для определения заболевания" | |
| ) | |
| # Функция для принудительного обучения модели | |
| def train_model_if_needed(force=False): | |
| """Обучение модели, если она не существует или если force=True""" | |
| model_path = '/tmp/tomato_disease_classifier.pth' | |
| # Если модель не существует или force=True, обучаем | |
| if force or not torch.os.path.exists(model_path): | |
| logger.info("Начинаем обучение модели...") | |
| start_time = time.time() | |
| # Копируем архив во временную директорию | |
| dataset_path = '/tmp/tomato_dataset.zip' | |
| if not os.path.exists(dataset_path): | |
| logger.info(f"Копируем архив датасета в {dataset_path}") | |
| shutil.copy('tomato_dataset.zip', dataset_path) | |
| # Подготовка датасета | |
| logger.info("Подготовка датасета...") | |
| prepare_dataset(dataset_path) | |
| # Загрузка изображений и меток | |
| logger.info("Загрузка изображений и меток...") | |
| X, y = load_images_and_labels() | |
| logger.info(f"Загружено {len(X)} изображений с {len(set(y))} классами") | |
| # Обучение модели | |
| logger.info("Начало обучения модели...") | |
| train_and_evaluate_model(X, y) | |
| # Завершение обучения | |
| training_time = time.time() - start_time | |
| logger.info(f"Обучение модели завершено за {training_time:.2f} секунд!") | |
| else: | |
| logger.info("Модель уже существует, пропускаем обучение.") | |
| # Запуск приложения | |
| if __name__ == "__main__": | |
| # Принудительное обучение модели при запуске | |
| train_model_if_needed(force=True) | |
| # Запуск Gradio интерфейса | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |