Cascade AI
Обновление модели и логики обучения
f63ca0d
import os
import cv2
import numpy as np
import torch
import gradio as gr
from fastapi import FastAPI, File, UploadFile
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
import logging
from train import download_and_prepare_dataset, load_images_and_labels, train_and_evaluate_model, cleanup
import threading
# Настройка логгера
logger = logging.getLogger(__name__)
# Список классов болезней
DISEASE_CLASSES = [
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy'
]
def preprocess_image(image):
"""Подготовка изображения для предсказания"""
if image is None:
return None
# Resize и flatten
img_resized = cv2.resize(image, (64, 64))
img_flattened = img_resized.flatten()
return img_flattened
def load_model():
"""Загрузка обученной модели с поддержкой множественных путей"""
try:
# Список возможных путей для модели
model_paths = [
'/home/user/app/tomato_disease_classifier.pth', # Основной путь
'/tmp/data/state/SVC_comb_R.pth.pth', # Путь Hugging Face
'/tmp/tomato_disease_classifier.pth', # Резервный путь
'tomato_disease_classifier.pth' # Локальный путь
]
# Поиск первого существующего пути
model_path = next((path for path in model_paths if os.path.exists(path)), None)
if model_path is None:
logger.info("Модель не найдена, запускаем обучение в фоновом режиме")
threading.Thread(target=train_model).start()
return None, None
logger.info(f"Загрузка модели из: {model_path}")
# Загрузка данных модели
model_data = torch.load(model_path)
# Создание pipeline с масштабированием
scaler = StandardScaler()
scaler.mean_ = model_data['mean']
scaler.scale_ = model_data['std']
classifier = model_data['classifier']
return scaler, classifier
except Exception as e:
logger.error(f"Ошибка загрузки модели: {e}")
import traceback
logger.error(traceback.format_exc())
return None, None
def predict_disease(image):
"""Предсказание болезни томата"""
if image is None:
return "Пожалуйста, загрузите изображение"
# Загрузка модели
scaler, classifier = load_model()
if scaler is None or classifier is None:
return "Ошибка загрузки модели. Возможно, нужно сначала обучить модель."
# Предобработка изображения
processed_image = preprocess_image(image)
if processed_image is None:
return "Не удалось обработать изображение"
# Масштабирование
processed_image = scaler.transform([processed_image])
# Предсказание
prediction = classifier.predict(processed_image)
probabilities = classifier.predict_proba(processed_image)[0]
# Формирование результата
result = f"Обнаружено: {prediction[0]}\n\n"
result += "Вероятности:\n"
for disease, prob in zip(DISEASE_CLASSES, probabilities):
result += f"{disease}: {prob*100:.2f}%\n"
return result
# FastAPI приложение
app = FastAPI()
# Gradio интерфейс
iface = gr.Interface(
fn=predict_disease,
inputs=gr.Image(type="numpy", label="Загрузите изображение листа томата"),
outputs=gr.Textbox(label="Результат диагностики"),
title="Диагностика болезней томатов",
description="Загрузите изображение листа томата для определения заболевания"
)
# Маршрут для Gradio
@app.get("/")
def read_root():
return {"status": "Tomato Disease Classifier is running"}
# Запуск Gradio
def train_model():
try:
logger.info("Начинаем обучение модели...")
download_and_prepare_dataset()
X, y = load_images_and_labels()
train_and_evaluate_model(X, y)
logger.info("Модель успешно обучена!")
except Exception as e:
logger.error(f"Ошибка при обучении модели: {e}")
finally:
cleanup()
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)