Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from fastapi import FastAPI, File, UploadFile | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.svm import SVC | |
| import logging | |
| from train import download_and_prepare_dataset, load_images_and_labels, train_and_evaluate_model, cleanup | |
| import threading | |
| # Настройка логгера | |
| logger = logging.getLogger(__name__) | |
| # Список классов болезней | |
| DISEASE_CLASSES = [ | |
| 'Tomato___Bacterial_spot', | |
| 'Tomato___Early_blight', | |
| 'Tomato___Late_blight', | |
| 'Tomato___Leaf_Mold', | |
| 'Tomato___Septoria_leaf_spot', | |
| 'Tomato___Spider_mites Two-spotted_spider_mite', | |
| 'Tomato___Target_Spot', | |
| 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', | |
| 'Tomato___Tomato_mosaic_virus', | |
| 'Tomato___healthy' | |
| ] | |
| def preprocess_image(image): | |
| """Подготовка изображения для предсказания""" | |
| if image is None: | |
| return None | |
| # Resize и flatten | |
| img_resized = cv2.resize(image, (64, 64)) | |
| img_flattened = img_resized.flatten() | |
| return img_flattened | |
| def load_model(): | |
| """Загрузка обученной модели с поддержкой множественных путей""" | |
| try: | |
| # Список возможных путей для модели | |
| model_paths = [ | |
| '/home/user/app/tomato_disease_classifier.pth', # Основной путь | |
| '/tmp/data/state/SVC_comb_R.pth.pth', # Путь Hugging Face | |
| '/tmp/tomato_disease_classifier.pth', # Резервный путь | |
| 'tomato_disease_classifier.pth' # Локальный путь | |
| ] | |
| # Поиск первого существующего пути | |
| model_path = next((path for path in model_paths if os.path.exists(path)), None) | |
| if model_path is None: | |
| logger.info("Модель не найдена, запускаем обучение в фоновом режиме") | |
| threading.Thread(target=train_model).start() | |
| return None, None | |
| logger.info(f"Загрузка модели из: {model_path}") | |
| # Загрузка данных модели | |
| model_data = torch.load(model_path) | |
| # Создание pipeline с масштабированием | |
| scaler = StandardScaler() | |
| scaler.mean_ = model_data['mean'] | |
| scaler.scale_ = model_data['std'] | |
| classifier = model_data['classifier'] | |
| return scaler, classifier | |
| except Exception as e: | |
| logger.error(f"Ошибка загрузки модели: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return None, None | |
| def predict_disease(image): | |
| """Предсказание болезни томата""" | |
| if image is None: | |
| return "Пожалуйста, загрузите изображение" | |
| # Загрузка модели | |
| scaler, classifier = load_model() | |
| if scaler is None or classifier is None: | |
| return "Ошибка загрузки модели. Возможно, нужно сначала обучить модель." | |
| # Предобработка изображения | |
| processed_image = preprocess_image(image) | |
| if processed_image is None: | |
| return "Не удалось обработать изображение" | |
| # Масштабирование | |
| processed_image = scaler.transform([processed_image]) | |
| # Предсказание | |
| prediction = classifier.predict(processed_image) | |
| probabilities = classifier.predict_proba(processed_image)[0] | |
| # Формирование результата | |
| result = f"Обнаружено: {prediction[0]}\n\n" | |
| result += "Вероятности:\n" | |
| for disease, prob in zip(DISEASE_CLASSES, probabilities): | |
| result += f"{disease}: {prob*100:.2f}%\n" | |
| return result | |
| # FastAPI приложение | |
| app = FastAPI() | |
| # Gradio интерфейс | |
| iface = gr.Interface( | |
| fn=predict_disease, | |
| inputs=gr.Image(type="numpy", label="Загрузите изображение листа томата"), | |
| outputs=gr.Textbox(label="Результат диагностики"), | |
| title="Диагностика болезней томатов", | |
| description="Загрузите изображение листа томата для определения заболевания" | |
| ) | |
| # Маршрут для Gradio | |
| def read_root(): | |
| return {"status": "Tomato Disease Classifier is running"} | |
| # Запуск Gradio | |
| def train_model(): | |
| try: | |
| logger.info("Начинаем обучение модели...") | |
| download_and_prepare_dataset() | |
| X, y = load_images_and_labels() | |
| train_and_evaluate_model(X, y) | |
| logger.info("Модель успешно обучена!") | |
| except Exception as e: | |
| logger.error(f"Ошибка при обучении модели: {e}") | |
| finally: | |
| cleanup() | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |