darkgoolder commited on
Commit
a60511d
·
1 Parent(s): e82378c
.env.example ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API Settings
2
+ API_V1_PREFIX=/api/v1
3
+ PROJECT_NAME=Wagon Classification API
4
+ VERSION=1.0.0
5
+
6
+ # Model Settings
7
+ MODEL_PATH=models/best_model.pth
8
+ CLASS_NAMES=pered,zad,none
9
+
10
+ # Security
11
+ MAX_UPLOAD_SIZE=10485760
12
+ ALLOWED_EXTENSIONS=.jpg,.jpeg,.png,.bmp
13
+
14
+ # CORS
15
+ ALLOWED_ORIGINS=http://localhost,http://localhost:8000,http://127.0.0.1:8000
16
+
17
+ # Logging
18
+ LOG_LEVEL=INFO
.gitignore CHANGED
@@ -205,3 +205,9 @@ cython_debug/
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
 
 
 
 
 
 
 
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
208
+
209
+
210
+ uploads/
211
+ wagon_classification/
212
+ wagon_data/
213
+
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Устанавливаем системные зависимости
6
+ RUN apt-get update && apt-get install -y \
7
+ gcc \
8
+ g++ \
9
+ libgl1-mesa-glx \
10
+ libglib2.0-0 \
11
+ libsm6 \
12
+ libxext6 \
13
+ libxrender-dev \
14
+ libgomp1 \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ # Копируем зависимости
18
+ COPY requirements.txt .
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Копируем приложение
22
+ COPY . .
23
+
24
+ # Создаем необходимые папки
25
+ RUN mkdir -p models uploads
26
+
27
+ # Открываем порт
28
+ EXPOSE 8000
29
+
30
+ # Запускаем приложение
31
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
app/_init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Wagon Classification API - приложение для классификации вагонов"""
2
+
3
+ __version__ = "1.0.0"
app/api/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """API маршруты приложения"""
2
+
3
+ from app.api.routes import router
4
+
5
+ __all__ = ['router']
app/api/dependencies.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Зависимости для API маршрутов
3
+ """
4
+
5
+ from fastapi import Request, HTTPException, status
6
+ from app.models.wagon_model import get_classifier
7
+
8
+
9
+ async def verify_model_loaded():
10
+ """Проверка, что модель загружена"""
11
+ try:
12
+ classifier = get_classifier()
13
+ if classifier.model is None:
14
+ raise HTTPException(
15
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
16
+ detail="Модель не загружена"
17
+ )
18
+ return classifier
19
+ except Exception as e:
20
+ raise HTTPException(
21
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
22
+ detail=f"Ошибка загрузки модели: {str(e)}"
23
+ )
app/api/routes.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API эндпоинты для классификации вагонов
3
+ """
4
+
5
+ import os
6
+ import uuid
7
+ import logging
8
+ from typing import List
9
+ from fastapi import APIRouter, File, UploadFile, HTTPException, status
10
+ from fastapi.responses import JSONResponse
11
+ from PIL import Image
12
+ import io
13
+
14
+ from app.models.schemas import PredictionResponse, ErrorResponse, HealthResponse, BatchPredictionResponse
15
+ from app.models.wagon_model import get_classifier
16
+ from app.config import settings
17
+ from app.utils.image_utils import validate_image_file, process_image
18
+
19
+ logger = logging.getLogger(__name__)
20
+ router = APIRouter()
21
+
22
+
23
+ @router.get(
24
+ "/health",
25
+ response_model=HealthResponse,
26
+ tags=["System"],
27
+ summary="Проверка здоровья сервиса"
28
+ )
29
+ async def health_check():
30
+ """
31
+ Проверка работоспособности API и наличия модели
32
+ """
33
+ try:
34
+ classifier = get_classifier()
35
+ return HealthResponse(
36
+ status="healthy",
37
+ model_loaded=True,
38
+ device=classifier.device,
39
+ version=settings.VERSION
40
+ )
41
+ except Exception as e:
42
+ logger.error(f"Health check failed: {e}")
43
+ return HealthResponse(
44
+ status="unhealthy",
45
+ model_loaded=False,
46
+ device="unknown",
47
+ version=settings.VERSION
48
+ )
49
+
50
+
51
+ @router.post(
52
+ "/predict",
53
+ response_model=PredictionResponse,
54
+ tags=["Prediction"],
55
+ summary="Классификация одного изображения",
56
+ responses={
57
+ 400: {"model": ErrorResponse, "description": "Ошибка валидации"},
58
+ 413: {"model": ErrorResponse, "description": "Файл слишком большой"},
59
+ 500: {"model": ErrorResponse, "description": "Внутренняя ошибка сервера"}
60
+ }
61
+ )
62
+ async def predict_image(
63
+ file: UploadFile = File(..., description="Изображение вагона")
64
+ ):
65
+ """
66
+ Классифицирует изображение вагона
67
+
68
+ Определяет:
69
+ - **pered** - передняя часть вагона
70
+ - **zad** - задняя часть вагона
71
+ - **none** - вагон не обнаружен
72
+
73
+ Возвращает предсказанный класс и уверенность модели.
74
+ """
75
+ try:
76
+ # Валидация файла
77
+ validate_image_file(file, settings)
78
+
79
+ # Загружаем изображение
80
+ image = process_image(file)
81
+
82
+ # Получаем модель и делаем предсказание
83
+ classifier = get_classifier()
84
+ predicted_class, confidence, probabilities = classifier.predict(image)
85
+
86
+ # Формируем ответ
87
+ response_data = {
88
+ "class": predicted_class,
89
+ "class_name": classifier.class_names_ru.get(predicted_class, predicted_class),
90
+ "confidence": confidence,
91
+ "probabilities": probabilities
92
+ }
93
+
94
+ return PredictionResponse(
95
+ status="success",
96
+ data=response_data,
97
+ request_id=str(uuid.uuid4())
98
+ )
99
+
100
+ except HTTPException:
101
+ raise
102
+ except Exception as e:
103
+ logger.error(f"Непредвиденная ошибка: {e}", exc_info=True)
104
+ raise HTTPException(
105
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
106
+ detail={
107
+ "code": "INTERNAL_ERROR",
108
+ "message": "Внутренняя ошибка сервера"
109
+ }
110
+ )
111
+
112
+
113
+ @router.post(
114
+ "/predict-batch",
115
+ tags=["Prediction"],
116
+ summary="Пакетная классификация изображений"
117
+ )
118
+ async def predict_batch(
119
+ files: List[UploadFile] = File(..., description="Список изображений")
120
+ ):
121
+ """
122
+ Классифицирует несколько изображений одновременно
123
+
124
+ Максимальное количество файлов не ограничено, но каждый файл
125
+ должен соответствовать требованиям по размеру и формату.
126
+ """
127
+ try:
128
+ classifier = get_classifier()
129
+ results = []
130
+
131
+ for file in files:
132
+ try:
133
+ # Валидация
134
+ validate_image_file(file, settings)
135
+ image = process_image(file)
136
+
137
+ # Предсказание
138
+ predicted_class, confidence, probabilities = classifier.predict(image)
139
+
140
+ results.append({
141
+ "filename": file.filename,
142
+ "success": True,
143
+ "result": {
144
+ "class": predicted_class,
145
+ "class_name": classifier.class_names_ru.get(predicted_class, predicted_class),
146
+ "confidence": confidence,
147
+ "probabilities": probabilities
148
+ }
149
+ })
150
+
151
+ except HTTPException as e:
152
+ results.append({
153
+ "filename": file.filename,
154
+ "success": False,
155
+ "error": e.detail.get("message", str(e.detail))
156
+ })
157
+ except Exception as e:
158
+ results.append({
159
+ "filename": file.filename,
160
+ "success": False,
161
+ "error": str(e)
162
+ })
163
+
164
+ return JSONResponse(
165
+ status_code=status.HTTP_200_OK,
166
+ content={
167
+ "status": "success",
168
+ "results": results,
169
+ "total": len(results),
170
+ "successful": sum(1 for r in results if r["success"])
171
+ }
172
+ )
173
+
174
+ except Exception as e:
175
+ logger.error(f"Ошибка пакетной обработки: {e}", exc_info=True)
176
+ raise HTTPException(
177
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
178
+ detail={
179
+ "code": "BATCH_ERROR",
180
+ "message": "Ошибка пакетной обработки"
181
+ }
182
+ )
app/config.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Конфигурация приложения
3
+ Загружает настройки из переменных окружения
4
+ """
5
+
6
+ import os
7
+ from pathlib import Path
8
+ from typing import List
9
+ from pydantic_settings import BaseSettings
10
+ from dotenv import load_dotenv
11
+
12
+ # Загружаем переменные окружения
13
+ load_dotenv()
14
+
15
+ # Базовые пути
16
+ BASE_DIR = Path(__file__).resolve().parent.parent
17
+ MODEL_DIR = BASE_DIR / "models"
18
+ UPLOAD_DIR = BASE_DIR / "uploads"
19
+
20
+ # Создаем необходимые папки
21
+ UPLOAD_DIR.mkdir(exist_ok=True)
22
+ MODEL_DIR.mkdir(exist_ok=True)
23
+
24
+
25
+ class Settings(BaseSettings):
26
+ """Настройки приложения"""
27
+
28
+ # API настройки
29
+ API_V1_PREFIX: str = "/api/v1"
30
+ PROJECT_NAME: str = "Wagon Classification API"
31
+ VERSION: str = "1.0.0"
32
+
33
+ # Модель
34
+ MODEL_PATH: str = str(MODEL_DIR / "best_model.pth")
35
+ CLASS_NAMES: List[str] = ["pered", "zad", "none"]
36
+
37
+ # Безопасность
38
+ MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024 # 10 MB
39
+ ALLOWED_EXTENSIONS: set = {".jpg", ".jpeg", ".png", ".bmp"}
40
+
41
+ # CORS
42
+ ALLOWED_ORIGINS: List[str] = [
43
+ "http://localhost",
44
+ "http://localhost:8000",
45
+ "http://127.0.0.1:8000"
46
+ ]
47
+
48
+ # Логирование
49
+ LOG_LEVEL: str = "INFO"
50
+
51
+ class Config:
52
+ env_file = ".env"
53
+ case_sensitive = True
54
+
55
+
56
+ settings = Settings()
app/main.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Основной файл приложения FastAPI
3
+ """
4
+
5
+ from fastapi import FastAPI
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.staticfiles import StaticFiles
8
+ from fastapi.responses import FileResponse
9
+ import logging
10
+ import os
11
+ from pathlib import Path
12
+
13
+ from app.api.routes import router
14
+ from app.config import settings
15
+ from app.utils.logger import setup_logging
16
+
17
+ # Настройка логирования
18
+ setup_logging(settings.LOG_LEVEL)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Создаем приложение
22
+ app = FastAPI(
23
+ title=settings.PROJECT_NAME,
24
+ version=settings.VERSION,
25
+ description="API для классификации вагонов по изображениям\n\n"
26
+ "Определяет переднюю и заднюю часть вагона на фотографии.",
27
+ docs_url="/docs",
28
+ redoc_url="/redoc",
29
+ openapi_tags=[
30
+ {
31
+ "name": "System",
32
+ "description": "Системные эндпоинты (health check)"
33
+ },
34
+ {
35
+ "name": "Prediction",
36
+ "description": "Эндпоинты для классификации изображений"
37
+ }
38
+ ]
39
+ )
40
+
41
+ # Настройка CORS
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=settings.ALLOWED_ORIGINS,
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
+ # Подключаем API роуты
51
+ app.include_router(router, prefix=settings.API_V1_PREFIX)
52
+
53
+ # Статические файлы (веб-интерфейс)
54
+ static_dir = Path(__file__).parent / "static"
55
+ if static_dir.exists():
56
+ app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
57
+
58
+
59
+ @app.get("/")
60
+ async def root():
61
+ """Главная страница"""
62
+ static_index = static_dir / "index.html"
63
+ if static_index.exists():
64
+ return FileResponse(str(static_index))
65
+ return {
66
+ "message": settings.PROJECT_NAME,
67
+ "version": settings.VERSION,
68
+ "docs": "/docs",
69
+ "health": f"{settings.API_V1_PREFIX}/health"
70
+ }
71
+
72
+
73
+ @app.on_event("startup")
74
+ async def startup_event():
75
+ """Загрузка модели при старте"""
76
+ logger.info("=" * 50)
77
+ logger.info(f"Запуск {settings.PROJECT_NAME} v{settings.VERSION}")
78
+ logger.info("=" * 50)
79
+
80
+ # Проверяем существование модели
81
+ if not os.path.exists(settings.MODEL_PATH):
82
+ logger.warning(f"⚠️ Модель не найдена: {settings.MODEL_PATH}")
83
+ logger.info("Пожалуйста, обучите модель командой: python train_model.py")
84
+ else:
85
+ try:
86
+ # Предварительная загрузка модели
87
+ from app.models.wagon_model import get_classifier
88
+ classifier = get_classifier()
89
+ logger.info(f"✅ Модель загружена на устройство: {classifier.device}")
90
+ logger.info(f"📋 Доступные классы: {classifier.class_names}")
91
+ except Exception as e:
92
+ logger.error(f"❌ Ошибка при загрузке модели: {e}")
93
+
94
+
95
+ @app.on_event("shutdown")
96
+ async def shutdown_event():
97
+ """Очистка при завершении"""
98
+ logger.info("Остановка API сервиса")
99
+
100
+
101
+ if __name__ == "__main__":
102
+ import uvicorn
103
+ uvicorn.run(
104
+ "app.main:app",
105
+ host="0.0.0.0",
106
+ port=8000,
107
+ reload=True
108
+ )
app/models/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Модели данных и ML модель"""
2
+
3
+ from app.models.wagon_model import WagonClassifier, get_classifier
4
+ from app.models.schemas import PredictionResponse, ErrorResponse, HealthResponse
5
+
6
+ __all__ = [
7
+ 'WagonClassifier',
8
+ 'get_classifier',
9
+ 'PredictionResponse',
10
+ 'ErrorResponse',
11
+ 'HealthResponse'
12
+ ]
app/models/schemas.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic схемы для валидации данных API
3
+ """
4
+
5
+ from pydantic import BaseModel, Field
6
+ from typing import Dict, Optional, List
7
+ from datetime import datetime
8
+
9
+
10
+ class PredictionResponse(BaseModel):
11
+ """Ответ API с предсказанием"""
12
+ status: str = Field(..., example="success")
13
+ data: Dict = Field(..., description="Результат классификации")
14
+ timestamp: datetime = Field(default_factory=datetime.now)
15
+ request_id: Optional[str] = Field(None, description="Уникальный идентификатор запроса")
16
+
17
+ class Config:
18
+ json_schema_extra = {
19
+ "example": {
20
+ "status": "success",
21
+ "data": {
22
+ "class": "pered",
23
+ "class_name": "передняя часть вагона",
24
+ "confidence": 0.95,
25
+ "probabilities": {
26
+ "pered": 0.95,
27
+ "zad": 0.03,
28
+ "none": 0.02
29
+ }
30
+ },
31
+ "timestamp": "2024-01-15T10:30:00",
32
+ "request_id": "123e4567-e89b-12d3-a456-426614174000"
33
+ }
34
+ }
35
+
36
+
37
+ class ErrorResponse(BaseModel):
38
+ """Ответ при ошибке"""
39
+ status: str = Field(..., example="error")
40
+ error: Dict = Field(..., description="Детали ошибки")
41
+ timestamp: datetime = Field(default_factory=datetime.now)
42
+
43
+ class Config:
44
+ json_schema_extra = {
45
+ "example": {
46
+ "status": "error",
47
+ "error": {
48
+ "code": "INVALID_IMAGE",
49
+ "message": "Файл не является корректным изображением",
50
+ "details": "Поддерживаются форматы: jpg, jpeg, png"
51
+ },
52
+ "timestamp": "2024-01-15T10:30:00"
53
+ }
54
+ }
55
+
56
+
57
+ class HealthResponse(BaseModel):
58
+ """Проверка здоровья сервиса"""
59
+ status: str = Field(..., description="Статус сервиса")
60
+ model_loaded: bool = Field(..., description="Загружена ли модель")
61
+ device: str = Field(..., description="Устройство выполнения")
62
+ version: str = Field(..., description="Версия API")
63
+
64
+
65
+ class BatchPredictionResponse(BaseModel):
66
+ """Ответ для пакетной классификации"""
67
+ status: str = Field(..., example="success")
68
+ results: List[Dict] = Field(..., description="Результаты для каждого файла")
69
+ total: int = Field(..., description="Всего файлов")
70
+ successful: int = Field(..., description="Успешно обработано")
app/models/wagon_model.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Обертка для модели машинного обучения
3
+ Загружает обученную модель и выполняет предсказания
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision import models, transforms
9
+ from PIL import Image
10
+ import os
11
+ import logging
12
+ from typing import Dict, Tuple, List
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class WagonClassifier:
18
+ """
19
+ Классификатор вагонов
20
+ Загружает обученную модель и выполняет инференс
21
+ """
22
+
23
+ def __init__(self, model_path: str, class_names: List[str], device: str = None):
24
+ """
25
+ Инициализация классификатора
26
+
27
+ Args:
28
+ model_path: Путь к файлу с весами модели (.pth)
29
+ class_names: Список названий классов
30
+ device: Устройство для выполнения (cuda/cpu)
31
+ """
32
+ self.model_path = model_path
33
+ self.class_names = class_names
34
+ self.num_classes = len(class_names)
35
+
36
+ # Определяем устройство
37
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
38
+ logger.info(f"Используется устройство: {self.device}")
39
+
40
+ # Загружаем модель
41
+ self.model = self._load_model()
42
+
43
+ # Трансформации для изображений
44
+ self.transform = transforms.Compose([
45
+ transforms.Resize((224, 224)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(
48
+ mean=[0.485, 0.456, 0.406],
49
+ std=[0.229, 0.224, 0.225]
50
+ )
51
+ ])
52
+
53
+ # Русские названия классов для вывода
54
+ self.class_names_ru = {
55
+ 'pered': 'передняя часть вагона',
56
+ 'zad': 'задняя часть вагона',
57
+ 'none': 'вагон не обнаружен'
58
+ }
59
+
60
+ logger.info(f"Модель загружена. Классы: {self.class_names}")
61
+
62
+ def _load_model(self) -> nn.Module:
63
+ """
64
+ Загрузка модели из файла
65
+
66
+ Returns:
67
+ Загруженная модель в режиме evaluation
68
+ """
69
+ # Создаем архитектуру модели (должна совпадать с train_model.py)
70
+ model = models.efficientnet_b2(weights=None)
71
+ in_features = model.classifier[1].in_features
72
+ model.classifier = nn.Sequential(
73
+ nn.Dropout(p=0.3),
74
+ nn.Linear(in_features, self.num_classes)
75
+ )
76
+
77
+ # Проверяем существование файла
78
+ if not os.path.exists(self.model_path):
79
+ raise FileNotFoundError(f"Модель не найдена: {self.model_path}")
80
+
81
+ # Загружаем веса
82
+ checkpoint = torch.load(self.model_path, map_location=self.device)
83
+
84
+ # Поддерживаем разные форматы сохранения
85
+ if 'model_state_dict' in checkpoint:
86
+ model.load_state_dict(checkpoint['model_state_dict'])
87
+ else:
88
+ model.load_state_dict(checkpoint)
89
+
90
+ # Перемещаем на устройство и переводим в режим оценки
91
+ model = model.to(self.device)
92
+ model.eval()
93
+
94
+ return model
95
+
96
+ def _preprocess_image(self, image: Image.Image) -> torch.Tensor:
97
+ """
98
+ Предобработка изображения перед подачей в модель
99
+
100
+ Args:
101
+ image: PIL Image
102
+
103
+ Returns:
104
+ Тензор, готовый для инференса
105
+ """
106
+ # Конвертируем в RGB если нужно
107
+ if image.mode != 'RGB':
108
+ image = image.convert('RGB')
109
+
110
+ # Применяем трансформации
111
+ input_tensor = self.transform(image)
112
+ input_tensor = input_tensor.unsqueeze(0) # Добавляем batch dimension
113
+ input_tensor = input_tensor.to(self.device)
114
+
115
+ return input_tensor
116
+
117
+ def predict(self, image: Image.Image) -> Tuple[str, float, Dict[str, float]]:
118
+ """
119
+ Предсказание для одного изображения
120
+
121
+ Args:
122
+ image: PIL Image
123
+
124
+ Returns:
125
+ Tuple: (предсказанный_класс, уверенность, словарь_вероятностей)
126
+ """
127
+ try:
128
+ # Предобработка
129
+ input_tensor = self._preprocess_image(image)
130
+
131
+ # Инференс
132
+ with torch.no_grad():
133
+ outputs = self.model(input_tensor)
134
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
135
+
136
+ # Получаем предсказание
137
+ predicted_idx = torch.argmax(probabilities, dim=1).item()
138
+ confidence = probabilities[0][predicted_idx].item()
139
+ predicted_class = self.class_names[predicted_idx]
140
+
141
+ # Все вероятности
142
+ all_probs = {
143
+ class_name: probabilities[0][i].item()
144
+ for i, class_name in enumerate(self.class_names)
145
+ }
146
+
147
+ logger.info(f"Предсказание: {predicted_class} с уверенностью {confidence:.2%}")
148
+
149
+ return predicted_class, confidence, all_probs
150
+
151
+ except Exception as e:
152
+ logger.error(f"Ошибка при предсказании: {e}")
153
+ raise
154
+
155
+ def predict_batch(self, images: List[Image.Image]) -> List[Dict]:
156
+ """
157
+ Пакетное предсказание для нескольких изображений
158
+
159
+ Args:
160
+ images: Список PIL Image
161
+
162
+ Returns:
163
+ Список результатов для каждого изображения
164
+ """
165
+ results = []
166
+ for image in images:
167
+ pred_class, confidence, probs = self.predict(image)
168
+ results.append({
169
+ 'class': pred_class,
170
+ 'class_name': self.class_names_ru.get(pred_class, pred_class),
171
+ 'confidence': confidence,
172
+ 'probabilities': probs
173
+ })
174
+ return results
175
+
176
+
177
+ # Глобальный экземпляр модели (синглтон)
178
+ _classifier_instance = None
179
+
180
+
181
+ def get_classifier() -> WagonClassifier:
182
+ """
183
+ Получить экземпляр классификатора (синглтон)
184
+ Модель загружается только один раз при первом вызове
185
+
186
+ Returns:
187
+ Экземпляр WagonClassifier
188
+ """
189
+ global _classifier_instance
190
+
191
+ if _classifier_instance is None:
192
+ from app.config import settings
193
+
194
+ _classifier_instance = WagonClassifier(
195
+ model_path=settings.MODEL_PATH,
196
+ class_names=settings.CLASS_NAMES
197
+ )
198
+
199
+ return _classifier_instance
app/services/__init__.py ADDED
File without changes
app/services/prediction_service.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Сервис для работы с предсказаниями
3
+ Содержит бизнес-логику обработки изображений
4
+ """
5
+
6
+ import logging
7
+ from typing import Dict, Any, List
8
+ from PIL import Image
9
+
10
+ from app.models.wagon_model import WagonClassifier
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class PredictionService:
16
+ """Сервис для выполнения предсказаний"""
17
+
18
+ def __init__(self, classifier: WagonClassifier):
19
+ self.classifier = classifier
20
+
21
+ def predict_single(self, image: Image.Image) -> Dict[str, Any]:
22
+ """
23
+ Предсказание для одного изображения
24
+
25
+ Args:
26
+ image: PIL Image
27
+
28
+ Returns:
29
+ Словарь с результатами предсказания
30
+ """
31
+ predicted_class, confidence, probabilities = self.classifier.predict(image)
32
+
33
+ return {
34
+ "class": predicted_class,
35
+ "class_name": self.classifier.class_names_ru.get(predicted_class, predicted_class),
36
+ "confidence": confidence,
37
+ "probabilities": probabilities
38
+ }
39
+
40
+ def predict_batch(self, images: List[Image.Image]) -> List[Dict[str, Any]]:
41
+ """
42
+ Предсказание для нескольких изображений
43
+
44
+ Args:
45
+ images: Список PIL Image
46
+
47
+ Returns:
48
+ Список результатов
49
+ """
50
+ results = []
51
+ for image in images:
52
+ try:
53
+ result = self.predict_single(image)
54
+ results.append(result)
55
+ except Exception as e:
56
+ logger.error(f"Ошибка при предсказании: {e}")
57
+ results.append({"error": str(e)})
58
+
59
+ return results
60
+
61
+ def get_model_info(self) -> Dict[str, Any]:
62
+ """Получить информацию о модели"""
63
+ return {
64
+ "device": self.classifier.device,
65
+ "classes": self.classifier.class_names,
66
+ "num_classes": self.classifier.num_classes
67
+ }
app/static/index.html ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="ru">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
6
+ <title>Классификатор вагонов - WagonDetector</title>
7
+ <link rel="stylesheet" href="/static/style.css">
8
+ </head>
9
+ <body>
10
+ <div class="container">
11
+ <header>
12
+ <h1>🚂 Классификатор вагонов</h1>
13
+ <p>Определение передней и задней части вагона по фотографии</p>
14
+ </header>
15
+
16
+ <div class="upload-area" id="uploadArea">
17
+ <div class="upload-content">
18
+ <div class="upload-icon">📸</div>
19
+ <h3>Загрузите изображение вагона</h3>
20
+ <p>Перетащите файл сюда или нажмите для выбора</p>
21
+ <p class="file-info">Поддерживаются: JPG, JPEG, PNG (до 10 MB)</p>
22
+ <input type="file" id="fileInput" accept="image/jpeg,image/jpg,image/png" hidden>
23
+ <button class="btn btn-primary" id="selectFileBtn">Выбрать файл</button>
24
+ </div>
25
+ </div>
26
+
27
+ <div class="preview-area" id="previewArea" style="display: none;">
28
+ <div class="preview-image">
29
+ <img id="previewImg" alt="Предпросмотр">
30
+ <button class="btn-remove" id="removeImageBtn">✕</button>
31
+ </div>
32
+ <button class="btn btn-success" id="predictBtn">🔍 Распознать вагон</button>
33
+ </div>
34
+
35
+ <div class="results-area" id="resultsArea" style="display: none;">
36
+ <h3>Результат классификации</h3>
37
+ <div class="result-card">
38
+ <div class="result-class" id="resultClass"></div>
39
+ <div class="result-confidence" id="resultConfidence"></div>
40
+ <div class="probabilities" id="probabilities"></div>
41
+ </div>
42
+ </div>
43
+
44
+ <div class="loading" id="loading" style="display: none;">
45
+ <div class="spinner"></div>
46
+ <p>Обработка изображения...</p>
47
+ </div>
48
+
49
+ <div class="error-message" id="errorMessage" style="display: none;"></div>
50
+ </div>
51
+
52
+ <script src="/static/script.js"></script>
53
+ </body>
54
+ </html>
app/static/script.js ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // DOM элементы
2
+ const uploadArea = document.getElementById('uploadArea');
3
+ const fileInput = document.getElementById('fileInput');
4
+ const selectFileBtn = document.getElementById('selectFileBtn');
5
+ const previewArea = document.getElementById('previewArea');
6
+ const previewImg = document.getElementById('previewImg');
7
+ const removeImageBtn = document.getElementById('removeImageBtn');
8
+ const predictBtn = document.getElementById('predictBtn');
9
+ const resultsArea = document.getElementById('resultsArea');
10
+ const loading = document.getElementById('loading');
11
+ const errorMessage = document.getElementById('errorMessage');
12
+
13
+ let currentFile = null;
14
+
15
+ // API базовый URL
16
+ const API_URL = window.location.origin + '/api/v1';
17
+
18
+ // Обработчики событий
19
+ selectFileBtn.addEventListener('click', () => fileInput.click());
20
+ removeImageBtn.addEventListener('click', clearImage);
21
+ predictBtn.addEventListener('click', predictImage);
22
+ fileInput.addEventListener('change', handleFileSelect);
23
+
24
+ // Drag & Drop
25
+ uploadArea.addEventListener('dragover', (e) => {
26
+ e.preventDefault();
27
+ uploadArea.classList.add('drag-over');
28
+ });
29
+
30
+ uploadArea.addEventListener('dragleave', () => {
31
+ uploadArea.classList.remove('drag-over');
32
+ });
33
+
34
+ uploadArea.addEventListener('drop', (e) => {
35
+ e.preventDefault();
36
+ uploadArea.classList.remove('drag-over');
37
+ const file = e.dataTransfer.files[0];
38
+ if (file && file.type.startsWith('image/')) {
39
+ handleFile(file);
40
+ } else {
41
+ showError('Пожалуйста, загрузите изображение');
42
+ }
43
+ });
44
+
45
+ uploadArea.addEventListener('click', () => fileInput.click());
46
+
47
+ function handleFileSelect(e) {
48
+ const file = e.target.files[0];
49
+ if (file) {
50
+ handleFile(file);
51
+ }
52
+ }
53
+
54
+ function handleFile(file) {
55
+ // Проверка размера (10 MB)
56
+ if (file.size > 10 * 1024 * 1024) {
57
+ showError('Файл слишком большой. Максимум 10 MB');
58
+ return;
59
+ }
60
+
61
+ currentFile = file;
62
+
63
+ // Предпросмотр
64
+ const reader = new FileReader();
65
+ reader.onload = (e) => {
66
+ previewImg.src = e.target.result;
67
+ uploadArea.style.display = 'none';
68
+ previewArea.style.display = 'block';
69
+ resultsArea.style.display = 'none';
70
+ hideError();
71
+ };
72
+ reader.readAsDataURL(file);
73
+ }
74
+
75
+ function clearImage() {
76
+ currentFile = null;
77
+ fileInput.value = '';
78
+ previewArea.style.display = 'none';
79
+ uploadArea.style.display = 'block';
80
+ resultsArea.style.display = 'none';
81
+ hideError();
82
+ }
83
+
84
+ async function predictImage() {
85
+ if (!currentFile) {
86
+ showError('Сначала выберите изображение');
87
+ return;
88
+ }
89
+
90
+ // Показываем загрузку
91
+ loading.style.display = 'block';
92
+ resultsArea.style.display = 'none';
93
+ hideError();
94
+
95
+ // Создаем FormData
96
+ const formData = new FormData();
97
+ formData.append('file', currentFile);
98
+
99
+ try {
100
+ // Отправляем запрос
101
+ const response = await fetch(`${API_URL}/predict`, {
102
+ method: 'POST',
103
+ body: formData
104
+ });
105
+
106
+ const data = await response.json();
107
+
108
+ if (!response.ok) {
109
+ throw new Error(data.error?.message || 'Ошибка при обработке');
110
+ }
111
+
112
+ // Отображаем результаты
113
+ displayResults(data.data);
114
+
115
+ } catch (error) {
116
+ console.error('Error:', error);
117
+ showError(error.message || 'Ошибка при отправке запроса');
118
+ } finally {
119
+ loading.style.display = 'none';
120
+ }
121
+ }
122
+
123
+ function displayResults(data) {
124
+ const resultClass = document.getElementById('resultClass');
125
+ const resultConfidence = document.getElementById('resultConfidence');
126
+ const probabilitiesDiv = document.getElementById('probabilities');
127
+
128
+ // Определяем эмодзи для класса
129
+ let emoji = '';
130
+ if (data.class === 'pered') emoji = '🚂 Передняя часть';
131
+ else if (data.class === 'zad') emoji = '🚂 Задняя часть';
132
+ else emoji = '⭕ Вагон не обнаружен';
133
+
134
+ // Отображаем основной результат
135
+ resultClass.innerHTML = `${emoji}<br>${data.class_name}`;
136
+ resultConfidence.innerHTML = `Уверенность: <strong>${(data.confidence * 100).toFixed(1)}%</strong>`;
137
+
138
+ // Отображаем все вероятности
139
+ probabilitiesDiv.innerHTML = '<h4>Распределение вероятностей:</h4>';
140
+
141
+ const classNames = {
142
+ 'pered': 'Передняя часть',
143
+ 'zad': 'Задняя часть',
144
+ 'none': 'Вагон не обнаружен'
145
+ };
146
+
147
+ for (const [cls, prob] of Object.entries(data.probabilities)) {
148
+ const percent = (prob * 100).toFixed(1);
149
+ const isPredicted = cls === data.class;
150
+
151
+ const probItem = document.createElement('div');
152
+ probItem.className = 'prob-item';
153
+ probItem.innerHTML = `
154
+ <div class="prob-label">${classNames[cls] || cls}</div>
155
+ <div class="prob-bar">
156
+ <div class="prob-fill" style="width: ${percent}%; background: ${isPredicted ? 'linear-gradient(90deg, #667eea, #764ba2)' : '#cbd5e0'}">
157
+ ${percent}%
158
+ </div>
159
+ </div>
160
+ `;
161
+ probabilitiesDiv.appendChild(probItem);
162
+ }
163
+
164
+ // Показываем результаты
165
+ resultsArea.style.display = 'block';
166
+
167
+ // Прокручиваем к результатам
168
+ resultsArea.scrollIntoView({ behavior: 'smooth' });
169
+ }
170
+
171
+ function showError(message) {
172
+ errorMessage.textContent = message;
173
+ errorMessage.style.display = 'block';
174
+ setTimeout(() => {
175
+ errorMessage.style.display = 'none';
176
+ }, 5000);
177
+ }
178
+
179
+ function hideError() {
180
+ errorMessage.style.display = 'none';
181
+ }
182
+
183
+ // Проверка здоровья API при загрузке
184
+ async function checkHealth() {
185
+ try {
186
+ const response = await fetch(`${API_URL}/health`);
187
+ const data = await response.json();
188
+ if (data.status === 'healthy') {
189
+ console.log('✅ API готов к работе');
190
+ } else {
191
+ console.warn('⚠️ API не здоров:', data);
192
+ showError('Сервис временно недоступен');
193
+ }
194
+ } catch (error) {
195
+ console.error('❌ API недоступен:', error);
196
+ showError('Не удалось подключиться к серверу');
197
+ }
198
+ }
199
+
200
+ // Запускаем проверку при загрузке
201
+ checkHealth();
app/static/style.css ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ margin: 0;
3
+ padding: 0;
4
+ box-sizing: border-box;
5
+ }
6
+
7
+ body {
8
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
9
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
10
+ min-height: 100vh;
11
+ padding: 20px;
12
+ }
13
+
14
+ .container {
15
+ max-width: 800px;
16
+ margin: 0 auto;
17
+ background: white;
18
+ border-radius: 20px;
19
+ box-shadow: 0 20px 60px rgba(0,0,0,0.3);
20
+ overflow: hidden;
21
+ padding: 30px;
22
+ }
23
+
24
+ header {
25
+ text-align: center;
26
+ margin-bottom: 30px;
27
+ }
28
+
29
+ header h1 {
30
+ color: #333;
31
+ font-size: 28px;
32
+ margin-bottom: 10px;
33
+ }
34
+
35
+ header p {
36
+ color: #666;
37
+ font-size: 14px;
38
+ }
39
+
40
+ .upload-area {
41
+ border: 2px dashed #ddd;
42
+ border-radius: 10px;
43
+ padding: 40px;
44
+ text-align: center;
45
+ transition: all 0.3s ease;
46
+ cursor: pointer;
47
+ }
48
+
49
+ .upload-area:hover {
50
+ border-color: #667eea;
51
+ background: #f8f9ff;
52
+ }
53
+
54
+ .upload-area.drag-over {
55
+ border-color: #667eea;
56
+ background: #f0f2ff;
57
+ }
58
+
59
+ .upload-icon {
60
+ font-size: 48px;
61
+ margin-bottom: 15px;
62
+ }
63
+
64
+ .upload-content h3 {
65
+ color: #333;
66
+ margin-bottom: 10px;
67
+ }
68
+
69
+ .upload-content p {
70
+ color: #666;
71
+ margin-bottom: 5px;
72
+ }
73
+
74
+ .file-info {
75
+ font-size: 12px;
76
+ color: #999;
77
+ }
78
+
79
+ .btn {
80
+ padding: 10px 20px;
81
+ border: none;
82
+ border-radius: 5px;
83
+ font-size: 14px;
84
+ cursor: pointer;
85
+ transition: all 0.3s ease;
86
+ margin-top: 15px;
87
+ }
88
+
89
+ .btn-primary {
90
+ background: #667eea;
91
+ color: white;
92
+ }
93
+
94
+ .btn-primary:hover {
95
+ background: #5a67d8;
96
+ transform: translateY(-2px);
97
+ }
98
+
99
+ .btn-success {
100
+ background: #48bb78;
101
+ color: white;
102
+ font-size: 16px;
103
+ padding: 12px 30px;
104
+ }
105
+
106
+ .btn-success:hover {
107
+ background: #38a169;
108
+ transform: translateY(-2px);
109
+ }
110
+
111
+ .preview-area {
112
+ text-align: center;
113
+ margin-top: 20px;
114
+ }
115
+
116
+ .preview-image {
117
+ position: relative;
118
+ display: inline-block;
119
+ margin-bottom: 20px;
120
+ }
121
+
122
+ .preview-image img {
123
+ max-width: 100%;
124
+ max-height: 400px;
125
+ border-radius: 10px;
126
+ box-shadow: 0 5px 15px rgba(0,0,0,0.2);
127
+ }
128
+
129
+ .btn-remove {
130
+ position: absolute;
131
+ top: -10px;
132
+ right: -10px;
133
+ width: 30px;
134
+ height: 30px;
135
+ border-radius: 50%;
136
+ background: #f56565;
137
+ color: white;
138
+ border: none;
139
+ cursor: pointer;
140
+ font-size: 18px;
141
+ transition: all 0.3s ease;
142
+ }
143
+
144
+ .btn-remove:hover {
145
+ background: #e53e3e;
146
+ transform: scale(1.1);
147
+ }
148
+
149
+ .results-area {
150
+ margin-top: 30px;
151
+ }
152
+
153
+ .results-area h3 {
154
+ text-align: center;
155
+ color: #333;
156
+ margin-bottom: 20px;
157
+ }
158
+
159
+ .result-card {
160
+ background: #f7fafc;
161
+ border-radius: 10px;
162
+ padding: 20px;
163
+ text-align: center;
164
+ }
165
+
166
+ .result-class {
167
+ font-size: 24px;
168
+ font-weight: bold;
169
+ color: #667eea;
170
+ margin-bottom: 10px;
171
+ }
172
+
173
+ .result-confidence {
174
+ font-size: 18px;
175
+ color: #48bb78;
176
+ margin-bottom: 20px;
177
+ }
178
+
179
+ .probabilities {
180
+ text-align: left;
181
+ margin-top: 20px;
182
+ }
183
+
184
+ .probabilities h4 {
185
+ margin-bottom: 10px;
186
+ color: #4a5568;
187
+ font-size: 14px;
188
+ }
189
+
190
+ .prob-item {
191
+ margin-bottom: 10px;
192
+ }
193
+
194
+ .prob-label {
195
+ font-size: 14px;
196
+ color: #666;
197
+ margin-bottom: 5px;
198
+ }
199
+
200
+ .prob-bar {
201
+ background: #e2e8f0;
202
+ height: 30px;
203
+ border-radius: 5px;
204
+ overflow: hidden;
205
+ position: relative;
206
+ }
207
+
208
+ .prob-fill {
209
+ background: linear-gradient(90deg, #667eea, #764ba2);
210
+ height: 100%;
211
+ display: flex;
212
+ align-items: center;
213
+ justify-content: flex-end;
214
+ padding-right: 10px;
215
+ color: white;
216
+ font-size: 12px;
217
+ font-weight: bold;
218
+ transition: width 0.5s ease;
219
+ }
220
+
221
+ .loading {
222
+ text-align: center;
223
+ padding: 40px;
224
+ }
225
+
226
+ .spinner {
227
+ border: 4px solid #f3f3f3;
228
+ border-top: 4px solid #667eea;
229
+ border-radius: 50%;
230
+ width: 50px;
231
+ height: 50px;
232
+ animation: spin 1s linear infinite;
233
+ margin: 0 auto 20px;
234
+ }
235
+
236
+ @keyframes spin {
237
+ 0% { transform: rotate(0deg); }
238
+ 100% { transform: rotate(360deg); }
239
+ }
240
+
241
+ .error-message {
242
+ background: #fed7d7;
243
+ color: #c53030;
244
+ padding: 15px;
245
+ border-radius: 10px;
246
+ margin-top: 20px;
247
+ text-align: center;
248
+ }
249
+
250
+ @media (max-width: 768px) {
251
+ .container {
252
+ padding: 20px;
253
+ }
254
+
255
+ .upload-area {
256
+ padding: 20px;
257
+ }
258
+
259
+ .result-class {
260
+ font-size: 20px;
261
+ }
262
+ }
app/utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Вспомогательные утилиты"""
2
+
3
+ from app.utils.image_utils import validate_image_file, process_image
4
+ from app.utils.logger import setup_logging
5
+
6
+ __all__ = ['validate_image_file', 'process_image', 'setup_logging']
app/utils/image_utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Утилиты для работы с изображениями
3
+ """
4
+
5
+ import os
6
+ import io
7
+ import logging
8
+ from fastapi import UploadFile, HTTPException, status
9
+ from PIL import Image, ImageFile
10
+
11
+ # Разрешаем загрузку усеченных изображений
12
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def validate_image_file(file: UploadFile, settings) -> bool:
18
+ """
19
+ Проверка валидности файла изображения
20
+
21
+ Args:
22
+ file: Загруженный файл
23
+ settings: Настройки приложения
24
+
25
+ Returns:
26
+ True если файл валиден
27
+
28
+ Raises:
29
+ HTTPException при ошибках валидации
30
+ """
31
+ # Проверяем расширение
32
+ ext = os.path.splitext(file.filename)[1].lower()
33
+ if ext not in settings.ALLOWED_EXTENSIONS:
34
+ raise HTTPException(
35
+ status_code=status.HTTP_400_BAD_REQUEST,
36
+ detail={
37
+ "code": "INVALID_EXTENSION",
38
+ "message": f"Неподдерживаемый формат. Разрешенные: {', '.join(settings.ALLOWED_EXTENSIONS)}"
39
+ }
40
+ )
41
+
42
+ # Проверяем размер
43
+ file.file.seek(0, 2)
44
+ size = file.file.tell()
45
+ file.file.seek(0)
46
+
47
+ if size > settings.MAX_UPLOAD_SIZE:
48
+ raise HTTPException(
49
+ status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
50
+ detail={
51
+ "code": "FILE_TOO_LARGE",
52
+ "message": f"Файл слишком большой. Максимум: {settings.MAX_UPLOAD_SIZE // (1024*1024)} MB"
53
+ }
54
+ )
55
+
56
+ return True
57
+
58
+
59
+ def process_image(file: UploadFile) -> Image.Image:
60
+ """
61
+ Загрузка и предобработка изображения
62
+
63
+ Args:
64
+ file: Загруженный файл
65
+
66
+ Returns:
67
+ PIL Image объект
68
+
69
+ Raises:
70
+ HTTPException при ошибках загрузки
71
+ """
72
+ try:
73
+ # Читаем файл
74
+ contents = file.file.read()
75
+
76
+ # Пытаемся открыть как изображение
77
+ image = Image.open(io.BytesIO(contents))
78
+
79
+ # Проверяем, что изображение можно прочитать
80
+ image.verify()
81
+
82
+ # Переоткрываем (после verify нужно заново)
83
+ image = Image.open(io.BytesIO(contents))
84
+
85
+ return image
86
+
87
+ except Exception as e:
88
+ logger.error(f"Ошибка загрузки изображения: {e}")
89
+ raise HTTPException(
90
+ status_code=status.HTTP_400_BAD_REQUEST,
91
+ detail={
92
+ "code": "INVALID_IMAGE",
93
+ "message": "Файл не является корректным изображением"
94
+ }
95
+ )
app/utils/logger.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Настройка логирования
3
+ """
4
+
5
+ import logging
6
+ import sys
7
+ from typing import Optional
8
+
9
+
10
+ def setup_logging(level: Optional[str] = None):
11
+ """
12
+ Настройка логирования для приложения
13
+
14
+ Args:
15
+ level: Уровень логирования (DEBUG, INFO, WARNING, ERROR)
16
+ """
17
+ log_level = level or logging.INFO
18
+
19
+ # Настройка формата
20
+ formatter = logging.Formatter(
21
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
22
+ datefmt='%Y-%m-%d %H:%M:%S'
23
+ )
24
+
25
+ # Обработчик для stdout
26
+ console_handler = logging.StreamHandler(sys.stdout)
27
+ console_handler.setFormatter(formatter)
28
+
29
+ # Настройка корневого логгера
30
+ root_logger = logging.getLogger()
31
+ root_logger.setLevel(log_level)
32
+ root_logger.addHandler(console_handler)
33
+
34
+ # Уменьшаем логи от библиотек
35
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
36
+ logging.getLogger("torch").setLevel(logging.WARNING)
37
+
38
+ return root_logger
docker-compose.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ api:
5
+ build: .
6
+ ports:
7
+ - "8000:8000"
8
+ volumes:
9
+ - ./models:/app/models
10
+ - ./uploads:/app/uploads
11
+ - ./app:/app/app
12
+ - ./logs:/app/logs
13
+ environment:
14
+ - MODEL_PATH=/app/models/best_model.pth
15
+ - LOG_LEVEL=INFO
16
+ restart: unless-stopped
17
+ healthcheck:
18
+ test: ["CMD", "curl", "-f", "http://localhost:8000/api/v1/health"]
19
+ interval: 30s
20
+ timeout: 10s
21
+ retries: 3
22
+ start_period: 40s
models/best_model.pth ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow
2
+ opencv-python
3
+ matplotlib
4
+ numpy
5
+ scikit-learn
6
+ torch
7
+ torchvision
8
+ seaborn
9
+ tqdm
10
+ Pillow
11
+ pandas
12
+ patool
13
+ unrar
14
+ winrar
15
+ fastapi
16
+ uvicorn[standard]
17
+ python-multipart
18
+ python-dotenv
19
+ pydantic
20
+ pydantic-settings
21
+ aiofiles
22
+ python-jose[cryptography]
23
+ python-magic
24
+ pytest
25
+ pytest-asyncio
26
+ httpx
27
+
tests/__init__.py ADDED
File without changes
tests/test_api.py ADDED
File without changes
tests/test_model.py ADDED
File without changes
train_model.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from torchvision import models, transforms
6
+ import torch.cuda.amp as amp
7
+ import os
8
+ import shutil
9
+ import numpy as np
10
+ from PIL import Image, ImageFile
11
+ import matplotlib.pyplot as plt
12
+ from tqdm import tqdm
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.metrics import confusion_matrix, classification_report
15
+ import seaborn as sns
16
+ import warnings
17
+ warnings.filterwarnings('ignore')
18
+
19
+ # ================================================
20
+ # ВКЛЮЧАЕМ ОБРАБОТКУ УСЕЧЕННЫХ ИЗОБРАЖЕНИЙ
21
+ # ================================================
22
+ ImageFile.LOAD_TRUNCATED_IMAGES = True # Разрешаем загрузку усеченных изображений
23
+
24
+ # ================================================
25
+ # КОНФИГУРАЦИЯ (С ПРАВИЛЬНЫМИ НАЗВАНИЯМИ КЛАССОВ)
26
+ # ================================================
27
+ class Config:
28
+ # Пути (изменены для Windows)
29
+ BASE_DIR = os.path.join(os.getcwd(), 'wagon_classification')
30
+ DATA_DIR = os.path.join(os.getcwd(), 'wagon_classification', 'data', 'processed')
31
+ EXTRACTED_DIR = os.path.join(os.getcwd(), 'wagon_data', 'extracted')
32
+ MODEL_SAVE_PATH = os.path.join(os.getcwd(), 'wagon_classification', 'best_model.pth')
33
+
34
+ # Параметры - ИСПРАВЛЕНО: pered вместо prered
35
+ CLASS_NAMES = ['pered', 'zad', 'none'] # Изменено prered -> pered
36
+ NUM_CLASSES = 3
37
+
38
+ # Гиперпараметры
39
+ BATCH_SIZE = 32 # Оптимальный размер для T4
40
+ NUM_EPOCHS = 15
41
+
42
+ # Устройство
43
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
+
45
+ @staticmethod
46
+ def print_info():
47
+ print(f"\n📊 КОНФИГУРАЦИЯ:")
48
+ print(f" • Устройство: {Config.DEVICE}")
49
+ if Config.DEVICE.type == 'cuda':
50
+ print(f" • GPU: {torch.cuda.get_device_name(0)}")
51
+ print(f" • Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
52
+ print(f" • Классы: {Config.CLASS_NAMES}")
53
+ print(f" • Batch size: {Config.BATCH_SIZE}")
54
+ print(f" • Эпох: {Config.NUM_EPOCHS}")
55
+
56
+ # ================================================
57
+ # УТИЛИТЫ ДЛЯ РАБОТЫ С ИЗОБРАЖЕНИЯМИ
58
+ # ================================================
59
+ def load_image_safe(image_path, target_size=(224, 224)):
60
+ """Безопасная загрузка изображения с обработкой ошибок"""
61
+ try:
62
+ # Пытаемся открыть изображение
63
+ image = Image.open(image_path)
64
+
65
+ # Проверяем, что изображение валидно
66
+ image.verify() # Проверка целостности файла
67
+
68
+ # Закрываем и открываем снова (после verify нужно переоткрыть)
69
+ image = Image.open(image_path)
70
+
71
+ # Конвертируем в RGB если нужно
72
+ if image.mode != 'RGB':
73
+ image = image.convert('RGB')
74
+
75
+ # Проверяем размеры
76
+ if image.size[0] == 0 or image.size[1] == 0:
77
+ print(f"⚠ Изображение {image_path} имеет нулевые размеры")
78
+ # Создаем черное изображение
79
+ image = Image.new('RGB', target_size, color='black')
80
+
81
+ return image
82
+
83
+ except (IOError, OSError, Image.DecompressionBombError) as e:
84
+ print(f"⚠ Ошибка загрузки {image_path}: {e}")
85
+ # Создаем черное изображение в случае ошибки
86
+ return Image.new('RGB', target_size, color='black')
87
+
88
+ except Exception as e:
89
+ print(f"⚠ Неизвестная ошибка при загрузке {image_path}: {e}")
90
+ return Image.new('RGB', target_size, color='black')
91
+
92
+ def repair_image_file(image_path):
93
+ """Пытается восстановить поврежденный файл изображения"""
94
+ try:
95
+ with open(image_path, 'rb') as f:
96
+ data = f.read()
97
+
98
+ # Проверяем, что файл не пустой
99
+ if len(data) == 0:
100
+ print(f"❌ Файл {image_path} пустой")
101
+ return False
102
+
103
+ # Пытаемся восстановить как JPEG
104
+ if image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'):
105
+ # Добавляем маркер конца JPEG если нужно
106
+ if not data.endswith(b'\xff\xd9'):
107
+ print(f"⚠ Восстанавливаю JPEG файл {image_path}")
108
+ data += b'\xff\xd9'
109
+ with open(image_path, 'wb') as f:
110
+ f.write(data)
111
+ return True
112
+
113
+ return False
114
+
115
+ except Exception as e:
116
+ print(f"❌ Ошибка при восстановлении {image_path}: {e}")
117
+ return False
118
+
119
+ # ================================================
120
+ # ТРАНСФОРМАЦИИ (ПРОСТЫЕ И РАБОЧИЕ)
121
+ # ================================================
122
+ def get_transforms():
123
+ """Создание простых и рабочих трансформаций"""
124
+ # Обучающие трансформации
125
+ train_transform = transforms.Compose([
126
+ transforms.Resize((256, 256)),
127
+ transforms.RandomCrop(224),
128
+ transforms.RandomHorizontalFlip(p=0.5),
129
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
130
+ transforms.ToTensor(),
131
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
132
+ std=[0.229, 0.224, 0.225])
133
+ ])
134
+
135
+ # Валидационные трансформации
136
+ val_transform = transforms.Compose([
137
+ transforms.Resize((224, 224)),
138
+ transforms.ToTensor(),
139
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
140
+ std=[0.229, 0.224, 0.225])
141
+ ])
142
+
143
+ return train_transform, val_transform
144
+
145
+ # ================================================
146
+ # ПОДГОТОВКА ДАННЫХ (С ПРАВИЛЬНЫМИ ИМЕНАМИ ПАПОК)
147
+ # ================================================
148
+ def prepare_data_simple():
149
+ """Упрощенная подготовка данных с правильными именами папок"""
150
+ print("=" * 60)
151
+ print("📊 ПОДГОТОВКА ДАННЫХ")
152
+ print("=" * 60)
153
+
154
+ # Создаем папки
155
+ os.makedirs(Config.BASE_DIR, exist_ok=True)
156
+ os.makedirs(Config.EXTRACTED_DIR, exist_ok=True)
157
+ os.makedirs(Config.DATA_DIR, exist_ok=True)
158
+
159
+ # Для Windows - ручной ввод пути к архиву
160
+ print("\n📤 ШАГ 1: Укажите путь к архиву vagon1.rar")
161
+ print("Пример: C:/Users/Username/Downloads/vagon1.rar")
162
+ archive_path = input("Введите полный путь к архиву: ").strip()
163
+
164
+ # Заменяем прямые слеши на обратные для Windows
165
+ archive_path = archive_path.replace('/', '\\')
166
+
167
+ if not os.path.exists(archive_path):
168
+ print(f"\n❌ Файл не найден: {archive_path}")
169
+ return False
170
+
171
+ print(f"\n✅ Найден архив: {os.path.basename(archive_path)}")
172
+
173
+ # Распаковываем с использованием patool (кроссплатформенный)
174
+ try:
175
+ import patoolib
176
+ print("📦 Распаковка архива...")
177
+ patoolib.extract_archive(archive_path, outdir=Config.EXTRACTED_DIR)
178
+ print("✅ Архив распакован")
179
+ except ImportError:
180
+ print("⚠ Установите библиотеку patool: pip install patool")
181
+ print("Или распакуйте архив вручную в папку:", Config.EXTRACTED_DIR)
182
+ return False
183
+ except Exception as e:
184
+ print(f"⚠ Ошибка при распаковке: {e}")
185
+ print("Попробуйте распаковать вручную в папку:", Config.EXTRACTED_DIR)
186
+ return False
187
+
188
+ # Проверяем структуру - ИЩЕМ ПРАВИЛЬНЫЕ ИМЕНА ПАПОК
189
+ print("\n🔍 Проверка данных...")
190
+
191
+ # Список возможных имен папок (учитываем опечатки)
192
+ possible_folders = {
193
+ 'pered': ['pered', 'prered', 'peredn', 'peredniy', 'front', 'перед'],
194
+ 'zad': ['zad', 'zadn', 'zadniy', 'back', 'rear', 'зад'],
195
+ 'none': ['none', 'non', 'empty', 'нет', 'пусто']
196
+ }
197
+
198
+ actual_folders = os.listdir(Config.EXTRACTED_DIR)
199
+ print(f"Найдены папки в extracted: {actual_folders}")
200
+
201
+ # Сопоставляем фактические папки с нашими классами
202
+ folder_mapping = {}
203
+ for target_class, possible_names in possible_folders.items():
204
+ for folder in actual_folders:
205
+ folder_lower = folder.lower()
206
+ if folder_lower in possible_names:
207
+ folder_mapping[target_class] = folder
208
+ print(f" ✓ {target_class} → {folder}")
209
+ break
210
+
211
+ # Если не нашли все классы, пытаемся найти по содержимому
212
+ if len(folder_mapping) < len(Config.CLASS_NAMES):
213
+ print("\n⚠ Не все классы найдены. Ищем изображения...")
214
+ for folder in actual_folders:
215
+ folder_path = os.path.join(Config.EXTRACTED_DIR, folder)
216
+ if os.path.isdir(folder_path):
217
+ images = [f for f in os.listdir(folder_path)
218
+ if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
219
+ if images:
220
+ print(f" Папка '{folder}': {len(images)} изображений")
221
+ # Пробуем угадать класс
222
+ if 'pered' in folder.lower() or 'перед' in folder.lower():
223
+ folder_mapping['pered'] = folder
224
+ elif 'zad' in folder.lower() or 'зад' in folder.lower():
225
+ folder_mapping['zad'] = folder
226
+ elif 'none' in folder.lower() or 'нет' in folder.lower():
227
+ folder_mapping['none'] = folder
228
+
229
+ # Проверяем, что нашли все необходимые классы
230
+ missing_classes = []
231
+ for cls in Config.CLASS_NAMES:
232
+ if cls not in folder_mapping:
233
+ missing_classes.append(cls)
234
+
235
+ if missing_classes:
236
+ print(f"\n❌ Отсутствуют классы: {missing_classes}")
237
+ print("Пожалуйста, убедитесь что в архиве есть папки с именами:")
238
+ print(" - 'pered' (или похожее) - для передней части вагона")
239
+ print(" - 'zad' (или похожее) - для задней части вагона")
240
+ print(" - 'none' (или похожее) - для отсутствия вагона")
241
+ return False
242
+
243
+ print("\n✅ Все классы найдены!")
244
+
245
+ # Создаем структуру
246
+ print("\n📁 Создание структуры train/val...")
247
+ for split in ['train', 'val']:
248
+ for cls in Config.CLASS_NAMES:
249
+ os.makedirs(os.path.join(Config.DATA_DIR, split, cls), exist_ok=True)
250
+
251
+ # Распределяем данные
252
+ print("📊 Разделение на train/val (80/20)...")
253
+ total_images = 0
254
+
255
+ for target_class, source_folder in folder_mapping.items():
256
+ source_dir = os.path.join(Config.EXTRACTED_DIR, source_folder)
257
+
258
+ if not os.path.exists(source_dir):
259
+ print(f"⚠ Папка {source_folder} не найдена, пропускаем")
260
+ continue
261
+
262
+ images = [f for f in os.listdir(source_dir)
263
+ if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
264
+
265
+ if not images:
266
+ print(f"⚠ В папке {source_folder} нет изображений")
267
+ continue
268
+
269
+ print(f"\n📂 Обрабатываем {source_folder} → {target_class}:")
270
+ print(f" Найдено {len(images)} изображений")
271
+
272
+ # Разделяем
273
+ train_imgs, val_imgs = train_test_split(
274
+ images, test_size=0.2, random_state=42
275
+ )
276
+
277
+ # Копируем train
278
+ print(f" Копируем {len(train_imgs)} в train...")
279
+ for img in tqdm(train_imgs, desc=f" {target_class} train"):
280
+ src = os.path.join(source_dir, img)
281
+ dst = os.path.join(Config.DATA_DIR, 'train', target_class, img)
282
+ shutil.copy2(src, dst)
283
+
284
+ # Копируем val
285
+ print(f" Копируем {len(val_imgs)} в val...")
286
+ for img in tqdm(val_imgs, desc=f" {target_class} val"):
287
+ src = os.path.join(source_dir, img)
288
+ dst = os.path.join(Config.DATA_DIR, 'val', target_class, img)
289
+ shutil.copy2(src, dst)
290
+
291
+ total_images += len(images)
292
+ print(f" ✓ {target_class}: {len(train_imgs)} train, {len(val_imgs)} val")
293
+
294
+ # Проверяем финальную структуру
295
+ print(f"\n✅ Готово! Всего {total_images} изображений")
296
+ print("\n📂 Финальная структура данных:")
297
+
298
+ for split in ['train', 'val']:
299
+ split_total = 0
300
+ print(f"\n {split.upper()}:")
301
+ for cls in Config.CLASS_NAMES:
302
+ cls_dir = os.path.join(Config.DATA_DIR, split, cls)
303
+ if os.path.exists(cls_dir):
304
+ count = len([f for f in os.listdir(cls_dir)
305
+ if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
306
+ print(f" {cls}: {count} изображений")
307
+ split_total += count
308
+ print(f" Всего: {split_total}")
309
+
310
+ return True
311
+
312
+ # ================================================
313
+ # ДАТАСЕТ С ОБРАБОТКОЙ ПОВРЕЖДЕННЫХ ИЗОБРАЖЕНИЙ
314
+ # ================================================
315
+ class RobustWagonDataset(Dataset):
316
+ """Надежный датасет с обработкой поврежденных изображений"""
317
+ def __init__(self, data_dir, transform=None, mode='train'):
318
+ self.image_paths = []
319
+ self.labels = []
320
+ self.transform = transform
321
+
322
+ data_path = os.path.join(data_dir, mode)
323
+
324
+ for class_idx, class_name in enumerate(Config.CLASS_NAMES):
325
+ class_dir = os.path.join(data_path, class_name)
326
+ if not os.path.exists(class_dir):
327
+ print(f"⚠ Папка {class_dir} не найдена!")
328
+ continue
329
+
330
+ images = [f for f in os.listdir(class_dir)
331
+ if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
332
+
333
+ for img in images:
334
+ self.image_paths.append(os.path.join(class_dir, img))
335
+ self.labels.append(class_idx)
336
+
337
+ print(f"✅ {mode.upper()}: загружено {len(self.image_paths)} изображений")
338
+
339
+ def __len__(self):
340
+ return len(self.image_paths)
341
+
342
+ def __getitem__(self, idx):
343
+ # Загружаем изображение с обработкой ошибок
344
+ img_path = self.image_paths[idx]
345
+
346
+ # Пытаемся загрузить изображение
347
+ image = load_image_safe(img_path)
348
+
349
+ # Применяем трансформации
350
+ if self.transform:
351
+ image = self.transform(image)
352
+
353
+ return image, self.labels[idx]
354
+
355
+ # ================================================
356
+ # МОДЕЛЬ
357
+ # ================================================
358
+ def create_simple_model():
359
+ """Создание простой модели"""
360
+ # Используем EfficientNet-B2 как компромисс между скоростью и точностью
361
+ model = models.efficientnet_b2(weights='DEFAULT')
362
+
363
+ # Заменяем классификатор
364
+ in_features = model.classifier[1].in_features
365
+ model.classifier = nn.Sequential(
366
+ nn.Dropout(p=0.3),
367
+ nn.Linear(in_features, Config.NUM_CLASSES)
368
+ )
369
+
370
+ return model.to(Config.DEVICE)
371
+
372
+ # ================================================
373
+ # ОБУЧЕНИЕ (РАБОЧАЯ ВЕРСИЯ)
374
+ # ================================================
375
+ def train_simple_model():
376
+ """Простая и рабочая функция обучения"""
377
+ print("\n" + "="*60)
378
+ print("🏋️‍♂️ НАЧИНАЕМ ОБУЧЕНИЕ")
379
+ print("=" * 60)
380
+
381
+ # Очищаем кэш GPU
382
+ if torch.cuda.is_available():
383
+ torch.cuda.empty_cache()
384
+
385
+ Config.print_info()
386
+
387
+ # Получаем трансформации
388
+ train_transform, val_transform = get_transforms()
389
+
390
+ # Создаем датасеты
391
+ print("\n📥 Загрузка данных...")
392
+ train_dataset = RobustWagonDataset(
393
+ Config.DATA_DIR,
394
+ transform=train_transform,
395
+ mode='train'
396
+ )
397
+
398
+ val_dataset = RobustWagonDataset(
399
+ Config.DATA_DIR,
400
+ transform=val_transform,
401
+ mode='val'
402
+ )
403
+
404
+ if len(train_dataset) == 0:
405
+ print("❌ Обучающие данные не найдены!")
406
+ return None, None
407
+
408
+ # Создаем DataLoader
409
+ train_loader = DataLoader(
410
+ train_dataset,
411
+ batch_size=Config.BATCH_SIZE,
412
+ shuffle=True,
413
+ num_workers=0, # 0 для Windows чтобы избежать проблем
414
+ pin_memory=True if torch.cuda.is_available() else False
415
+ )
416
+
417
+ val_loader = DataLoader(
418
+ val_dataset,
419
+ batch_size=Config.BATCH_SIZE,
420
+ shuffle=False,
421
+ num_workers=0, # 0 для Windows чтобы избежать проблем
422
+ pin_memory=True if torch.cuda.is_available() else False
423
+ )
424
+
425
+ # Создаем модель
426
+ print("\n🧠 Создание модели...")
427
+ model = create_simple_model()
428
+
429
+ # Функция потерь и оптимизатор
430
+ criterion = nn.CrossEntropyLoss()
431
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
432
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
433
+
434
+ # История обучения
435
+ history = {
436
+ 'train_loss': [], 'train_acc': [],
437
+ 'val_loss': [], 'val_acc': []
438
+ }
439
+
440
+ best_val_acc = 0.0
441
+
442
+ print("\n" + "="*50)
443
+ print("🏁 НАЧАЛО ОБУЧЕНИЯ")
444
+ print("="*50)
445
+
446
+ for epoch in range(Config.NUM_EPOCHS):
447
+ print(f"\n📅 ЭПОХА {epoch + 1}/{Config.NUM_EPOCHS}")
448
+
449
+ # ===== ОБУЧЕНИЕ =====
450
+ model.train()
451
+ train_loss = 0.0
452
+ train_correct = 0
453
+ train_total = 0
454
+
455
+ train_bar = tqdm(train_loader, desc='Training')
456
+ for images, labels in train_bar:
457
+ # Перемещаем данные на GPU
458
+ images = images.to(Config.DEVICE)
459
+ labels = labels.to(Config.DEVICE)
460
+
461
+ # Forward pass
462
+ optimizer.zero_grad()
463
+ outputs = model(images)
464
+ loss = criterion(outputs, labels)
465
+
466
+ # Backward pass
467
+ loss.backward()
468
+ optimizer.step()
469
+
470
+ # Статистика
471
+ train_loss += loss.item()
472
+ _, predicted = outputs.max(1)
473
+ train_total += labels.size(0)
474
+ train_correct += predicted.eq(labels).sum().item()
475
+
476
+ # Обновляем прогресс-бар
477
+ train_bar.set_postfix({
478
+ 'Loss': f'{loss.item():.4f}',
479
+ 'Acc': f'{100.*train_correct/train_total:.1f}%'
480
+ })
481
+
482
+ # Средние значения за эпоху
483
+ avg_train_loss = train_loss / len(train_loader)
484
+ train_accuracy = train_correct / train_total
485
+
486
+ # ===== ВАЛИДАЦИЯ =====
487
+ model.eval()
488
+ val_loss = 0.0
489
+ val_correct = 0
490
+ val_total = 0
491
+ all_preds = []
492
+ all_labels = []
493
+
494
+ with torch.no_grad():
495
+ val_bar = tqdm(val_loader, desc='Validation')
496
+ for images, labels in val_bar:
497
+ images = images.to(Config.DEVICE)
498
+ labels = labels.to(Config.DEVICE)
499
+
500
+ outputs = model(images)
501
+ loss = criterion(outputs, labels)
502
+
503
+ val_loss += loss.item()
504
+ _, predicted = outputs.max(1)
505
+ val_total += labels.size(0)
506
+ val_correct += predicted.eq(labels).sum().item()
507
+
508
+ all_preds.extend(predicted.cpu().numpy())
509
+ all_labels.extend(labels.cpu().numpy())
510
+
511
+ avg_val_loss = val_loss / len(val_loader)
512
+ val_accuracy = val_correct / val_total
513
+
514
+ # Обновляем scheduler
515
+ scheduler.step()
516
+
517
+ # Сохраняем историю
518
+ history['train_loss'].append(avg_train_loss)
519
+ history['train_acc'].append(train_accuracy)
520
+ history['val_loss'].append(avg_val_loss)
521
+ history['val_acc'].append(val_accuracy)
522
+
523
+ # Сохраняем лучшую модель
524
+ if val_accuracy > best_val_acc:
525
+ best_val_acc = val_accuracy
526
+ torch.save({
527
+ 'epoch': epoch,
528
+ 'model_state_dict': model.state_dict(),
529
+ 'optimizer_state_dict': optimizer.state_dict(),
530
+ 'val_acc': val_accuracy,
531
+ 'train_acc': train_accuracy,
532
+ 'class_names': Config.CLASS_NAMES
533
+ }, Config.MODEL_SAVE_PATH)
534
+ print(f"💾 Сохранена лучшая модель! Точность: {val_accuracy:.4f}")
535
+
536
+ # Выводим статистику
537
+ print(f"📊 Результаты эпохи {epoch + 1}:")
538
+ print(f" Train Loss: {avg_train_loss:.4f}, Acc: {train_accuracy:.4f}")
539
+ print(f" Val Loss: {avg_val_loss:.4f}, Acc: {val_accuracy:.4f}")
540
+ print(f" LR: {scheduler.get_last_lr()[0]:.2e}")
541
+
542
+ # ===== ВИЗУАЛИЗАЦИЯ РЕЗУЛЬТАТОВ =====
543
+ print("\n📈 Визуализация результатов...")
544
+
545
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
546
+
547
+ # График потерь
548
+ axes[0, 0].plot(history['train_loss'], label='Train', marker='o', linewidth=2)
549
+ axes[0, 0].plot(history['val_loss'], label='Val', marker='s', linewidth=2)
550
+ axes[0, 0].set_title('Loss History', fontsize=14, fontweight='bold')
551
+ axes[0, 0].set_xlabel('Epoch')
552
+ axes[0, 0].set_ylabel('Loss')
553
+ axes[0, 0].legend()
554
+ axes[0, 0].grid(True, alpha=0.3)
555
+
556
+ # График точности
557
+ axes[0, 1].plot(history['train_acc'], label='Train', marker='o', linewidth=2)
558
+ axes[0, 1].plot(history['val_acc'], label='Val', marker='s', linewidth=2)
559
+ axes[0, 1].set_title('Accuracy History', fontsize=14, fontweight='bold')
560
+ axes[0, 1].set_xlabel('Epoch')
561
+ axes[0, 1].set_ylabel('Accuracy')
562
+ axes[0, 1].legend()
563
+ axes[0, 1].grid(True, alpha=0.3)
564
+
565
+ # Confusion Matrix
566
+ try:
567
+ cm = confusion_matrix(all_labels, all_preds)
568
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
569
+ xticklabels=Config.CLASS_NAMES,
570
+ yticklabels=Config.CLASS_NAMES,
571
+ ax=axes[1, 0])
572
+ axes[1, 0].set_title('Confusion Matrix', fontsize=14, fontweight='bold')
573
+ axes[1, 0].set_xlabel('Predicted')
574
+ axes[1, 0].set_ylabel('True')
575
+ except:
576
+ axes[1, 0].text(0.5, 0.5, 'Confusion Matrix\nне доступна',
577
+ ha='center', va='center', fontsize=12)
578
+ axes[1, 0].set_title('Confusion Matrix', fontsize=14, fontweight='bold')
579
+ axes[1, 0].axis('off')
580
+
581
+ # Classification Report
582
+ try:
583
+ report = classification_report(all_labels, all_preds,
584
+ target_names=Config.CLASS_NAMES)
585
+ axes[1, 1].text(0, 1, report, fontsize=10, fontfamily='monospace',
586
+ verticalalignment='top', transform=axes[1, 1].transAxes)
587
+ except:
588
+ axes[1, 1].text(0.5, 0.5, 'Classification Report\nне доступен',
589
+ ha='center', va='center', fontsize=12)
590
+
591
+ axes[1, 1].set_title('Classification Report', fontsize=14, fontweight='bold')
592
+ axes[1, 1].axis('off')
593
+
594
+ plt.tight_layout()
595
+ results_path = os.path.join(os.getcwd(), 'training_results.png')
596
+ plt.savefig(results_path, dpi=100, bbox_inches='tight')
597
+ plt.show()
598
+
599
+ # Выводим отчет по классификации
600
+ print("\n" + "="*60)
601
+ print("📋 ОТЧЕТ ПО КЛАССИФИКАЦИИ")
602
+ print("="*60)
603
+ try:
604
+ print(classification_report(all_labels, all_preds,
605
+ target_names=Config.CLASS_NAMES))
606
+ except:
607
+ print("Отчет не доступен")
608
+
609
+ print("\n" + "="*60)
610
+ print("🎉 ОБУЧЕНИЕ ЗАВЕРШЕНО!")
611
+ print("="*60)
612
+ print(f"🏆 Лучшая точность на валидации: {best_val_acc:.4f}")
613
+ print(f"💾 Модель сохранена: {Config.MODEL_SAVE_PATH}")
614
+ print("\n📋 Классы модели:")
615
+ for i, cls in enumerate(Config.CLASS_NAMES):
616
+ print(f" {i}: {cls}")
617
+
618
+ return model, history
619
+
620
+ # ================================================
621
+ # ПРЕДСКАЗАНИЕ С ОБРАБОТКОЙ ПОВРЕЖДЕННЫХ ИЗОБРАЖЕНИЙ
622
+ # ================================================
623
+ def predict_single_image():
624
+ """Предсказание для одного изображения с обработкой поврежденных файлов"""
625
+ if not os.path.exists(Config.MODEL_SAVE_PATH):
626
+ print("❌ Модель не обучена!")
627
+ return
628
+
629
+ print("\n📤 Введите путь к изображению для классификации...")
630
+ image_path = input("Введите полный путь к изображению: ").strip()
631
+ image_path = image_path.replace('/', '\\')
632
+
633
+ if not os.path.exists(image_path):
634
+ print(f"❌ Файл не найден: {image_path}")
635
+ return
636
+
637
+ print(f"✅ Изображение найдено: {os.path.basename(image_path)}")
638
+
639
+ # Пытаемся восстановить поврежденное изображение
640
+ print("🔧 Проверка целостности изображения...")
641
+ repair_success = repair_image_file(image_path)
642
+ if repair_success:
643
+ print("✅ Изображение восстановлено")
644
+
645
+ # Загружаем модель
646
+ model = create_simple_model()
647
+ checkpoint = torch.load(Config.MODEL_SAVE_PATH, map_location=Config.DEVICE)
648
+ model.load_state_dict(checkpoint['model_state_dict'])
649
+ model.eval()
650
+
651
+ # Получаем трансформации
652
+ _, val_transform = get_transforms()
653
+
654
+ # Загружаем и обрабатываем изображение
655
+ try:
656
+ # Используем безопасную загрузку
657
+ print("🖼️ Загрузка изображения...")
658
+ image = load_image_safe(image_path)
659
+
660
+ # Проверяем, что изображение загрузилось
661
+ if image is None:
662
+ print("❌ Не удалось загрузить изображение")
663
+ return None
664
+
665
+ print(f"✅ Изображение загружено. Размер: {image.size}")
666
+
667
+ # Применяем трансформации
668
+ print("🔄 Применение трансформаций...")
669
+ input_tensor = val_transform(image).unsqueeze(0).to(Config.DEVICE)
670
+
671
+ # Предсказание
672
+ print("🧠 Выполнение предсказания...")
673
+ with torch.no_grad():
674
+ outputs = model(input_tensor)
675
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
676
+ predicted_idx = torch.argmax(probabilities, dim=1).item()
677
+ confidence = probabilities[0][predicted_idx].item()
678
+
679
+ predicted_class = Config.CLASS_NAMES[predicted_idx]
680
+
681
+ # Выводим результат
682
+ print("\n" + "="*60)
683
+ print("🎯 РЕЗУЛЬТАТ КЛАССИФИКАЦИИ")
684
+ print("="*60)
685
+ print(f"📋 Класс: {predicted_class}")
686
+ print(f"📊 Уверенность: {confidence:.2%}")
687
+ print(f"\n📈 Распределение вероятностей:")
688
+ for i, cls in enumerate(Config.CLASS_NAMES):
689
+ prob = probabilities[0][i].item()
690
+ prob_percent = prob * 100
691
+ # Создаем прогресс-бар
692
+ bar_length = 20
693
+ filled_length = int(bar_length * prob)
694
+ bar = '█' * filled_length + '░' * (bar_length - filled_length)
695
+
696
+ # Подсвечиваем предсказанный класс
697
+ if i == predicted_idx:
698
+ print(f" ⭐ {cls}: {bar} {prob_percent:5.1f}% ({prob:.4f})")
699
+ else:
700
+ print(f" {cls}: {bar} {prob_percent:5.1f}% ({prob:.4f})")
701
+
702
+ # Визуализация
703
+ plt.figure(figsize=(14, 6))
704
+
705
+ # Изображение
706
+ plt.subplot(1, 3, 1)
707
+ plt.imshow(image)
708
+ plt.title(f"Входное изображение\n{os.path.basename(image_path)}", fontsize=12)
709
+ plt.axis('off')
710
+
711
+ # График вероятностей
712
+ plt.subplot(1, 3, 2)
713
+ colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
714
+ probs = probabilities[0].cpu().numpy()
715
+
716
+ bars = plt.bar(Config.CLASS_NAMES, probs, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
717
+ bars[predicted_idx].set_alpha(1.0)
718
+ bars[predicted_idx].set_linewidth(3)
719
+ bars[predicted_idx].set_edgecolor('red')
720
+
721
+ plt.title(f"Предсказание: {predicted_class}\nУ��еренность: {confidence:.2%}",
722
+ fontsize=14, fontweight='bold')
723
+ plt.ylim([0, 1.1])
724
+ plt.ylabel('Вероятность', fontsize=12)
725
+ plt.grid(True, alpha=0.3, axis='y')
726
+
727
+ # Добавляем значения на столбцы
728
+ for i, (bar, prob) in enumerate(zip(bars, probs)):
729
+ plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
730
+ f'{prob:.2%}', ha='center', va='bottom', fontsize=11,
731
+ fontweight='bold' if i == predicted_idx else 'normal',
732
+ color='red' if i == predicted_idx else 'black')
733
+
734
+ # Тепловая карта вероятностей
735
+ plt.subplot(1, 3, 3)
736
+ prob_matrix = probabilities.cpu().numpy().reshape(-1, 1)
737
+ plt.imshow(prob_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
738
+ plt.colorbar(label='Вероятность')
739
+ plt.yticks(range(len(Config.CLASS_NAMES)), Config.CLASS_NAMES)
740
+ plt.xticks([])
741
+ plt.title('Тепловая карта вероятностей', fontsize=14, fontweight='bold')
742
+
743
+ # Добавляем значения в ячейки
744
+ for i, prob in enumerate(prob_matrix):
745
+ plt.text(0, i, f'{prob[0]:.3f}', ha='center', va='center',
746
+ color='white' if prob[0] > 0.5 else 'black',
747
+ fontweight='bold' if i == predicted_idx else 'normal')
748
+
749
+ plt.tight_layout()
750
+ plt.show()
751
+
752
+ # Дополнительная информация
753
+ print("\n📝 ИНТЕРПРЕТАЦИЯ РЕЗУЛЬТАТА:")
754
+ if predicted_class == 'pered':
755
+ print(" 🚂 Передняя часть вагона обнаружена")
756
+ elif predicted_class == 'zad':
757
+ print(" 🚂 Задняя часть вагона обнаружена")
758
+ elif predicted_class == 'none':
759
+ print(" ⭕ Вагон не обнаружен")
760
+
761
+ if confidence > 0.9:
762
+ print(" ✅ Высокая уверенность предсказания")
763
+ elif confidence > 0.7:
764
+ print(" ⚠ Средняя уверенность предсказания")
765
+ else:
766
+ print(" ❓ Низкая уверенность, возможно неоднозначное изображение")
767
+
768
+ return predicted_class, confidence
769
+
770
+ except Exception as e:
771
+ print(f"\n❌ Критическая ошибка при обработке изображения: {e}")
772
+ print("\n🔧 ВОЗМОЖНЫЕ РЕШЕНИЯ:")
773
+ print(" 1. Попробуйте загрузить другое изображение")
774
+ print(" 2. Убедитесь, что файл не поврежден")
775
+ print(" 3. Проверьте формат файла (должен быть JPG, PNG)")
776
+
777
+ import traceback
778
+ traceback.print_exc()
779
+ return None
780
+
781
+ # ================================================
782
+ # ПАКЕТНОЕ ТЕСТИРОВАНИЕ
783
+ # ================================================
784
+ def batch_test_images():
785
+ """Тестирование модели на нескольких изображениях"""
786
+ if not os.path.exists(Config.MODEL_SAVE_PATH):
787
+ print("❌ Модель не обучена!")
788
+ return
789
+
790
+ print("\n📤 Введите путь к папке с изображениями для тестирования...")
791
+ folder_path = input("Введите полный путь к папке: ").strip()
792
+ folder_path = folder_path.replace('/', '\\')
793
+
794
+ if not os.path.exists(folder_path):
795
+ print(f"❌ Папка не найдена: {folder_path}")
796
+ return
797
+
798
+ # Получаем список изображений
799
+ image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
800
+ image_files = [f for f in os.listdir(folder_path)
801
+ if f.lower().endswith(image_extensions)]
802
+
803
+ if not image_files:
804
+ print(f"❌ В папке нет изображений: {folder_path}")
805
+ return
806
+
807
+ print(f"✅ Найдено {len(image_files)} изображений")
808
+
809
+ # Загружаем модель
810
+ model = create_simple_model()
811
+ checkpoint = torch.load(Config.MODEL_SAVE_PATH, map_location=Config.DEVICE)
812
+ model.load_state_dict(checkpoint['model_state_dict'])
813
+ model.eval()
814
+
815
+ # Получаем трансформации
816
+ _, val_transform = get_transforms()
817
+
818
+ results = []
819
+
820
+ for image_name in image_files:
821
+ print(f"\n🔍 Обработка: {image_name}")
822
+
823
+ # Путь к изображению
824
+ image_path = os.path.join(folder_path, image_name)
825
+
826
+ try:
827
+ # Безопасная загрузка
828
+ image = load_image_safe(image_path)
829
+ if image is None:
830
+ print(f" ❌ Не удалось загрузить {image_name}")
831
+ continue
832
+
833
+ # Пред��казание
834
+ input_tensor = val_transform(image).unsqueeze(0).to(Config.DEVICE)
835
+
836
+ with torch.no_grad():
837
+ outputs = model(input_tensor)
838
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
839
+ predicted_idx = torch.argmax(probabilities, dim=1).item()
840
+ confidence = probabilities[0][predicted_idx].item()
841
+
842
+ predicted_class = Config.CLASS_NAMES[predicted_idx]
843
+ results.append((image_name, predicted_class, confidence))
844
+
845
+ print(f" ✅ {predicted_class} ({confidence:.2%})")
846
+
847
+ except Exception as e:
848
+ print(f" ❌ Ошибка: {e}")
849
+ results.append((image_name, "ERROR", 0.0))
850
+
851
+ # Выводим сводку
852
+ print("\n" + "="*60)
853
+ print("📊 СВОДКА ПО ТЕСТИРОВАНИЮ")
854
+ print("="*60)
855
+
856
+ if not results:
857
+ print("❌ Нет результатов")
858
+ return
859
+
860
+ # Группируем по классам
861
+ class_summary = {}
862
+ for _, cls, conf in results:
863
+ if cls not in class_summary:
864
+ class_summary[cls] = []
865
+ class_summary[cls].append(conf)
866
+
867
+ print("\n📈 Статистика по классам:")
868
+ for cls, confidences in class_summary.items():
869
+ if cls == "ERROR":
870
+ print(f" ❌ Ошибки: {len(confidences)} изображений")
871
+ else:
872
+ avg_conf = np.mean(confidences) if confidences else 0
873
+ print(f" {cls}: {len(confidences)} изображений, средняя уверенность: {avg_conf:.2%}")
874
+
875
+ # Подробные результаты
876
+ print("\n📋 Подробные результаты:")
877
+ for i, (img_name, cls, conf) in enumerate(results, 1):
878
+ if cls == "ERROR":
879
+ print(f" {i:2d}. ❌ {img_name}")
880
+ else:
881
+ print(f" {i:2d}. ✅ {img_name}: {cls} ({conf:.2%})")
882
+
883
+ return results
884
+
885
+ # ================================================
886
+ # ГЛАВНОЕ МЕНЮ
887
+ # ================================================
888
+ def main_menu():
889
+ """Главное меню"""
890
+ print("\n" + "="*60)
891
+ print("🚂 КЛАССИФИКАТОР ВАГОНОВ")
892
+ print("="*60)
893
+ print(f"📱 Устройство: {Config.DEVICE}")
894
+ if Config.DEVICE.type == 'cuda':
895
+ print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")
896
+ print(f"💾 Память: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
897
+
898
+ while True:
899
+ print("\n" + "="*60)
900
+ print("ГЛАВНОЕ МЕНЮ:")
901
+ print("1. 📊 Подготовить данные")
902
+ print("2. 🏋️‍♂️ Обучить модель")
903
+ print("3. 🔍 Протестировать одно изображение")
904
+ print("4. 📦 Протестировать несколько изображений")
905
+ print("5. 📈 Показать графики")
906
+ print("6. 🧹 Очистить кэш")
907
+ print("0. ❌ Выход")
908
+ print("="*60)
909
+
910
+ choice = input("\nВыберите действие (0-6): ").strip()
911
+
912
+ if choice == '1':
913
+ # Подготовка данных
914
+ print("\n" + "="*60)
915
+ print("ПОДГОТОВКА ДАННЫХ")
916
+ print("="*60)
917
+
918
+ success = prepare_data_simple()
919
+ if success:
920
+ print("\n✅ Данные готовы к обучению!")
921
+ else:
922
+ print("\n❌ Ошибка при подготовке данных")
923
+
924
+ elif choice == '2':
925
+ # Обучение модели
926
+ if not os.path.exists(Config.DATA_DIR):
927
+ print("\n❌ Данные не подготовлены! Сначала выполните шаг 1.")
928
+ continue
929
+
930
+ try:
931
+ model, history = train_simple_model()
932
+ if model is not None:
933
+ print("\n✅ Обучение успешно завершено!")
934
+ except Exception as e:
935
+ print(f"\n❌ Ошибка при обучении: {e}")
936
+ import traceback
937
+ traceback.print_exc()
938
+
939
+ elif choice == '3':
940
+ # Тестирование одного изображения
941
+ if not os.path.exists(Config.MODEL_SAVE_PATH):
942
+ print("\n❌ Модель не обучена! Сначала выполните шаг 2.")
943
+ continue
944
+
945
+ try:
946
+ result = predict_single_image()
947
+ if result:
948
+ print("\n✅ Предсказание выполнено успешно!")
949
+ except Exception as e:
950
+ print(f"\n❌ Ошибка при предсказании: {e}")
951
+
952
+ elif choice == '4':
953
+ # Пакетное тестирование
954
+ if not os.path.exists(Config.MODEL_SAVE_PATH):
955
+ print("\n❌ Модель н�� обучена! Сначала выполните шаг 2.")
956
+ continue
957
+
958
+ try:
959
+ results = batch_test_images()
960
+ if results:
961
+ print("\n✅ Пакетное тестирование завершено!")
962
+ except Exception as e:
963
+ print(f"\n❌ Ошибка при пакетном тестировании: {e}")
964
+
965
+ elif choice == '5':
966
+ # Показать графики
967
+ results_path = os.path.join(os.getcwd(), 'training_results.png')
968
+ if os.path.exists(results_path):
969
+ print("\n📊 Графики обучения:")
970
+ try:
971
+ img = plt.imread(results_path)
972
+ plt.figure(figsize=(12, 8))
973
+ plt.imshow(img)
974
+ plt.axis('off')
975
+ plt.show()
976
+ except Exception as e:
977
+ print(f"❌ Ошибка при загрузке графиков: {e}")
978
+ else:
979
+ print("\n❌ Графики не найдены. Сначала обучите модель.")
980
+
981
+ elif choice == '6':
982
+ # Очистка кэша
983
+ if torch.cuda.is_available():
984
+ torch.cuda.empty_cache()
985
+ print("✅ Кэш GPU очищен!")
986
+ else:
987
+ print("⚠ GPU не доступна")
988
+
989
+ elif choice == '0':
990
+ print("\n👋 До свидания!")
991
+ if torch.cuda.is_available():
992
+ torch.cuda.empty_cache()
993
+ break
994
+
995
+ else:
996
+ print("\n❌ Неверный выбор. Пожалуйста, выберите от 0 до 6.")
997
+
998
+ # ================================================
999
+ # ЗАПУСК
1000
+ # ================================================
1001
+ if __name__ == "__main__":
1002
+ print("🚂 КЛАССИФИКАТОР ВАГОНОВ ДЛЯ WINDOWS 10")
1003
+ print("=" * 60)
1004
+ print("📦 Зависимости:")
1005
+ print("Установите библиотеки из файла requirements.txt:")
1006
+ print("pip install -r requirements.txt")
1007
+ print("=" * 60)
1008
+
1009
+ # Запускаем меню
1010
+ main_menu()