Spaces:
Running
Running
Upload 43 files
Browse files- .gitattributes +1 -0
- Dockerfile +26 -0
- Notebooks/EfficientNet_ConvNext_Fusion.ipynb +0 -0
- Notebooks/Resnet18_fine_tuning_final.ipynb +0 -0
- Notebooks/damage_detector_yolo.ipynb +0 -0
- README.md +699 -7
- app.py +236 -0
- assets/fusion_classification_report.png +0 -0
- assets/fusion_confusion_matrix.png +0 -0
- assets/fusion_training_curves.png +0 -0
- assets/resnet_classification_report.png +0 -0
- assets/resnet_confusion_matrix.png +0 -0
- assets/resnet_training_curves.png +0 -0
- assets/yolo_detection_sample.jpg +3 -0
- requirements.txt +15 -0
- scripts/gradcam.py +167 -0
- scripts/load_models.py +91 -0
- scripts/prediction_helper.py +313 -0
- scripts/yolo_predict.py +63 -0
- src/config.py +60 -0
- src/data/augmentation.py +90 -0
- src/data/dataset.py +189 -0
- src/data/ingestion.py +55 -0
- src/data/preprocessing.py +58 -0
- src/export/conver_model.py +68 -0
- src/export/upload_to_huggingface.py +90 -0
- src/models/fusion_model.py +112 -0
- src/models/resnet_model.py +64 -0
- src/training/train_fusion.py +92 -0
- src/training/train_resnet.py +68 -0
- src/training/train_yolo.py +85 -0
- src/training/trainer.py +305 -0
- test/test_augmentation.py +40 -0
- test/test_config.py +37 -0
- test/test_dataset.py +57 -0
- test/test_fusion_model.py +42 -0
- test/test_ingestion.py +32 -0
- test/test_model_conversion.py +39 -0
- test/test_preprocessing.py +37 -0
- test/test_resnet_model.py +38 -0
- test/test_train_fusion.py +39 -0
- test/test_train_resnet.py +39 -0
- test/test_train_yolo.py +36 -0
- test/test_upload_to_huggingface.py +53 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/yolo_detection_sample.jpg filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# --- SYSTEM DEPENDENCIES (CRITICAL FOR OPENCV / YOLO) ---
|
| 9 |
+
RUN apt-get update && apt-get install -y \
|
| 10 |
+
build-essential \
|
| 11 |
+
gcc \
|
| 12 |
+
libgl1 \
|
| 13 |
+
libglib2.0-0 \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# --- PYTHON DEPENDENCIES ---
|
| 17 |
+
COPY requirements.txt .
|
| 18 |
+
RUN pip install --no-cache-dir --upgrade pip \
|
| 19 |
+
&& pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# --- APP CODE ---
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
|
| 26 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
Notebooks/EfficientNet_ConvNext_Fusion.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Notebooks/Resnet18_fine_tuning_final.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Notebooks/damage_detector_yolo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
CHANGED
|
@@ -1,10 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚗 DamageLens: AI-Powered Car Damage Detection
|
| 2 |
+
|
| 3 |
+
[](https://python.org)
|
| 4 |
+
[](https://pytorch.org)
|
| 5 |
+
[](https://fastapi.tiangolo.com)
|
| 6 |
+
[](https://github.com/junaidariie/DamageLensAI/actions/workflows/ci.yaml)
|
| 7 |
+
[](LICENSE)
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## ⚠️ Important Notes
|
| 12 |
+
|
| 13 |
+
> **Cold Startup Time**: The API may take **4-5 minutes** on the first request to warm up the models. Subsequent predictions will be significantly faster.
|
| 14 |
+
|
| 15 |
+
> **Model Size**: The Fusion model is computationally intensive. Individual predictions typically complete in 30-60 seconds depending on hardware.
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
**APP LINK** : https://junaidariie.github.io/DamageLensAI/
|
| 20 |
+
|
| 21 |
+
**HF REPO** : https://huggingface.co/spaces/junaid17/DamageLensAI/tree/main
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## 📋 Table of Contents
|
| 26 |
+
|
| 27 |
+
- [Overview](#-overview)
|
| 28 |
+
- [Features](#-features)
|
| 29 |
+
- [Architecture](#-architecture)
|
| 30 |
+
- [Model Performance](#-model-performance)
|
| 31 |
+
- [CI Pipeline](#-ci-pipeline)
|
| 32 |
+
- [Setup & Installation](#-setup--installation)
|
| 33 |
+
- [Usage](#-usage)
|
| 34 |
+
- [API Documentation](#-api-documentation)
|
| 35 |
+
- [Model Optimization](#-model-optimization)
|
| 36 |
+
- [Dataset & Training](#-dataset--training)
|
| 37 |
+
- [Web UI Features](#-web-ui-features)
|
| 38 |
+
- [Directory Structure](#-directory-structure)
|
| 39 |
+
- [Limitations & Known Issues](#-limitations--known-issues)
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## 🎯 Overview
|
| 44 |
+
|
| 45 |
+
**DamageLens** is an advanced AI system for detecting and classifying car damage using multi-model fusion architecture. It combines the power of **ResNet-18**, **EfficientNet-V2-S**, and **ConvNeXt-Small** to achieve robust damage classification across vehicle front and rear sections.
|
| 46 |
+
|
| 47 |
+
The system can identify six damage categories:
|
| 48 |
+
- ✅ Front Normal / Front Breakage / Front Crushed
|
| 49 |
+
- ✅ Rear Normal / Rear Breakage / Rear Crushed
|
| 50 |
+
|
| 51 |
+
Additionally, it uses **YOLO object detection** to localize damage regions with bounding boxes.
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
## ✨ Features
|
| 56 |
+
|
| 57 |
+
| Feature | Description |
|
| 58 |
+
|---------|-------------|
|
| 59 |
+
| **Dual Model Architecture** | ResNet (lightweight) and Fusion (high-accuracy) options |
|
| 60 |
+
| **Grad-CAM Visualization** | Understand which image regions drive predictions |
|
| 61 |
+
| **Real-time YOLO Detection** | Localize damage with confidence scores |
|
| 62 |
+
| **FP16 Optimization** | Reduced model size (788MB → 135MB) with minimal accuracy loss |
|
| 63 |
+
| **FastAPI Backend** | High-performance REST API with async support |
|
| 64 |
+
| **Responsive Web UI** | Modern, interactive web interface with real-time feedback |
|
| 65 |
+
| **Static File Serving** | Efficient caching and delivery of results |
|
| 66 |
+
| **CI/CD Pipeline** | Automated testing via GitHub Actions on every push/PR |
|
| 67 |
+
| **HuggingFace Integration** | Models auto-downloaded from HF Hub on first startup |
|
| 68 |
+
|
| 69 |
+
---
|
| 70 |
+
|
| 71 |
+
## 🏗️ Architecture
|
| 72 |
+
|
| 73 |
+
### System Overview
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
┌──────────────────────────────────────────────────────┐
|
| 77 |
+
│ Frontend (Web UI) │
|
| 78 |
+
│ HTML / CSS / JavaScript (Dark Mode, Glassmorphism) │
|
| 79 |
+
│ ├─ Drag & Drop Image Upload │
|
| 80 |
+
│ ├─ Model Selection (Fusion / ResNet) │
|
| 81 |
+
│ └─ Real-time Result Tabs (Prediction/GradCAM/YOLO) │
|
| 82 |
+
└───────────────────┬──────────────────────────────────┘
|
| 83 |
+
│ REST API (JSON)
|
| 84 |
+
┌───────────────────▼──────────────────────────────────┐
|
| 85 |
+
│ FastAPI Backend (app.py) │
|
| 86 |
+
│ ├─ POST /predict/resnet → ResNet inference │
|
| 87 |
+
│ ├─ POST /predict/fusion → Fusion inference │
|
| 88 |
+
│ ├─ POST /predict?mode=* → Grad-CAM generation │
|
| 89 |
+
│ └─ POST /predict/yolo → YOLO detection │
|
| 90 |
+
│ │
|
| 91 |
+
│ Lifespan: models loaded once at startup │
|
| 92 |
+
│ Static: /static/uploads /static/results │
|
| 93 |
+
└──────┬───────────┬──────────────┬────────────────────┘
|
| 94 |
+
│ │ │
|
| 95 |
+
┌──────▼──┐ ┌─────▼──────┐ ┌───▼──────────┐
|
| 96 |
+
│ ResNet │ │ Fusion │ │ YOLO v11m │
|
| 97 |
+
│ (77%) │ │ (84%) │ │ Detection │
|
| 98 |
+
└──────┬──┘ └─────┬──────┘ └───┬──────────┘
|
| 99 |
+
│ │ │
|
| 100 |
+
└─────┬─────┘ │
|
| 101 |
+
│ │
|
| 102 |
+
┌───────▼──────┐ ┌────────▼────────┐
|
| 103 |
+
│ Grad-CAM │ │ Bounding Boxes │
|
| 104 |
+
│ Heatmaps │ │ + Confidence │
|
| 105 |
+
└──────────────┘ └─────────────────┘
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
### Model Loading (scripts/load_models.py)
|
| 109 |
+
|
| 110 |
+
```
|
| 111 |
+
Startup
|
| 112 |
+
│
|
| 113 |
+
├─ hf_hub_download("junaid17/car-damage-classifier")
|
| 114 |
+
│ └─> ResnetCarDamagePredictor(checkpoint, class_map)
|
| 115 |
+
│
|
| 116 |
+
├─ hf_hub_download("junaid17/best_fusion_model_fp16")
|
| 117 |
+
│ └─> FusionCarDamagePredictor(checkpoint, class_map)
|
| 118 |
+
│
|
| 119 |
+
└─ hf_hub_download("junaid17/Yolo_Model")
|
| 120 |
+
└─> YOLO(checkpoint)
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### Fusion Model (High Accuracy — 84%)
|
| 124 |
+
|
| 125 |
+
```
|
| 126 |
+
┌─────────────────────────────────────────────────────────────────┐
|
| 127 |
+
│ INPUT IMAGE │
|
| 128 |
+
│ (3, 260, 260) │
|
| 129 |
+
└────────────────┬────────────────────────────────┬──────────────┘
|
| 130 |
+
│ │
|
| 131 |
+
┌───────▼────────┐ ┌─────────▼────────┐
|
| 132 |
+
│ EfficientNet- │ │ ConvNeXt-Small │
|
| 133 |
+
│ V2-S Backbone │ │ Backbone │
|
| 134 |
+
│ │ │ │
|
| 135 |
+
│ Frozen except │ │ Frozen except │
|
| 136 |
+
│ features[5,6,7]│ │ stages[2,3] + │
|
| 137 |
+
│ (unfrozen) │ │ layernorm │
|
| 138 |
+
└───────┬────────┘ └─────────┬────────┘
|
| 139 |
+
│ │
|
| 140 |
+
┌───────▼────────┐ ┌─────────▼────────┐
|
| 141 |
+
│ AdaptiveAvg │ │ Pooler Output │
|
| 142 |
+
│ Pool → Flatten │ │ │
|
| 143 |
+
└───────┬────────┘ └─────────┬────────┘
|
| 144 |
+
│ (1280,) │ (768,)
|
| 145 |
+
└──────────────┬─────────────────┘
|
| 146 |
+
│
|
| 147 |
+
┌───────▼────────┐
|
| 148 |
+
│ CONCATENATE │
|
| 149 |
+
│ 1280 + 768 │
|
| 150 |
+
│ = (2048,) │
|
| 151 |
+
└───────┬────────┘
|
| 152 |
+
│
|
| 153 |
+
┌───────────▼───────────┐
|
| 154 |
+
│ FUSION HEAD │
|
| 155 |
+
│ Dropout(0.4) │
|
| 156 |
+
│ Linear(2048 → 512) │
|
| 157 |
+
│ LayerNorm(512) │
|
| 158 |
+
│ GELU() │
|
| 159 |
+
│ Dropout(0.3) │
|
| 160 |
+
│ Linear(512 → 256) │
|
| 161 |
+
│ LayerNorm(256) │
|
| 162 |
+
│ GELU() │
|
| 163 |
+
│ Dropout(0.2) │
|
| 164 |
+
│ Linear(256 → 6) │
|
| 165 |
+
└───────────┬───────────┘
|
| 166 |
+
│
|
| 167 |
+
┌───────▼────────┐
|
| 168 |
+
│ OUTPUT LOGITS │
|
| 169 |
+
│ (6 classes) │
|
| 170 |
+
└────────────────┘
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
**Optimizer**: AdamW with per-group learning rates
|
| 174 |
+
- EfficientNet features[5]: lr=1e-5
|
| 175 |
+
- EfficientNet features[6,7]: lr=3e-5
|
| 176 |
+
- ConvNeXt stages[2,3] + layernorm: lr=3e-5
|
| 177 |
+
- Fusion head: lr=1e-4
|
| 178 |
+
- Loss: CrossEntropyLoss with label_smoothing=0.1
|
| 179 |
+
- Early stopping patience: 7
|
| 180 |
+
|
| 181 |
+
### ResNet-18 (Lightweight — 77%)
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
┌──────────────────────────────────┐
|
| 185 |
+
│ INPUT IMAGE │
|
| 186 |
+
│ (3, 128, 128) │
|
| 187 |
+
└───────────────┬──────────────────┘
|
| 188 |
+
│
|
| 189 |
+
┌───────▼─────────┐
|
| 190 |
+
│ ResNet-18 │
|
| 191 |
+
│ Backbone │
|
| 192 |
+
│ │
|
| 193 |
+
│ Frozen except │
|
| 194 |
+
│ layer3, layer4 │
|
| 195 |
+
└───────┬─────────┘
|
| 196 |
+
│ (512,)
|
| 197 |
+
┌───────▼─────────────────────┐
|
| 198 |
+
│ Classification Head │
|
| 199 |
+
│ Dropout(0.5) │
|
| 200 |
+
│ Linear(512 → 256) │
|
| 201 |
+
│ ReLU() │
|
| 202 |
+
│ Dropout(0.3) │
|
| 203 |
+
│ Linear(256 → 6 classes) │
|
| 204 |
+
└───────┬─────────────────────┘
|
| 205 |
+
│
|
| 206 |
+
┌───────▼──────────┐
|
| 207 |
+
│ OUTPUT LOGITS │
|
| 208 |
+
│ (6 classes) │
|
| 209 |
+
└──────────────────┘
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
**Optimizer**: AdamW with per-group learning rates
|
| 213 |
+
- layer3: lr=1e-5
|
| 214 |
+
- layer4: lr=1e-5
|
| 215 |
+
- fc head: lr=1e-4
|
| 216 |
+
- Loss: CrossEntropyLoss
|
| 217 |
+
- Early stopping patience: 7
|
| 218 |
+
|
| 219 |
+
### YOLO v11m Integration
|
| 220 |
+
|
| 221 |
+
```
|
| 222 |
+
┌─────────────────────────────┐
|
| 223 |
+
│ INPUT IMAGE │
|
| 224 |
+
│ imgsz=640, conf=0.05 │
|
| 225 |
+
└──────────────┬──────────────┘
|
| 226 |
+
│
|
| 227 |
+
┌───────▼────────┐
|
| 228 |
+
│ YOLO v11m │
|
| 229 |
+
│ Inference │
|
| 230 |
+
└───────┬────────┘
|
| 231 |
+
│
|
| 232 |
+
┌──────────┴──────────┐
|
| 233 |
+
│ │
|
| 234 |
+
┌───▼───────┐ ┌──────▼──────┐
|
| 235 |
+
│ Bboxes │ │ Confidence │
|
| 236 |
+
│ (x1,y1, │ │ Scores + │
|
| 237 |
+
│ x2,y2) │ │ Class Label │
|
| 238 |
+
└───┬───────┘ └──────┬──────┘
|
| 239 |
+
└──────────┬──────────┘
|
| 240 |
+
│
|
| 241 |
+
┌───────▼────────┐
|
| 242 |
+
│ result.plot() │
|
| 243 |
+
│ Save to disk │
|
| 244 |
+
└────────────────┘
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### Grad-CAM Pipeline (scripts/gradcam.py)
|
| 248 |
+
|
| 249 |
+
```
|
| 250 |
+
Image Path
|
| 251 |
+
│
|
| 252 |
+
├─ ResNet mode: target_layer = model.layer4[-1]
|
| 253 |
+
└─ Fusion mode: target_layer = model.eff_features[-1]
|
| 254 |
+
(FP16 → FP32 cast on CPU automatically)
|
| 255 |
+
│
|
| 256 |
+
├─ Register forward hook (_GradCAMHook)
|
| 257 |
+
├─ Forward pass → score.backward()
|
| 258 |
+
├─ acts [C,H,W] × weights (mean of grads) → CAM [H,W]
|
| 259 |
+
├─ ReLU → normalize → resize to original dims
|
| 260 |
+
└─ cv2.applyColorMap(COLORMAP_JET) → addWeighted overlay
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
### Data Pipeline (src/data/)
|
| 264 |
+
|
| 265 |
+
```
|
| 266 |
+
Raw Images (data/dataset/)
|
| 267 |
+
│
|
| 268 |
+
├─ ingestion.py → scan folders, build file list
|
| 269 |
+
├─ preprocessing.py → validate / clean images
|
| 270 |
+
├─ augmentation.py → train/val transforms
|
| 271 |
+
│ ResNet: Resize(128,128) + HFlip + Rotation(15°) + ColorJitter
|
| 272 |
+
│ Fusion: Resize(260,260) + HFlip + Rotation(10°) + ColorJitter
|
| 273 |
+
└─ dataset.py → ImageFolder DataLoaders
|
| 274 |
+
(train 80% / val 20%, seed=42)
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
### Export & Deployment (src/export/)
|
| 278 |
+
|
| 279 |
+
```
|
| 280 |
+
Trained Checkpoints (checkpoints/)
|
| 281 |
+
│
|
| 282 |
+
├─ conver_model.py → FP32 → FP16 conversion
|
| 283 |
+
│ 788MB → 135MB (82.9% reduction)
|
| 284 |
+
└─ upload_to_huggingface.py → HfApi upload to:
|
| 285 |
+
junaid17/new-damagelens-resnet-classifier
|
| 286 |
+
junaid17/new-damagelens-fusion-fp16
|
| 287 |
+
junaid17/new-damagelens-yolo-detector
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
---
|
| 291 |
+
|
| 292 |
+
## 📊 Model Performance
|
| 293 |
+
|
| 294 |
+
### Fusion Model (High Accuracy — 84% Overall)
|
| 295 |
+
|
| 296 |
+
**Classification Report:**
|
| 297 |
+
|
| 298 |
+

|
| 299 |
+
|
| 300 |
+
**Confusion Matrix:**
|
| 301 |
+
|
| 302 |
+

|
| 303 |
+
|
| 304 |
+
**Training Curves:**
|
| 305 |
+
|
| 306 |
+

|
| 307 |
+
|
| 308 |
+
---
|
| 309 |
+
|
| 310 |
+
### ResNet-18 (Lightweight — 77% Overall)
|
| 311 |
+
|
| 312 |
+
**Classification Report:**
|
| 313 |
+
|
| 314 |
+

|
| 315 |
+
|
| 316 |
+
**Confusion Matrix:**
|
| 317 |
+
|
| 318 |
+

|
| 319 |
+
|
| 320 |
+
**Training Curves:**
|
| 321 |
+
|
| 322 |
+

|
| 323 |
+
|
| 324 |
+
---
|
| 325 |
+
|
| 326 |
+
### YOLO Detection Results
|
| 327 |
+
|
| 328 |
+

|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
## 🔁 CI Pipeline
|
| 333 |
+
|
| 334 |
+
DamageLens uses **GitHub Actions** for continuous integration. Every push or pull request to `main`, `master`, or `dev` triggers the full test suite automatically.
|
| 335 |
+
|
| 336 |
+
**CI Screenshot (GitHub Actions — All Tests Passing):**
|
| 337 |
+
|
| 338 |
+

|
| 339 |
+
|
| 340 |
+
### What the pipeline tests:
|
| 341 |
+
|
| 342 |
+
| Step | Test File | What it covers |
|
| 343 |
+
|------|-----------|----------------|
|
| 344 |
+
| Config | `test_config.py` | Paths, constants, class map |
|
| 345 |
+
| Ingestion | `test_ingestion.py` | Dataset folder scanning |
|
| 346 |
+
| Preprocessing | `test_preprocessing.py` | Image validation & cleaning |
|
| 347 |
+
| Augmentation | `test_augmentation.py` | Transform pipelines |
|
| 348 |
+
| Dataset | `test_dataset.py` | DataLoader creation |
|
| 349 |
+
| ResNet Architecture | `test_resnet_model.py` | Model init & forward pass |
|
| 350 |
+
| ResNet Training | `test_train_resnet.py` | Smoke test training loop |
|
| 351 |
+
|
| 352 |
+
### Pipeline config (`.github/workflows/ci.yaml`):
|
| 353 |
+
- Runs on: `ubuntu-latest`
|
| 354 |
+
- Python: `3.10`
|
| 355 |
+
- Triggers: push & PR to `main` / `master` / `dev`
|
| 356 |
+
|
| 357 |
---
|
| 358 |
+
|
| 359 |
+
## 🚀 Setup & Installation
|
| 360 |
+
|
| 361 |
+
### Prerequisites
|
| 362 |
+
|
| 363 |
+
- Python 3.11+
|
| 364 |
+
- CUDA 11.8+ (for GPU acceleration, optional but recommended)
|
| 365 |
+
- 8GB+ RAM (16GB recommended for Fusion model)
|
| 366 |
+
|
| 367 |
+
### Installation Steps
|
| 368 |
+
|
| 369 |
+
```bash
|
| 370 |
+
# Clone the repository
|
| 371 |
+
git clone https://github.com/junaid17/damagelens.git
|
| 372 |
+
cd DamageLens
|
| 373 |
+
|
| 374 |
+
# Create virtual environment
|
| 375 |
+
python -m venv myvenv
|
| 376 |
+
source myvenv/bin/activate # On Windows: myvenv\Scripts\activate
|
| 377 |
+
|
| 378 |
+
# Install dependencies
|
| 379 |
+
pip install -r requirements.txt
|
| 380 |
+
|
| 381 |
+
# Create required directories
|
| 382 |
+
mkdir -p static/uploads static/results checkpoints assets
|
| 383 |
+
```
|
| 384 |
+
|
| 385 |
+
### Download Pre-trained Models
|
| 386 |
+
|
| 387 |
+
Models are automatically downloaded from Hugging Face on first use:
|
| 388 |
+
- `car-damage-classifier.pt` — ResNet-18 checkpoint
|
| 389 |
+
- `best_fusion_model_fp16.pt` — Fusion model (FP16 optimized, 135MB)
|
| 390 |
+
- `damage_detector.pt` — YOLO v11m model
|
| 391 |
+
|
| 392 |
---
|
| 393 |
|
| 394 |
+
## 💻 Usage
|
| 395 |
+
|
| 396 |
+
### Running the FastAPI Server
|
| 397 |
+
|
| 398 |
+
```bash
|
| 399 |
+
uvicorn app:app --reload --host 127.0.0.1 --port 8000
|
| 400 |
+
```
|
| 401 |
+
|
| 402 |
+
Open your browser at `http://127.0.0.1:8000`
|
| 403 |
+
|
| 404 |
+
#### Quick Start:
|
| 405 |
+
1. Upload a car image (JPG/PNG)
|
| 406 |
+
2. Select analysis mode: **Fusion** (accurate) or **ResNet** (fast)
|
| 407 |
+
3. Click "Run AI Analysis"
|
| 408 |
+
4. View results in tabs:
|
| 409 |
+
- 📊 **Prediction**: Confidence scores and probabilities
|
| 410 |
+
- 👀 **Grad-CAM**: Visualize which regions influenced the prediction
|
| 411 |
+
- 🎯 **YOLO**: Damage bounding boxes with confidence
|
| 412 |
+
|
| 413 |
+
### Python API Example
|
| 414 |
+
|
| 415 |
+
```python
|
| 416 |
+
import requests
|
| 417 |
+
|
| 418 |
+
with open('car_image.jpg', 'rb') as f:
|
| 419 |
+
files = {'image': f}
|
| 420 |
+
resp = requests.post('http://127.0.0.1:8000/predict/resnet', files=files)
|
| 421 |
+
print(resp.json())
|
| 422 |
+
|
| 423 |
+
with open('car_image.jpg', 'rb') as f:
|
| 424 |
+
files = {'image': f}
|
| 425 |
+
resp = requests.post('http://127.0.0.1:8000/predict/fusion', files=files)
|
| 426 |
+
print(resp.json())
|
| 427 |
+
```
|
| 428 |
+
|
| 429 |
+
---
|
| 430 |
+
|
| 431 |
+
## 📡 API Documentation
|
| 432 |
+
|
| 433 |
+
### `POST /predict/resnet`
|
| 434 |
+
```
|
| 435 |
+
Content-Type: multipart/form-data
|
| 436 |
+
Body: image (File)
|
| 437 |
+
|
| 438 |
+
Response:
|
| 439 |
+
{
|
| 440 |
+
"status": "success",
|
| 441 |
+
"prediction": {
|
| 442 |
+
"Rear Normal": 0.47,
|
| 443 |
+
"Front Normal": 0.25,
|
| 444 |
+
...
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
```
|
| 448 |
+
|
| 449 |
+
### `POST /predict/fusion`
|
| 450 |
+
```
|
| 451 |
+
Content-Type: multipart/form-data
|
| 452 |
+
Body: image (File)
|
| 453 |
+
|
| 454 |
+
Response:
|
| 455 |
+
{
|
| 456 |
+
"status": "success",
|
| 457 |
+
"prediction": {
|
| 458 |
+
"Rear Normal": 0.49,
|
| 459 |
+
"Front Normal": 0.35,
|
| 460 |
+
...
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
+
```
|
| 464 |
+
|
| 465 |
+
### `POST /predict?mode={resnet|fusion}` — Grad-CAM
|
| 466 |
+
```
|
| 467 |
+
Content-Type: multipart/form-data
|
| 468 |
+
Body: file (File), mode (String)
|
| 469 |
+
|
| 470 |
+
Response:
|
| 471 |
+
{
|
| 472 |
+
"status": "success",
|
| 473 |
+
"mode": "fusion",
|
| 474 |
+
"original_image": "/static/uploads/{uuid}_input.jpg",
|
| 475 |
+
"selected_viz": "/static/results/{uuid}_fusion.jpg",
|
| 476 |
+
"resnet_viz": null,
|
| 477 |
+
"fusion_viz": "/static/results/{uuid}_fusion.jpg"
|
| 478 |
+
}
|
| 479 |
+
```
|
| 480 |
+
|
| 481 |
+
### `POST /predict/yolo`
|
| 482 |
+
```
|
| 483 |
+
Content-Type: multipart/form-data
|
| 484 |
+
Body: file (File)
|
| 485 |
+
|
| 486 |
+
Response:
|
| 487 |
+
{
|
| 488 |
+
"status": "success",
|
| 489 |
+
"original_image": "/static/uploads/{uuid}_input.jpg",
|
| 490 |
+
"yolo_image": "/static/results/{uuid}_yolo.jpg",
|
| 491 |
+
"detections": [
|
| 492 |
+
{ "label": "damage", "confidence": 0.87, "box": [x1, y1, x2, y2] }
|
| 493 |
+
],
|
| 494 |
+
"total_detections": 2,
|
| 495 |
+
"message": "Detections found"
|
| 496 |
+
}
|
| 497 |
+
```
|
| 498 |
+
|
| 499 |
+
---
|
| 500 |
+
|
| 501 |
+
## 🔧 Model Optimization
|
| 502 |
+
|
| 503 |
+
### FP16 Conversion (Fusion Model)
|
| 504 |
+
|
| 505 |
+
```
|
| 506 |
+
Original Model (FP32): 788 MB
|
| 507 |
+
Optimized Model (FP16): 135 MB
|
| 508 |
+
───────────────────────────────────
|
| 509 |
+
Compression Ratio: 82.9% reduction ✅
|
| 510 |
+
Accuracy Loss: < 1% ⚠️
|
| 511 |
+
Speed Improvement: ~1.3x faster ⚡
|
| 512 |
+
```
|
| 513 |
+
|
| 514 |
+
The system auto-detects FP16 checkpoints at load time:
|
| 515 |
+
|
| 516 |
+
```python
|
| 517 |
+
if first_tensor.dtype == torch.float16:
|
| 518 |
+
model = model.half()
|
| 519 |
+
|
| 520 |
+
# Grad-CAM on CPU: FP16 → FP32 cast applied automatically
|
| 521 |
+
if is_half:
|
| 522 |
+
model = model.float()
|
| 523 |
+
```
|
| 524 |
+
|
| 525 |
+
---
|
| 526 |
+
|
| 527 |
+
## 📚 Dataset & Training
|
| 528 |
+
|
| 529 |
+
### Data Constraints
|
| 530 |
+
|
| 531 |
+
- **Total Samples**: ~1,800 images
|
| 532 |
+
- **Train/Val Split**: 80/20 (seed=42)
|
| 533 |
+
- **Classes**: 6 (F_Breakage, F_Crushed, F_Normal, R_Breakage, R_Crushed, R_Normal)
|
| 534 |
+
- **YOLO subset**: ~100 annotated images (train/val split)
|
| 535 |
+
|
| 536 |
+
### Data Augmentation
|
| 537 |
+
|
| 538 |
+
| Transform | ResNet | Fusion |
|
| 539 |
+
|-----------|--------|--------|
|
| 540 |
+
| Resize | 128×128 | 260×260 |
|
| 541 |
+
| RandomHorizontalFlip | ✅ | ✅ |
|
| 542 |
+
| RandomRotation | ±15° | ±10° |
|
| 543 |
+
| ColorJitter (b/c/s) | ±20% | ±15% |
|
| 544 |
+
| ImageNet Normalize | ✅ | ✅ |
|
| 545 |
+
|
| 546 |
+
### Training Configuration
|
| 547 |
+
|
| 548 |
+
| Setting | ResNet | Fusion |
|
| 549 |
+
|---------|--------|--------|
|
| 550 |
+
| Backbone | ResNet-18 | EfficientNet-V2-S + ConvNeXt-Small |
|
| 551 |
+
| Frozen layers | All except layer3, layer4 | All except features[5,6,7] / stages[2,3] |
|
| 552 |
+
| Optimizer | AdamW | AdamW (per-group LR) |
|
| 553 |
+
| Loss | CrossEntropyLoss | CrossEntropyLoss (label_smoothing=0.1) |
|
| 554 |
+
| Early stopping | patience=7 | patience=7 |
|
| 555 |
+
| Input size | 128×128 | 260×260 (EfficientNet) / 224×224 (ConvNeXt) |
|
| 556 |
+
|
| 557 |
+
---
|
| 558 |
+
|
| 559 |
+
## 🎨 Web UI Features
|
| 560 |
+
|
| 561 |
+
- Dark mode glassmorphism design
|
| 562 |
+
- Drag & drop image upload
|
| 563 |
+
- Model selection dropdown (Fusion / ResNet)
|
| 564 |
+
- Real-time confidence bar animation
|
| 565 |
+
- Tab navigation: Prediction → Grad-CAM → YOLO
|
| 566 |
+
- Scan line effect during processing
|
| 567 |
+
- Plotly bar chart for class probabilities
|
| 568 |
+
- Side-by-side original vs heatmap comparison
|
| 569 |
+
|
| 570 |
+
---
|
| 571 |
+
|
| 572 |
+
## 🔍 Grad-CAM Visualization
|
| 573 |
+
|
| 574 |
+
Gradient-weighted Class Activation Mapping highlights which image regions most influenced the model's prediction.
|
| 575 |
+
|
| 576 |
+
```
|
| 577 |
+
Original Image + Grad-CAM Heatmap = Overlay
|
| 578 |
+
Red = High importance
|
| 579 |
+
Blue = Low importance
|
| 580 |
+
```
|
| 581 |
+
|
| 582 |
+
- ResNet: hooks into `layer4[-1]`
|
| 583 |
+
- Fusion: hooks into `eff_features[-1]` (EfficientNet's last block)
|
| 584 |
+
|
| 585 |
+
---
|
| 586 |
+
|
| 587 |
+
## 📋 Directory Structure
|
| 588 |
+
|
| 589 |
+
```
|
| 590 |
+
DamageLens/
|
| 591 |
+
├── app.py # FastAPI app + all endpoints
|
| 592 |
+
├── index.html # Web UI
|
| 593 |
+
├── requirements.txt
|
| 594 |
+
├── README.md
|
| 595 |
+
│
|
| 596 |
+
├── .github/
|
| 597 |
+
│ └── workflows/
|
| 598 |
+
│ └── ci.yaml # GitHub Actions CI pipeline
|
| 599 |
+
│
|
| 600 |
+
├── assets/ # ← Place README images here
|
| 601 |
+
│ ├── fusion_classification_report.png
|
| 602 |
+
│ ├── fusion_confusion_matrix.png
|
| 603 |
+
│ ├── fusion_training_curves.png
|
| 604 |
+
│ ├── resnet_classification_report.png
|
| 605 |
+
│ ├── resnet_confusion_matrix.png
|
| 606 |
+
│ ├── resnet_training_curves.png
|
| 607 |
+
│ ├── yolo_detection_sample.png
|
| 608 |
+
│ └── ci_pipeline_passing.png
|
| 609 |
+
│
|
| 610 |
+
├── scripts/
|
| 611 |
+
│ ├── prediction_helper.py # ResNet + Fusion model classes & inference
|
| 612 |
+
│ ├── gradcam.py # Grad-CAM (ResNet + Fusion, CPU-optimized)
|
| 613 |
+
│ ├── load_models.py # HF Hub download + model initialization
|
| 614 |
+
│ └── yolo_predict.py # YOLO inference + bbox drawing
|
| 615 |
+
│
|
| 616 |
+
├── src/
|
| 617 |
+
│ ├── config.py # Paths, hyperparams, class map
|
| 618 |
+
│ ├── data/
|
| 619 |
+
│ │ ├── ingestion.py # Dataset folder scanning
|
| 620 |
+
│ │ ├── preprocessing.py # Image validation
|
| 621 |
+
│ │ ├── augmentation.py # Train/val transforms
|
| 622 |
+
│ │ └── dataset.py # DataLoader creation
|
| 623 |
+
│ ├── models/
|
| 624 |
+
│ │ ├── resnet_model.py # CarClassifierResNet
|
| 625 |
+
│ │ └── fusion_model.py # FusionClassifier
|
| 626 |
+
│ ├── training/
|
| 627 |
+
│ │ ├── trainer.py # Generic train loop (single + dual input)
|
| 628 |
+
│ │ ├── train_resnet.py # ResNet training entry point
|
| 629 |
+
│ │ ├── train_fusion.py # Fusion training entry point
|
| 630 |
+
│ │ └── train_yolo.py # YOLO fine-tuning
|
| 631 |
+
│ └── export/
|
| 632 |
+
│ ├── conver_model.py # FP32 → FP16 conversion
|
| 633 |
+
│ └── upload_to_huggingface.py # HF Hub upload script
|
| 634 |
+
│
|
| 635 |
+
├── checkpoints/
|
| 636 |
+
│ ├── best_resnet_model.pt
|
| 637 |
+
│ ├── best_fusion_model_fp16.pt
|
| 638 |
+
│ ├── damage_detector.pt
|
| 639 |
+
│ └── yolo11m.pt
|
| 640 |
+
│
|
| 641 |
+
├── Notebooks/
|
| 642 |
+
│ ├── Resnet18_fine_tuning_final.ipynb
|
| 643 |
+
│ ├── EfficientNet_ConvNext_Fusion.ipynb
|
| 644 |
+
│ └── damage_detector_yolo.ipynb
|
| 645 |
+
│
|
| 646 |
+
├── test/
|
| 647 |
+
│ ├── test_config.py
|
| 648 |
+
│ ├── test_ingestion.py
|
| 649 |
+
│ ├── test_preprocessing.py
|
| 650 |
+
│ ├── test_augmentation.py
|
| 651 |
+
│ ├── test_dataset.py
|
| 652 |
+
│ ├── test_resnet_model.py
|
| 653 |
+
│ ├── test_fusion_model.py
|
| 654 |
+
│ ├── test_train_resnet.py
|
| 655 |
+
│ ├── test_train_fusion.py
|
| 656 |
+
│ ├── test_train_yolo.py
|
| 657 |
+
│ ├── test_model_conversion.py
|
| 658 |
+
│ └── test_upload_to_huggingface.py
|
| 659 |
+
│
|
| 660 |
+
├── data/
|
| 661 |
+
│ ├── dataset/ # 6-class image folders
|
| 662 |
+
│ │ ├── F_Breakage/
|
| 663 |
+
│ │ ├── F_Crushed/
|
| 664 |
+
│ │ ├── F_Normal/
|
| 665 |
+
│ │ ├── R_Breakage/
|
| 666 |
+
│ │ ├── R_Crushed/
|
| 667 |
+
│ │ └── R_Normal/
|
| 668 |
+
│ └── yolo/ # YOLO annotated subset
|
| 669 |
+
│ ├── train/images + labels/
|
| 670 |
+
│ ├── val/images + labels/
|
| 671 |
+
│ └── dataset_custom.yaml
|
| 672 |
+
│
|
| 673 |
+
└── static/
|
| 674 |
+
├── uploads/ # Temp uploaded images
|
| 675 |
+
└── results/ # Generated Grad-CAM / YOLO outputs
|
| 676 |
+
```
|
| 677 |
+
|
| 678 |
+
---
|
| 679 |
+
|
| 680 |
+
## ⚠️ Limitations & Known Issues
|
| 681 |
+
|
| 682 |
+
### Data Constraints
|
| 683 |
+
- **Limited Training Data**: ~1,800 samples — may show variance on edge cases
|
| 684 |
+
- **Class Imbalance**: Rear Crushed class has fewer samples, affecting recall
|
| 685 |
+
|
| 686 |
+
### Performance
|
| 687 |
+
|
| 688 |
+
| Metric | Value | Note |
|
| 689 |
+
|--------|-------|------|
|
| 690 |
+
| ResNet Inference | ~500ms | Fast, lower accuracy |
|
| 691 |
+
| Fusion Inference | 30-60s | Accurate, computationally heavy |
|
| 692 |
+
| Cold Startup | 4-5 min | HF Hub download + model warmup |
|
| 693 |
+
| GPU Memory | ~4GB | For Fusion model |
|
| 694 |
+
| ResNet Accuracy | 77% | Lightweight trade-off |
|
| 695 |
+
| Fusion Accuracy | 84% | Best accuracy |
|
| 696 |
+
|
| 697 |
+
### Technical Limitations
|
| 698 |
+
- Fusion accuracy is **7% higher** than ResNet (84% vs 77%)
|
| 699 |
+
- YOLO model may miss small or partially occluded damage
|
| 700 |
+
- Grad-CAM is for diagnostic/explainability purposes only
|
| 701 |
+
- Batch processing not currently supported
|
| 702 |
+
- FP16 Grad-CAM on CPU requires automatic FP32 cast (handled internally)
|
app.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import shutil
|
| 4 |
+
import logging
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 9 |
+
from fastapi.staticfiles import StaticFiles
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
|
| 13 |
+
from scripts.gradcam import get_resnet_gradcam, get_fusion_gradcam
|
| 14 |
+
from scripts.yolo_predict import get_yolo_damage_boxes
|
| 15 |
+
from scripts.load_models import initialize_models
|
| 16 |
+
|
| 17 |
+
# ---------------- LOGGING ----------------
|
| 18 |
+
logging.basicConfig(
|
| 19 |
+
level=logging.INFO,
|
| 20 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
# ---------------- ENV ----------------
|
| 26 |
+
load_dotenv()
|
| 27 |
+
|
| 28 |
+
# ---------------- DIRECTORIES ----------------
|
| 29 |
+
UPLOAD_DIR = "static/uploads"
|
| 30 |
+
RESULT_DIR = "static/results"
|
| 31 |
+
|
| 32 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 33 |
+
os.makedirs(RESULT_DIR, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
# ---------------- GLOBAL MODELS ----------------
|
| 36 |
+
resnet_predictor = None
|
| 37 |
+
fusion_predictor = None
|
| 38 |
+
yolo_model = None
|
| 39 |
+
|
| 40 |
+
CLASS_MAP = {
|
| 41 |
+
0: "Front Breakage",
|
| 42 |
+
1: "Front Crushed",
|
| 43 |
+
2: "Front Normal",
|
| 44 |
+
3: "Rear Breakage",
|
| 45 |
+
4: "Rear Crushed",
|
| 46 |
+
5: "Rear Normal"
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
# ---------------- FASTAPI STARTUP ----------------
|
| 50 |
+
@asynccontextmanager
|
| 51 |
+
async def lifespan(app: FastAPI):
|
| 52 |
+
global resnet_predictor, fusion_predictor, yolo_model
|
| 53 |
+
|
| 54 |
+
logger.info("Loading models at startup...")
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
resnet_predictor, fusion_predictor, yolo_model = initialize_models(CLASS_MAP)
|
| 58 |
+
logger.info("All models loaded successfully.")
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.exception("Model loading failed.")
|
| 62 |
+
raise RuntimeError(str(e))
|
| 63 |
+
|
| 64 |
+
yield
|
| 65 |
+
|
| 66 |
+
logger.info("Application shutdown.")
|
| 67 |
+
|
| 68 |
+
# ---------------- APP ----------------
|
| 69 |
+
app = FastAPI(lifespan=lifespan)
|
| 70 |
+
|
| 71 |
+
app.add_middleware(
|
| 72 |
+
CORSMiddleware,
|
| 73 |
+
allow_origins=["*"], # restrict this in production
|
| 74 |
+
allow_credentials=True,
|
| 75 |
+
allow_methods=["*"],
|
| 76 |
+
allow_headers=["*"],
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 80 |
+
|
| 81 |
+
# ---------------- HELPERS ----------------
|
| 82 |
+
def validate_image(upload_file: UploadFile):
|
| 83 |
+
if not upload_file.content_type.startswith("image/"):
|
| 84 |
+
raise HTTPException(
|
| 85 |
+
status_code=400,
|
| 86 |
+
detail="Uploaded file must be an image."
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def save_upload(upload_file: UploadFile):
|
| 91 |
+
unique_id = str(uuid.uuid4())
|
| 92 |
+
|
| 93 |
+
filename = f"{unique_id}_input.jpg"
|
| 94 |
+
file_path = os.path.join(UPLOAD_DIR, filename)
|
| 95 |
+
|
| 96 |
+
with open(file_path, "wb") as buffer:
|
| 97 |
+
shutil.copyfileobj(upload_file.file, buffer)
|
| 98 |
+
|
| 99 |
+
return unique_id, filename, file_path
|
| 100 |
+
|
| 101 |
+
# ---------------- ROUTES ----------------
|
| 102 |
+
@app.get("/")
|
| 103 |
+
def api_status():
|
| 104 |
+
return {"status": "API is running"}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@app.post("/predict")
|
| 108 |
+
async def predict_and_generate_cams(
|
| 109 |
+
file: UploadFile = File(...),
|
| 110 |
+
mode: str = "resnet"
|
| 111 |
+
):
|
| 112 |
+
validate_image(file)
|
| 113 |
+
|
| 114 |
+
mode = mode.lower()
|
| 115 |
+
|
| 116 |
+
if mode not in {"resnet", "fusion"}:
|
| 117 |
+
raise HTTPException(
|
| 118 |
+
status_code=400,
|
| 119 |
+
detail="mode must be 'resnet' or 'fusion'"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
unique_id, input_filename, input_path = save_upload(file)
|
| 124 |
+
|
| 125 |
+
if mode == "resnet":
|
| 126 |
+
output_name = f"{unique_id}_resnet.jpg"
|
| 127 |
+
output_path = os.path.join(RESULT_DIR, output_name)
|
| 128 |
+
|
| 129 |
+
get_resnet_gradcam(
|
| 130 |
+
input_path,
|
| 131 |
+
resnet_predictor,
|
| 132 |
+
output_path
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
selected_viz = f"/static/results/{output_name}"
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
"status": "success",
|
| 139 |
+
"mode": mode,
|
| 140 |
+
"original_image": f"/static/uploads/{input_filename}",
|
| 141 |
+
"selected_viz": selected_viz,
|
| 142 |
+
"resnet_viz": selected_viz,
|
| 143 |
+
"fusion_viz": None
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
output_name = f"{unique_id}_fusion.jpg"
|
| 147 |
+
output_path = os.path.join(RESULT_DIR, output_name)
|
| 148 |
+
|
| 149 |
+
get_fusion_gradcam(
|
| 150 |
+
input_path,
|
| 151 |
+
fusion_predictor,
|
| 152 |
+
output_path
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
selected_viz = f"/static/results/{output_name}"
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
"status": "success",
|
| 159 |
+
"mode": mode,
|
| 160 |
+
"original_image": f"/static/uploads/{input_filename}",
|
| 161 |
+
"selected_viz": selected_viz,
|
| 162 |
+
"resnet_viz": None,
|
| 163 |
+
"fusion_viz": selected_viz
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.exception("GradCAM generation failed.")
|
| 168 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@app.post("/predict/resnet")
|
| 172 |
+
async def resnet_prediction(image: UploadFile = File(...)):
|
| 173 |
+
validate_image(image)
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
pil_image = Image.open(image.file).convert("RGB")
|
| 177 |
+
|
| 178 |
+
result = resnet_predictor.resnet_predict(pil_image)
|
| 179 |
+
|
| 180 |
+
return {
|
| 181 |
+
"status": "success",
|
| 182 |
+
"prediction": result
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.exception("ResNet prediction failed.")
|
| 187 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@app.post("/predict/fusion")
|
| 191 |
+
async def fusion_prediction(image: UploadFile = File(...)):
|
| 192 |
+
validate_image(image)
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
pil_image = Image.open(image.file).convert("RGB")
|
| 196 |
+
|
| 197 |
+
result = fusion_predictor.predict(pil_image)
|
| 198 |
+
|
| 199 |
+
return {
|
| 200 |
+
"status": "success",
|
| 201 |
+
"prediction": result
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.exception("Fusion prediction failed.")
|
| 206 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@app.post("/predict/yolo")
|
| 210 |
+
async def yolo_detection(file: UploadFile = File(...)):
|
| 211 |
+
validate_image(file)
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
unique_id, input_filename, input_path = save_upload(file)
|
| 215 |
+
|
| 216 |
+
output_name = f"{unique_id}_yolo.jpg"
|
| 217 |
+
output_path = os.path.join(RESULT_DIR, output_name)
|
| 218 |
+
|
| 219 |
+
result = get_yolo_damage_boxes(
|
| 220 |
+
input_path,
|
| 221 |
+
yolo_model,
|
| 222 |
+
output_path
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return {
|
| 226 |
+
"status": "success",
|
| 227 |
+
"original_image": f"/static/uploads/{input_filename}",
|
| 228 |
+
"yolo_image": f"/static/results/{output_name}",
|
| 229 |
+
"detections": result["detections"],
|
| 230 |
+
"total_detections": result["total_detections"],
|
| 231 |
+
"message": result["message"]
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
except Exception as e:
|
| 235 |
+
logger.exception("YOLO detection failed.")
|
| 236 |
+
raise HTTPException(status_code=500, detail=str(e))
|
assets/fusion_classification_report.png
ADDED
|
assets/fusion_confusion_matrix.png
ADDED
|
assets/fusion_training_curves.png
ADDED
|
assets/resnet_classification_report.png
ADDED
|
assets/resnet_confusion_matrix.png
ADDED
|
assets/resnet_training_curves.png
ADDED
|
assets/yolo_detection_sample.jpg
ADDED
|
Git LFS Details
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
transformers
|
| 4 |
+
fastapi
|
| 5 |
+
uvicorn
|
| 6 |
+
dotenv
|
| 7 |
+
matplotlib
|
| 8 |
+
opencv-python
|
| 9 |
+
python-multipart
|
| 10 |
+
ultralytics
|
| 11 |
+
plotly
|
| 12 |
+
pandas
|
| 13 |
+
scikit-learn
|
| 14 |
+
seaborn
|
| 15 |
+
huggingface_hub
|
scripts/gradcam.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ------------------------------------------------------------------
|
| 12 |
+
# Lightweight hook manager — CPU-only, no logging, direct capture
|
| 13 |
+
# ------------------------------------------------------------------
|
| 14 |
+
class _GradCAMHook:
|
| 15 |
+
__slots__ = ("activation", "gradient", "fwd_handle", "bwd_handle")
|
| 16 |
+
|
| 17 |
+
def __init__(self, target_layer):
|
| 18 |
+
self.activation = None
|
| 19 |
+
self.gradient = None
|
| 20 |
+
self.fwd_handle = target_layer.register_forward_hook(self._fwd_hook)
|
| 21 |
+
self.bwd_handle = None
|
| 22 |
+
|
| 23 |
+
def _fwd_hook(self, module, inp, out):
|
| 24 |
+
self.activation = out
|
| 25 |
+
# Tensor-level hook is lighter than full backward hook or retain_grad()
|
| 26 |
+
self.bwd_handle = out.register_hook(self._bwd_hook)
|
| 27 |
+
|
| 28 |
+
def _bwd_hook(self, grad):
|
| 29 |
+
self.gradient = grad
|
| 30 |
+
|
| 31 |
+
def remove(self):
|
| 32 |
+
self.fwd_handle.remove()
|
| 33 |
+
if self.bwd_handle is not None:
|
| 34 |
+
self.bwd_handle.remove()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _postprocess_cam(cam_tensor, original_img, output_path, alpha=0.5, beta=0.6):
|
| 38 |
+
"""
|
| 39 |
+
CPU post-processing shared by both ResNet and Fusion.
|
| 40 |
+
cam_tensor: 2D torch tensor [H, W] on CPU, already ReLU'd
|
| 41 |
+
"""
|
| 42 |
+
h, w = original_img.height, original_img.width
|
| 43 |
+
|
| 44 |
+
# Normalize on CPU (vectorized)
|
| 45 |
+
cam_min = cam_tensor.min()
|
| 46 |
+
cam_max = cam_tensor.max()
|
| 47 |
+
if cam_max > cam_min:
|
| 48 |
+
cam_tensor = (cam_tensor - cam_min) / (cam_max - cam_min)
|
| 49 |
+
else:
|
| 50 |
+
cam_tensor = torch.zeros_like(cam_tensor)
|
| 51 |
+
|
| 52 |
+
# Convert to numpy once, then resize with OpenCV (very fast on CPU)
|
| 53 |
+
cam_np = cam_tensor.numpy()
|
| 54 |
+
cam_np = cv2.resize(cam_np, (w, h), interpolation=cv2.INTER_LINEAR)
|
| 55 |
+
|
| 56 |
+
cam_np = np.uint8(255 * cam_np)
|
| 57 |
+
heatmap = cv2.applyColorMap(cam_np, cv2.COLORMAP_JET)
|
| 58 |
+
|
| 59 |
+
original_bgr = cv2.cvtColor(np.array(original_img), cv2.COLOR_RGB2BGR)
|
| 60 |
+
overlay = cv2.addWeighted(original_bgr, alpha, heatmap, beta, 0)
|
| 61 |
+
cv2.imwrite(output_path, overlay)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ------------------------------------------------------------------
|
| 65 |
+
# Optimized ResNet Grad-CAM (CPU)
|
| 66 |
+
# ------------------------------------------------------------------
|
| 67 |
+
def get_resnet_gradcam(image_path, predictor, output_path):
|
| 68 |
+
logger.info("Starting ResNet Grad-CAM generation...")
|
| 69 |
+
|
| 70 |
+
model = predictor.model
|
| 71 |
+
model.eval()
|
| 72 |
+
|
| 73 |
+
target_layer = model.model.layer4[-1]
|
| 74 |
+
hook = _GradCAMHook(target_layer)
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
original_img = Image.open(image_path).convert("RGB")
|
| 78 |
+
input_tensor = predictor.test_transforms(original_img).unsqueeze(0)
|
| 79 |
+
|
| 80 |
+
output = model(input_tensor)
|
| 81 |
+
score, pred_class_idx = output[0].max(dim=0)
|
| 82 |
+
pred_class_idx = pred_class_idx.item()
|
| 83 |
+
|
| 84 |
+
logger.info(f"Predicted class index: {pred_class_idx}")
|
| 85 |
+
score.backward()
|
| 86 |
+
|
| 87 |
+
if hook.activation is None or hook.gradient is None:
|
| 88 |
+
raise RuntimeError("Failed to capture activations or gradients.")
|
| 89 |
+
|
| 90 |
+
# ----- Vectorized Grad-CAM on CPU -----
|
| 91 |
+
acts = hook.activation[0].detach().float() # [C, H, W]
|
| 92 |
+
grads = hook.gradient[0].detach().float() # [C, H, W]
|
| 93 |
+
|
| 94 |
+
weights = grads.mean(dim=(1, 2), keepdim=True) # [C, 1, 1]
|
| 95 |
+
cam = (weights * acts).sum(dim=0) # [H, W]
|
| 96 |
+
cam = F.relu(cam)
|
| 97 |
+
|
| 98 |
+
_postprocess_cam(cam, original_img, output_path, alpha=0.6, beta=0.4)
|
| 99 |
+
|
| 100 |
+
logger.info(f"ResNet Grad-CAM saved to: {output_path}")
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.exception("ResNet Grad-CAM generation failed.")
|
| 105 |
+
raise RuntimeError(f"ResNet Grad-CAM failed: {e}") from e
|
| 106 |
+
|
| 107 |
+
finally:
|
| 108 |
+
hook.remove()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ------------------------------------------------------------------
|
| 112 |
+
# Optimized Fusion Grad-CAM (EfficientNet + ConvNeXt) (CPU)
|
| 113 |
+
# ------------------------------------------------------------------
|
| 114 |
+
def get_fusion_gradcam(image_path, predictor, output_path):
|
| 115 |
+
logger.info("Starting Fusion Grad-CAM generation...")
|
| 116 |
+
|
| 117 |
+
model = predictor.model
|
| 118 |
+
model.eval()
|
| 119 |
+
|
| 120 |
+
# FIX: PyTorch CPU does not support FP16 convolutions well.
|
| 121 |
+
# If the model is HalfTensor, cast it to FP32 for this pass.
|
| 122 |
+
is_half = next(model.parameters()).dtype == torch.float16
|
| 123 |
+
if is_half:
|
| 124 |
+
logger.info("FP16 model detected on CPU. Converting to FP32 for compatibility.")
|
| 125 |
+
model = model.float()
|
| 126 |
+
|
| 127 |
+
target_layer = model.eff_features[-1]
|
| 128 |
+
hook = _GradCAMHook(target_layer)
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
original_img = Image.open(image_path).convert("RGB")
|
| 132 |
+
|
| 133 |
+
# CPU-only preprocessing (FloatTensor, no .to(device), no .half())
|
| 134 |
+
pixel_eff = predictor.eff_normalize(original_img).unsqueeze(0)
|
| 135 |
+
pixel_cnx = predictor.convnext_processor(
|
| 136 |
+
images=original_img, return_tensors="pt"
|
| 137 |
+
)["pixel_values"]
|
| 138 |
+
|
| 139 |
+
output = model(pixel_eff, pixel_cnx)
|
| 140 |
+
score, pred_class_idx = output[0].max(dim=0)
|
| 141 |
+
pred_class_idx = pred_class_idx.item()
|
| 142 |
+
|
| 143 |
+
logger.info(f"Predicted class index: {pred_class_idx}")
|
| 144 |
+
score.backward()
|
| 145 |
+
|
| 146 |
+
if hook.activation is None or hook.gradient is None:
|
| 147 |
+
raise RuntimeError("Failed to capture activations or gradients.")
|
| 148 |
+
|
| 149 |
+
# ----- Vectorized Grad-CAM on CPU -----
|
| 150 |
+
acts = hook.activation[0].detach().float() # [C, H, W]
|
| 151 |
+
grads = hook.gradient[0].detach().float() # [C, H, W]
|
| 152 |
+
|
| 153 |
+
weights = grads.mean(dim=(1, 2), keepdim=True) # [C, 1, 1]
|
| 154 |
+
cam = (weights * acts).sum(dim=0) # [H, W]
|
| 155 |
+
cam = F.relu(cam)
|
| 156 |
+
|
| 157 |
+
_postprocess_cam(cam, original_img, output_path, alpha=0.5, beta=0.6)
|
| 158 |
+
|
| 159 |
+
logger.info(f"Fusion Grad-CAM saved to: {output_path}")
|
| 160 |
+
return True
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.exception("Fusion Grad-CAM generation failed.")
|
| 164 |
+
raise RuntimeError(f"Fusion Grad-CAM failed: {e}") from e
|
| 165 |
+
|
| 166 |
+
finally:
|
| 167 |
+
hook.remove()
|
scripts/load_models.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
from ultralytics import YOLO
|
| 5 |
+
|
| 6 |
+
from .prediction_helper import (
|
| 7 |
+
ResnetCarDamagePredictor,
|
| 8 |
+
FusionCarDamagePredictor,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
MODEL_CONFIG = {
|
| 14 |
+
"resnet": {
|
| 15 |
+
"repo_id": "junaid17/car-damage-classifier",
|
| 16 |
+
"filename": "car-damage-classifier.pt",
|
| 17 |
+
},
|
| 18 |
+
"fusion": {
|
| 19 |
+
"repo_id": "junaid17/best_fusion_model_fp16",
|
| 20 |
+
"filename": "best_fusion_model_fp16.pt",
|
| 21 |
+
},
|
| 22 |
+
"yolo": {
|
| 23 |
+
"repo_id": "junaid17/Yolo_Model",
|
| 24 |
+
"filename": "damage_detector.pt",
|
| 25 |
+
},
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_checkpoint_path(model_key: str) -> Path:
|
| 30 |
+
if model_key not in MODEL_CONFIG:
|
| 31 |
+
raise ValueError(f"Unknown model key: {model_key}")
|
| 32 |
+
|
| 33 |
+
config = MODEL_CONFIG[model_key]
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
logger.info(f"Fetching {model_key} model from Hugging Face Hub...")
|
| 37 |
+
logger.info(f"Repo: {config['repo_id']}")
|
| 38 |
+
logger.info(f"File: {config['filename']}")
|
| 39 |
+
|
| 40 |
+
local_path = hf_hub_download(
|
| 41 |
+
repo_id=config["repo_id"],
|
| 42 |
+
filename=config["filename"],
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
logger.info(f"{model_key} model available at: {local_path}")
|
| 46 |
+
|
| 47 |
+
return Path(local_path)
|
| 48 |
+
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.exception(f"Failed to fetch {model_key} model.")
|
| 51 |
+
raise RuntimeError(f"Failed to load {model_key} checkpoint: {str(e)}")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ModelLoader:
|
| 55 |
+
def __init__(self):
|
| 56 |
+
logger.info("Initializing ModelLoader...")
|
| 57 |
+
|
| 58 |
+
def get_model_path(self, model_key: str) -> Path:
|
| 59 |
+
return get_checkpoint_path(model_key)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def initialize_models(class_map):
|
| 63 |
+
logger.info("Starting model initialization...")
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
resnet_path = get_checkpoint_path("resnet")
|
| 67 |
+
fusion_path = get_checkpoint_path("fusion")
|
| 68 |
+
yolo_path = get_checkpoint_path("yolo")
|
| 69 |
+
|
| 70 |
+
logger.info("Initializing ResNet predictor...")
|
| 71 |
+
resnet_predictor = ResnetCarDamagePredictor(
|
| 72 |
+
checkpoint_path=resnet_path,
|
| 73 |
+
class_map=class_map
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
logger.info("Initializing Fusion predictor...")
|
| 77 |
+
fusion_predictor = FusionCarDamagePredictor(
|
| 78 |
+
checkpoint_path=fusion_path,
|
| 79 |
+
class_map=class_map
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
logger.info("Initializing YOLO model...")
|
| 83 |
+
yolo_model = YOLO(str(yolo_path))
|
| 84 |
+
|
| 85 |
+
logger.info("All models initialized successfully.")
|
| 86 |
+
|
| 87 |
+
return resnet_predictor, fusion_predictor, yolo_model
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.exception("Model initialization failed.")
|
| 91 |
+
raise RuntimeError(f"Model initialization failed: {str(e)}")
|
scripts/prediction_helper.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torchvision import transforms, models
|
| 6 |
+
from PIL import Image, UnidentifiedImageError
|
| 7 |
+
from transformers import ConvNextModel, ConvNextImageProcessor
|
| 8 |
+
|
| 9 |
+
# ---------------- LOGGING SETUP ----------------
|
| 10 |
+
logging.basicConfig(
|
| 11 |
+
level=logging.INFO,
|
| 12 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------- RESNET MODEL ----------------
|
| 19 |
+
class Car_Classifier_Resnet(nn.Module):
|
| 20 |
+
def __init__(self, num_classes):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
logger.info("Initializing ResNet18 architecture...")
|
| 24 |
+
|
| 25 |
+
self.model = models.resnet18(weights="DEFAULT")
|
| 26 |
+
|
| 27 |
+
for param in self.model.parameters():
|
| 28 |
+
param.requires_grad = False
|
| 29 |
+
|
| 30 |
+
for param in self.model.layer3.parameters():
|
| 31 |
+
param.requires_grad = True
|
| 32 |
+
|
| 33 |
+
for param in self.model.layer4.parameters():
|
| 34 |
+
param.requires_grad = True
|
| 35 |
+
|
| 36 |
+
self.model.fc = nn.Sequential(
|
| 37 |
+
nn.Dropout(0.5),
|
| 38 |
+
nn.Linear(self.model.fc.in_features, 256),
|
| 39 |
+
nn.ReLU(),
|
| 40 |
+
nn.Dropout(0.3),
|
| 41 |
+
nn.Linear(256, num_classes)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
logger.info("ResNet architecture initialized successfully.")
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
return self.model(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ResnetCarDamagePredictor:
|
| 51 |
+
def __init__(self, checkpoint_path, class_map):
|
| 52 |
+
logger.info("Initializing ResNet predictor...")
|
| 53 |
+
|
| 54 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 55 |
+
self.class_map = class_map
|
| 56 |
+
|
| 57 |
+
logger.info(f"Using device for ResNet: {self.device}")
|
| 58 |
+
|
| 59 |
+
self.test_transforms = transforms.Compose([
|
| 60 |
+
transforms.Resize((128, 128)),
|
| 61 |
+
transforms.ToTensor(),
|
| 62 |
+
transforms.Normalize(
|
| 63 |
+
[0.485, 0.456, 0.406],
|
| 64 |
+
[0.229, 0.224, 0.225]
|
| 65 |
+
)
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
self.model = Car_Classifier_Resnet(num_classes=len(class_map))
|
| 70 |
+
|
| 71 |
+
logger.info(f"Loading ResNet checkpoint from: {checkpoint_path}")
|
| 72 |
+
|
| 73 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 74 |
+
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
| 75 |
+
|
| 76 |
+
self.model.load_state_dict(state_dict)
|
| 77 |
+
self.model.to(self.device)
|
| 78 |
+
self.model.eval()
|
| 79 |
+
|
| 80 |
+
logger.info("ResNet model loaded successfully.")
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.exception("Failed to load ResNet model.")
|
| 84 |
+
raise RuntimeError(f"Failed to load ResNet model: {str(e)}")
|
| 85 |
+
|
| 86 |
+
def resnet_predict(self, image_input):
|
| 87 |
+
logger.info("Starting ResNet prediction...")
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
if isinstance(image_input, str):
|
| 91 |
+
logger.info(f"Loading image from file path: {image_input}")
|
| 92 |
+
image = Image.open(image_input).convert("RGB")
|
| 93 |
+
|
| 94 |
+
elif isinstance(image_input, Image.Image):
|
| 95 |
+
logger.info("Using PIL image input.")
|
| 96 |
+
image = image_input.convert("RGB")
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
raise TypeError("image_input must be a file path or PIL.Image")
|
| 100 |
+
|
| 101 |
+
image = self.test_transforms(image)
|
| 102 |
+
image = image.unsqueeze(0).to(self.device)
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
outputs = self.model(image)
|
| 106 |
+
|
| 107 |
+
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
|
| 108 |
+
|
| 109 |
+
class_probs = {
|
| 110 |
+
self.class_map[i]: float(probs[i].item())
|
| 111 |
+
for i in range(len(self.class_map))
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
sorted_probs = dict(
|
| 115 |
+
sorted(class_probs.items(), key=lambda x: x[1], reverse=True)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
logger.info("ResNet prediction completed successfully.")
|
| 119 |
+
|
| 120 |
+
return sorted_probs
|
| 121 |
+
|
| 122 |
+
except UnidentifiedImageError:
|
| 123 |
+
logger.error("Invalid image file provided to ResNet predictor.")
|
| 124 |
+
raise ValueError("Invalid image file provided")
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.exception("ResNet prediction failed.")
|
| 128 |
+
raise RuntimeError(f"ResNet prediction failed: {str(e)}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ---------------- FUSION MODEL ----------------
|
| 132 |
+
class FusionClassifier(nn.Module):
|
| 133 |
+
def __init__(self, num_classes, convnext_model_name="facebook/convnext-small-224"):
|
| 134 |
+
super().__init__()
|
| 135 |
+
|
| 136 |
+
logger.info("Initializing Fusion model architecture...")
|
| 137 |
+
|
| 138 |
+
eff = models.efficientnet_v2_s(
|
| 139 |
+
weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
for param in eff.parameters():
|
| 143 |
+
param.requires_grad = False
|
| 144 |
+
|
| 145 |
+
for param in eff.features[5].parameters():
|
| 146 |
+
param.requires_grad = True
|
| 147 |
+
|
| 148 |
+
for param in eff.features[6].parameters():
|
| 149 |
+
param.requires_grad = True
|
| 150 |
+
|
| 151 |
+
for param in eff.features[7].parameters():
|
| 152 |
+
param.requires_grad = True
|
| 153 |
+
|
| 154 |
+
self.eff_features = eff.features
|
| 155 |
+
self.eff_avgpool = eff.avgpool
|
| 156 |
+
self.eff_out_dim = eff.classifier[1].in_features
|
| 157 |
+
|
| 158 |
+
logger.info("Loading ConvNeXt backbone...")
|
| 159 |
+
|
| 160 |
+
cnx = ConvNextModel.from_pretrained(convnext_model_name)
|
| 161 |
+
|
| 162 |
+
for param in cnx.parameters():
|
| 163 |
+
param.requires_grad = False
|
| 164 |
+
|
| 165 |
+
for param in cnx.encoder.stages[2].parameters():
|
| 166 |
+
param.requires_grad = True
|
| 167 |
+
|
| 168 |
+
for param in cnx.encoder.stages[3].parameters():
|
| 169 |
+
param.requires_grad = True
|
| 170 |
+
|
| 171 |
+
for param in cnx.layernorm.parameters():
|
| 172 |
+
param.requires_grad = True
|
| 173 |
+
|
| 174 |
+
self.cnx_backbone = cnx
|
| 175 |
+
self.cnx_out_dim = 768
|
| 176 |
+
|
| 177 |
+
fused_dim = self.eff_out_dim + self.cnx_out_dim
|
| 178 |
+
|
| 179 |
+
self.fusion_head = nn.Sequential(
|
| 180 |
+
nn.Dropout(p=0.4),
|
| 181 |
+
nn.Linear(fused_dim, 512),
|
| 182 |
+
nn.LayerNorm(512),
|
| 183 |
+
nn.GELU(),
|
| 184 |
+
nn.Dropout(p=0.3),
|
| 185 |
+
nn.Linear(512, 256),
|
| 186 |
+
nn.LayerNorm(256),
|
| 187 |
+
nn.GELU(),
|
| 188 |
+
nn.Dropout(p=0.2),
|
| 189 |
+
nn.Linear(256, num_classes)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
logger.info("Fusion architecture initialized successfully.")
|
| 193 |
+
|
| 194 |
+
def forward(self, pixel_values_eff, pixel_values_cnx):
|
| 195 |
+
x_eff = self.eff_features(pixel_values_eff)
|
| 196 |
+
x_eff = self.eff_avgpool(x_eff)
|
| 197 |
+
x_eff = torch.flatten(x_eff, 1)
|
| 198 |
+
|
| 199 |
+
cnx_out = self.cnx_backbone(
|
| 200 |
+
pixel_values=pixel_values_cnx,
|
| 201 |
+
return_dict=True
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
x_cnx = cnx_out.pooler_output
|
| 205 |
+
fused = torch.cat([x_eff, x_cnx], dim=1)
|
| 206 |
+
|
| 207 |
+
return self.fusion_head(fused)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class FusionCarDamagePredictor:
|
| 211 |
+
def __init__(self, checkpoint_path, class_map, convnext_model_name="facebook/convnext-small-224"):
|
| 212 |
+
logger.info("Initializing Fusion predictor...")
|
| 213 |
+
|
| 214 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 215 |
+
self.class_map = class_map
|
| 216 |
+
|
| 217 |
+
logger.info(f"Using device for Fusion: {self.device}")
|
| 218 |
+
|
| 219 |
+
self.eff_normalize = transforms.Compose([
|
| 220 |
+
transforms.Resize((260, 260)),
|
| 221 |
+
transforms.ToTensor(),
|
| 222 |
+
transforms.Normalize(
|
| 223 |
+
[0.485, 0.456, 0.406],
|
| 224 |
+
[0.229, 0.224, 0.225]
|
| 225 |
+
)
|
| 226 |
+
])
|
| 227 |
+
|
| 228 |
+
logger.info("Loading ConvNeXt image processor...")
|
| 229 |
+
self.convnext_processor = ConvNextImageProcessor.from_pretrained(
|
| 230 |
+
convnext_model_name
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
self.model = FusionClassifier(
|
| 235 |
+
num_classes=len(class_map),
|
| 236 |
+
convnext_model_name=convnext_model_name
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
logger.info(f"Loading Fusion checkpoint from: {checkpoint_path}")
|
| 240 |
+
|
| 241 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 242 |
+
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
| 243 |
+
|
| 244 |
+
first_tensor = next(iter(state_dict.values()))
|
| 245 |
+
|
| 246 |
+
if first_tensor.dtype == torch.float16:
|
| 247 |
+
logger.info("FP16 checkpoint detected. Converting model to half precision.")
|
| 248 |
+
self.model = self.model.half()
|
| 249 |
+
|
| 250 |
+
self.model.load_state_dict(state_dict)
|
| 251 |
+
self.model.to(self.device)
|
| 252 |
+
self.model.eval()
|
| 253 |
+
|
| 254 |
+
logger.info("Fusion model loaded successfully.")
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
logger.exception("Failed to load Fusion model.")
|
| 258 |
+
raise RuntimeError(f"Failed to load Fusion model: {str(e)}")
|
| 259 |
+
|
| 260 |
+
def predict(self, image_input):
|
| 261 |
+
logger.info("Starting Fusion prediction...")
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
if isinstance(image_input, str):
|
| 265 |
+
logger.info(f"Loading image from file path: {image_input}")
|
| 266 |
+
image = Image.open(image_input).convert("RGB")
|
| 267 |
+
|
| 268 |
+
elif isinstance(image_input, Image.Image):
|
| 269 |
+
logger.info("Using PIL image input.")
|
| 270 |
+
image = image_input.convert("RGB")
|
| 271 |
+
|
| 272 |
+
else:
|
| 273 |
+
raise TypeError("image_input must be a file path or PIL.Image")
|
| 274 |
+
|
| 275 |
+
pixel_eff = self.eff_normalize(image)
|
| 276 |
+
pixel_eff = pixel_eff.unsqueeze(0).to(self.device)
|
| 277 |
+
|
| 278 |
+
inputs_cnx = self.convnext_processor(
|
| 279 |
+
images=image,
|
| 280 |
+
return_tensors="pt"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
pixel_cnx = inputs_cnx["pixel_values"].to(self.device)
|
| 284 |
+
|
| 285 |
+
if next(self.model.parameters()).dtype == torch.float16:
|
| 286 |
+
logger.info("Converting input tensors to FP16.")
|
| 287 |
+
pixel_eff = pixel_eff.half()
|
| 288 |
+
pixel_cnx = pixel_cnx.half()
|
| 289 |
+
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
logits = self.model(pixel_eff, pixel_cnx)
|
| 292 |
+
probs = torch.nn.functional.softmax(logits, dim=1)[0]
|
| 293 |
+
|
| 294 |
+
class_probs = {
|
| 295 |
+
self.class_map[i]: float(probs[i].item())
|
| 296 |
+
for i in range(len(self.class_map))
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
sorted_probs = dict(
|
| 300 |
+
sorted(class_probs.items(), key=lambda x: x[1], reverse=True)
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
logger.info("Fusion prediction completed successfully.")
|
| 304 |
+
|
| 305 |
+
return sorted_probs
|
| 306 |
+
|
| 307 |
+
except UnidentifiedImageError:
|
| 308 |
+
logger.error("Invalid image file provided to Fusion predictor.")
|
| 309 |
+
raise ValueError("Invalid image file provided")
|
| 310 |
+
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.exception("Fusion prediction failed.")
|
| 313 |
+
raise RuntimeError(f"Fusion prediction failed: {str(e)}")
|
scripts/yolo_predict.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import logging
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_yolo_damage_boxes(image_path, yolo_model, output_path):
|
| 9 |
+
logger.info("Starting YOLO damage detection...")
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
image = Image.open(image_path).convert("RGB")
|
| 13 |
+
|
| 14 |
+
results = yolo_model.predict(
|
| 15 |
+
source=image,
|
| 16 |
+
conf=0.05,
|
| 17 |
+
imgsz=640,
|
| 18 |
+
verbose=False
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
result = results[0]
|
| 22 |
+
boxes = result.boxes
|
| 23 |
+
detections = []
|
| 24 |
+
|
| 25 |
+
if boxes is not None and len(boxes) > 0:
|
| 26 |
+
logger.info(f"{len(boxes)} detections found.")
|
| 27 |
+
|
| 28 |
+
for box in boxes:
|
| 29 |
+
conf = float(box.conf[0])
|
| 30 |
+
cls_id = int(box.cls[0])
|
| 31 |
+
|
| 32 |
+
label = yolo_model.names[cls_id]
|
| 33 |
+
|
| 34 |
+
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
| 35 |
+
|
| 36 |
+
detections.append({
|
| 37 |
+
"label": label,
|
| 38 |
+
"confidence": round(conf, 4),
|
| 39 |
+
"box": [x1, y1, x2, y2]
|
| 40 |
+
})
|
| 41 |
+
|
| 42 |
+
else:
|
| 43 |
+
logger.info("No detections found.")
|
| 44 |
+
|
| 45 |
+
plotted = result.plot()
|
| 46 |
+
|
| 47 |
+
cv2.imwrite(output_path, plotted)
|
| 48 |
+
|
| 49 |
+
logger.info(f"YOLO output saved to: {output_path}")
|
| 50 |
+
|
| 51 |
+
return {
|
| 52 |
+
"detections": detections,
|
| 53 |
+
"total_detections": len(detections),
|
| 54 |
+
"message": (
|
| 55 |
+
"No damage detected"
|
| 56 |
+
if len(detections) == 0
|
| 57 |
+
else "Detections found"
|
| 58 |
+
)
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.exception("YOLO detection failed.")
|
| 63 |
+
raise RuntimeError(f"YOLO failed: {str(e)}")
|
src/config.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# ---------------- PATHS ----------------
|
| 5 |
+
BASE_DIR = Path(__file__).resolve().parents[1]
|
| 6 |
+
|
| 7 |
+
DATASET_DIR = BASE_DIR / "data" / "dataset"
|
| 8 |
+
CHECKPOINT_DIR = BASE_DIR / "checkpoints"
|
| 9 |
+
EXPORT_DIR = BASE_DIR / "exports"
|
| 10 |
+
|
| 11 |
+
CHECKPOINT_DIR.mkdir(exist_ok=True)
|
| 12 |
+
EXPORT_DIR.mkdir(exist_ok=True)
|
| 13 |
+
|
| 14 |
+
# ---------------- TRAINING ----------------
|
| 15 |
+
BATCH_SIZE = 16
|
| 16 |
+
NUM_WORKERS = 4
|
| 17 |
+
LEARNING_RATE = 1e-4
|
| 18 |
+
WEIGHT_DECAY = 1e-5
|
| 19 |
+
VALIDATION_SPLIT = 0.2
|
| 20 |
+
RANDOM_SEED = 42
|
| 21 |
+
|
| 22 |
+
# TEMP DEV SETTING
|
| 23 |
+
EPOCHS = 1
|
| 24 |
+
|
| 25 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
|
| 27 |
+
# ---------------- IMAGE SIZES ----------------
|
| 28 |
+
RESNET_IMAGE_SIZE = 128
|
| 29 |
+
FUSION_IMAGE_SIZE = 260
|
| 30 |
+
YOLO_IMAGE_SIZE = 640
|
| 31 |
+
|
| 32 |
+
# ---------------- YOLO ----------------
|
| 33 |
+
YOLO_BASE_MODEL = "yolo11m.pt"
|
| 34 |
+
YOLO_BATCH_SIZE = 10
|
| 35 |
+
YOLO_EPOCHS = 1
|
| 36 |
+
YOLO_CONFIDENCE_THRESHOLD = 0.05
|
| 37 |
+
|
| 38 |
+
# ---------------- CLASSES ----------------
|
| 39 |
+
CLASS_NAMES = [
|
| 40 |
+
"F_Breakage",
|
| 41 |
+
"F_Crushed",
|
| 42 |
+
"F_Normal",
|
| 43 |
+
"R_Breakage",
|
| 44 |
+
"R_Crushed",
|
| 45 |
+
"R_Normal"
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
CLASS_MAP = {idx: cls for idx, cls in enumerate(CLASS_NAMES)}
|
| 51 |
+
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CLASS_NAMES)}
|
| 52 |
+
|
| 53 |
+
NUM_CLASSES = len(CLASS_NAMES)
|
| 54 |
+
|
| 55 |
+
# ---------------- HUGGING FACE ----------------
|
| 56 |
+
HF_USERNAME = "junaid17"
|
| 57 |
+
|
| 58 |
+
HF_RESNET_REPO = "new-car-damage-classifier"
|
| 59 |
+
HF_FUSION_REPO = "new-best-fusion-model-fp16"
|
| 60 |
+
HF_YOLO_REPO = "new-Yolo-Model"
|
src/data/augmentation.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
|
| 4 |
+
from src.config import RESNET_IMAGE_SIZE, FUSION_IMAGE_SIZE
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_resnet_train_transforms():
|
| 10 |
+
logger.info("Creating ResNet training transforms...")
|
| 11 |
+
|
| 12 |
+
return transforms.Compose([
|
| 13 |
+
transforms.Resize((RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE)),
|
| 14 |
+
transforms.RandomHorizontalFlip(),
|
| 15 |
+
transforms.RandomRotation(15),
|
| 16 |
+
transforms.ColorJitter(
|
| 17 |
+
brightness=0.2,
|
| 18 |
+
contrast=0.2,
|
| 19 |
+
saturation=0.2
|
| 20 |
+
),
|
| 21 |
+
transforms.ToTensor(),
|
| 22 |
+
transforms.Normalize(
|
| 23 |
+
mean=[0.485, 0.456, 0.406],
|
| 24 |
+
std=[0.229, 0.224, 0.225]
|
| 25 |
+
)
|
| 26 |
+
])
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_resnet_val_transforms():
|
| 30 |
+
logger.info("Creating ResNet validation transforms...")
|
| 31 |
+
|
| 32 |
+
return transforms.Compose([
|
| 33 |
+
transforms.Resize((RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE)),
|
| 34 |
+
transforms.ToTensor(),
|
| 35 |
+
transforms.Normalize(
|
| 36 |
+
mean=[0.485, 0.456, 0.406],
|
| 37 |
+
std=[0.229, 0.224, 0.225]
|
| 38 |
+
)
|
| 39 |
+
])
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_fusion_train_transforms():
|
| 43 |
+
logger.info("Creating Fusion training transforms...")
|
| 44 |
+
|
| 45 |
+
return transforms.Compose([
|
| 46 |
+
transforms.Resize((FUSION_IMAGE_SIZE, FUSION_IMAGE_SIZE)),
|
| 47 |
+
transforms.RandomHorizontalFlip(),
|
| 48 |
+
transforms.RandomRotation(10),
|
| 49 |
+
transforms.ColorJitter(
|
| 50 |
+
brightness=0.15,
|
| 51 |
+
contrast=0.15,
|
| 52 |
+
saturation=0.15
|
| 53 |
+
),
|
| 54 |
+
transforms.ToTensor(),
|
| 55 |
+
transforms.Normalize(
|
| 56 |
+
mean=[0.485, 0.456, 0.406],
|
| 57 |
+
std=[0.229, 0.224, 0.225]
|
| 58 |
+
)
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_fusion_val_transforms():
|
| 63 |
+
logger.info("Creating Fusion validation transforms...")
|
| 64 |
+
|
| 65 |
+
return transforms.Compose([
|
| 66 |
+
transforms.Resize((FUSION_IMAGE_SIZE, FUSION_IMAGE_SIZE)),
|
| 67 |
+
transforms.ToTensor(),
|
| 68 |
+
transforms.Normalize(
|
| 69 |
+
mean=[0.485, 0.456, 0.406],
|
| 70 |
+
std=[0.229, 0.224, 0.225]
|
| 71 |
+
)
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
logging.basicConfig(
|
| 77 |
+
level=logging.INFO,
|
| 78 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
resnet_train = get_resnet_train_transforms()
|
| 82 |
+
resnet_val = get_resnet_val_transforms()
|
| 83 |
+
fusion_train = get_fusion_train_transforms()
|
| 84 |
+
fusion_val = get_fusion_val_transforms()
|
| 85 |
+
|
| 86 |
+
print("\nTransforms created successfully:")
|
| 87 |
+
print("ResNet Train:", resnet_train)
|
| 88 |
+
print("ResNet Val:", resnet_val)
|
| 89 |
+
print("Fusion Train:", fusion_train)
|
| 90 |
+
print("Fusion Val:", fusion_val)
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
from transformers import ConvNextImageProcessor
|
| 5 |
+
|
| 6 |
+
from src.config import (
|
| 7 |
+
BATCH_SIZE,
|
| 8 |
+
NUM_WORKERS
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from src.data.ingestion import collect_image_paths
|
| 12 |
+
from src.data.preprocessing import split_dataset
|
| 13 |
+
from src.data.augmentation import (
|
| 14 |
+
get_resnet_train_transforms,
|
| 15 |
+
get_resnet_val_transforms,
|
| 16 |
+
get_fusion_train_transforms,
|
| 17 |
+
get_fusion_val_transforms
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ResNetDataset(Dataset):
|
| 24 |
+
def __init__(self, samples, transforms=None):
|
| 25 |
+
self.samples = samples
|
| 26 |
+
self.transforms = transforms
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return len(self.samples)
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, idx):
|
| 32 |
+
image_path, label = self.samples[idx]
|
| 33 |
+
|
| 34 |
+
image = Image.open(image_path).convert("RGB")
|
| 35 |
+
|
| 36 |
+
if self.transforms:
|
| 37 |
+
image = self.transforms(image)
|
| 38 |
+
|
| 39 |
+
return image, label
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class FusionDataset(Dataset):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
samples,
|
| 46 |
+
transforms=None,
|
| 47 |
+
convnext_model_name="facebook/convnext-small-224"
|
| 48 |
+
):
|
| 49 |
+
self.samples = samples
|
| 50 |
+
self.transforms = transforms
|
| 51 |
+
|
| 52 |
+
logger.info("Loading ConvNeXt processor...")
|
| 53 |
+
|
| 54 |
+
self.processor = ConvNextImageProcessor.from_pretrained(
|
| 55 |
+
convnext_model_name
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def __len__(self):
|
| 59 |
+
return len(self.samples)
|
| 60 |
+
|
| 61 |
+
def __getitem__(self, idx):
|
| 62 |
+
image_path, label = self.samples[idx]
|
| 63 |
+
|
| 64 |
+
image = Image.open(image_path).convert("RGB")
|
| 65 |
+
|
| 66 |
+
if self.transforms:
|
| 67 |
+
eff_tensor = self.transforms(image)
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError("Fusion transforms are required.")
|
| 70 |
+
|
| 71 |
+
convnext_inputs = self.processor(
|
| 72 |
+
images=image,
|
| 73 |
+
return_tensors="pt"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
convnext_tensor = convnext_inputs["pixel_values"].squeeze(0)
|
| 77 |
+
|
| 78 |
+
return {
|
| 79 |
+
"pixel_values_eff": eff_tensor,
|
| 80 |
+
"pixel_values_cnx": convnext_tensor,
|
| 81 |
+
"labels": label
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def create_resnet_dataloaders():
|
| 86 |
+
logger.info("Creating ResNet dataloaders...")
|
| 87 |
+
|
| 88 |
+
samples = collect_image_paths()
|
| 89 |
+
train_data, val_data = split_dataset(samples)
|
| 90 |
+
|
| 91 |
+
train_dataset = ResNetDataset(
|
| 92 |
+
train_data,
|
| 93 |
+
transforms=get_resnet_train_transforms()
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
val_dataset = ResNetDataset(
|
| 97 |
+
val_data,
|
| 98 |
+
transforms=get_resnet_val_transforms()
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
train_loader = DataLoader(
|
| 102 |
+
train_dataset,
|
| 103 |
+
batch_size=BATCH_SIZE,
|
| 104 |
+
shuffle=True,
|
| 105 |
+
num_workers=NUM_WORKERS
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
val_loader = DataLoader(
|
| 109 |
+
val_dataset,
|
| 110 |
+
batch_size=BATCH_SIZE,
|
| 111 |
+
shuffle=False,
|
| 112 |
+
num_workers=NUM_WORKERS
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
logger.info("ResNet dataloaders created successfully.")
|
| 116 |
+
|
| 117 |
+
return train_loader, val_loader
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def create_fusion_dataloaders():
|
| 121 |
+
logger.info("Creating Fusion dataloaders...")
|
| 122 |
+
|
| 123 |
+
samples = collect_image_paths()
|
| 124 |
+
train_data, val_data = split_dataset(samples)
|
| 125 |
+
|
| 126 |
+
train_dataset = FusionDataset(
|
| 127 |
+
train_data,
|
| 128 |
+
transforms=get_fusion_train_transforms()
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
val_dataset = FusionDataset(
|
| 132 |
+
val_data,
|
| 133 |
+
transforms=get_fusion_val_transforms()
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
train_loader = DataLoader(
|
| 137 |
+
train_dataset,
|
| 138 |
+
batch_size=BATCH_SIZE,
|
| 139 |
+
shuffle=True,
|
| 140 |
+
num_workers=NUM_WORKERS
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
val_loader = DataLoader(
|
| 144 |
+
val_dataset,
|
| 145 |
+
batch_size=BATCH_SIZE,
|
| 146 |
+
shuffle=False,
|
| 147 |
+
num_workers=NUM_WORKERS
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
logger.info("Fusion dataloaders created successfully.")
|
| 151 |
+
|
| 152 |
+
return train_loader, val_loader
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
logging.basicConfig(
|
| 157 |
+
level=logging.INFO,
|
| 158 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
print("\nTesting ResNet dataloaders...\n")
|
| 162 |
+
|
| 163 |
+
train_loader, val_loader = create_resnet_dataloaders()
|
| 164 |
+
|
| 165 |
+
images, labels = next(iter(train_loader))
|
| 166 |
+
|
| 167 |
+
print("ResNet batch shape:", images.shape)
|
| 168 |
+
print("ResNet labels shape:", labels.shape)
|
| 169 |
+
|
| 170 |
+
print("\nTesting Fusion dataloaders...\n")
|
| 171 |
+
|
| 172 |
+
train_loader, val_loader = create_fusion_dataloaders()
|
| 173 |
+
|
| 174 |
+
batch = next(iter(train_loader))
|
| 175 |
+
|
| 176 |
+
print(
|
| 177 |
+
"Fusion EfficientNet batch shape:",
|
| 178 |
+
batch["pixel_values_eff"].shape
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
print(
|
| 182 |
+
"Fusion ConvNeXt batch shape:",
|
| 183 |
+
batch["pixel_values_cnx"].shape
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
print(
|
| 187 |
+
"Fusion labels shape:",
|
| 188 |
+
batch["labels"].shape
|
| 189 |
+
)
|
src/data/ingestion.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from src.config import DATASET_DIR, CLASS_TO_IDX
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
VALID_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def collect_image_paths():
|
| 12 |
+
logger.info("Starting dataset ingestion...")
|
| 13 |
+
|
| 14 |
+
if not DATASET_DIR.exists():
|
| 15 |
+
raise FileNotFoundError(f"Dataset directory not found: {DATASET_DIR}")
|
| 16 |
+
|
| 17 |
+
samples = []
|
| 18 |
+
|
| 19 |
+
for class_name, label in CLASS_TO_IDX.items():
|
| 20 |
+
class_dir = DATASET_DIR / class_name
|
| 21 |
+
|
| 22 |
+
if not class_dir.exists():
|
| 23 |
+
logger.warning(f"Missing class folder: {class_dir}")
|
| 24 |
+
continue
|
| 25 |
+
|
| 26 |
+
image_count = 0
|
| 27 |
+
|
| 28 |
+
for image_path in class_dir.iterdir():
|
| 29 |
+
if image_path.suffix.lower() in VALID_EXTENSIONS:
|
| 30 |
+
samples.append((str(image_path), label))
|
| 31 |
+
image_count += 1
|
| 32 |
+
|
| 33 |
+
logger.info(f"{class_name}: {image_count} images found")
|
| 34 |
+
|
| 35 |
+
if not samples:
|
| 36 |
+
raise ValueError("No valid images found in dataset.")
|
| 37 |
+
|
| 38 |
+
logger.info(f"Total images collected: {len(samples)}")
|
| 39 |
+
|
| 40 |
+
return samples
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=logging.INFO,
|
| 46 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
data = collect_image_paths()
|
| 50 |
+
|
| 51 |
+
print(f"\nTotal samples: {len(data)}")
|
| 52 |
+
print("First 5 samples:")
|
| 53 |
+
|
| 54 |
+
for sample in data[:5]:
|
| 55 |
+
print(sample)
|
src/data/preprocessing.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from collections import Counter
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
|
| 5 |
+
from src.config import VALIDATION_SPLIT, RANDOM_SEED
|
| 6 |
+
from src.data.ingestion import collect_image_paths
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def split_dataset(samples):
|
| 12 |
+
logger.info("Starting dataset preprocessing...")
|
| 13 |
+
|
| 14 |
+
if not samples:
|
| 15 |
+
raise ValueError("Empty dataset provided.")
|
| 16 |
+
|
| 17 |
+
image_paths = [sample[0] for sample in samples]
|
| 18 |
+
labels = [sample[1] for sample in samples]
|
| 19 |
+
|
| 20 |
+
logger.info(f"Total samples before split: {len(samples)}")
|
| 21 |
+
|
| 22 |
+
train_paths, val_paths, train_labels, val_labels = train_test_split(
|
| 23 |
+
image_paths,
|
| 24 |
+
labels,
|
| 25 |
+
test_size=VALIDATION_SPLIT,
|
| 26 |
+
stratify=labels,
|
| 27 |
+
random_state=RANDOM_SEED
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
train_data = list(zip(train_paths, train_labels))
|
| 31 |
+
val_data = list(zip(val_paths, val_labels))
|
| 32 |
+
|
| 33 |
+
logger.info(f"Training samples: {len(train_data)}")
|
| 34 |
+
logger.info(f"Validation samples: {len(val_data)}")
|
| 35 |
+
|
| 36 |
+
logger.info(f"Train distribution: {Counter(train_labels)}")
|
| 37 |
+
logger.info(f"Validation distribution: {Counter(val_labels)}")
|
| 38 |
+
|
| 39 |
+
return train_data, val_data
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
logging.basicConfig(
|
| 44 |
+
level=logging.INFO,
|
| 45 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
samples = collect_image_paths()
|
| 49 |
+
|
| 50 |
+
train_data, val_data = split_dataset(samples)
|
| 51 |
+
|
| 52 |
+
print("\nTrain sample preview:")
|
| 53 |
+
for sample in train_data[:5]:
|
| 54 |
+
print(sample)
|
| 55 |
+
|
| 56 |
+
print("\nValidation sample preview:")
|
| 57 |
+
for sample in val_data[:5]:
|
| 58 |
+
print(sample)
|
src/export/conver_model.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from src.config import DEVICE, NUM_CLASSES, CHECKPOINT_DIR
|
| 6 |
+
from src.models.fusion_model import FusionClassifier
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
INPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model.pt"
|
| 11 |
+
OUTPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model_fp16.pt"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def convert_fusion_to_fp16():
|
| 15 |
+
logger.info("Initializing Fusion model for FP16 conversion...")
|
| 16 |
+
|
| 17 |
+
if not INPUT_CHECKPOINT.exists():
|
| 18 |
+
raise FileNotFoundError(
|
| 19 |
+
f"Fusion checkpoint not found: {INPUT_CHECKPOINT}"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
model = FusionClassifier(
|
| 23 |
+
num_classes=NUM_CLASSES
|
| 24 |
+
).to(DEVICE)
|
| 25 |
+
|
| 26 |
+
logger.info(f"Loading checkpoint from: {INPUT_CHECKPOINT}")
|
| 27 |
+
|
| 28 |
+
checkpoint = torch.load(
|
| 29 |
+
INPUT_CHECKPOINT,
|
| 30 |
+
map_location=DEVICE
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
| 34 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 35 |
+
else:
|
| 36 |
+
model.load_state_dict(checkpoint)
|
| 37 |
+
|
| 38 |
+
logger.info("Model weights loaded successfully.")
|
| 39 |
+
|
| 40 |
+
model.eval()
|
| 41 |
+
|
| 42 |
+
logger.info("Converting model to FP16...")
|
| 43 |
+
|
| 44 |
+
model = model.half()
|
| 45 |
+
|
| 46 |
+
torch.save(
|
| 47 |
+
model.state_dict(),
|
| 48 |
+
OUTPUT_CHECKPOINT
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
size_mb = os.path.getsize(OUTPUT_CHECKPOINT) / (1024 * 1024)
|
| 52 |
+
|
| 53 |
+
logger.info(f"FP16 model saved at: {OUTPUT_CHECKPOINT}")
|
| 54 |
+
logger.info(f"FP16 model size: {size_mb:.2f} MB")
|
| 55 |
+
|
| 56 |
+
return OUTPUT_CHECKPOINT
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
logging.basicConfig(
|
| 61 |
+
level=logging.INFO,
|
| 62 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
fp16_path = convert_fusion_to_fp16()
|
| 66 |
+
|
| 67 |
+
print("\nFusion FP16 conversion completed successfully.")
|
| 68 |
+
print(f"Saved model: {fp16_path}")
|
src/export/upload_to_huggingface.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from huggingface_hub import HfApi
|
| 5 |
+
|
| 6 |
+
from src.config import CHECKPOINT_DIR
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
HF_USERNAME = "junaid17"
|
| 13 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 14 |
+
|
| 15 |
+
if not HF_TOKEN:
|
| 16 |
+
raise ValueError(
|
| 17 |
+
"HF_TOKEN not found in .env file."
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
MODELS = {
|
| 21 |
+
"new-damagelens-resnet-classifier": {
|
| 22 |
+
"path": CHECKPOINT_DIR / "best_resnet_model.pt",
|
| 23 |
+
"filename": "new_best_resnet_model.pt"
|
| 24 |
+
},
|
| 25 |
+
|
| 26 |
+
"new-damagelens-fusion-fp16": {
|
| 27 |
+
"path": CHECKPOINT_DIR / "best_fusion_model_fp16.pt",
|
| 28 |
+
"filename": "new_best_fusion_model_fp16.pt"
|
| 29 |
+
},
|
| 30 |
+
|
| 31 |
+
"new-damagelens-yolo-detector": {
|
| 32 |
+
"path": CHECKPOINT_DIR / "damage_detector.pt",
|
| 33 |
+
"filename": "new_damage_detector.pt"
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def upload_model(api, repo_name, file_path, filename):
|
| 39 |
+
if not file_path.exists():
|
| 40 |
+
raise FileNotFoundError(
|
| 41 |
+
f"Model file not found: {file_path}"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
repo_id = f"{HF_USERNAME}/{repo_name}"
|
| 45 |
+
|
| 46 |
+
logger.info(f"Creating repo: {repo_id}")
|
| 47 |
+
|
| 48 |
+
api.create_repo(
|
| 49 |
+
repo_id=repo_id,
|
| 50 |
+
repo_type="model",
|
| 51 |
+
exist_ok=True
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
logger.info(f"Uploading {filename} to {repo_id}")
|
| 55 |
+
|
| 56 |
+
api.upload_file(
|
| 57 |
+
path_or_fileobj=str(file_path),
|
| 58 |
+
path_in_repo=filename,
|
| 59 |
+
repo_id=repo_id,
|
| 60 |
+
repo_type="model"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
logger.info(f"Upload completed: {repo_id}")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def upload_all_models():
|
| 67 |
+
logger.info("Starting Hugging Face model uploads...")
|
| 68 |
+
|
| 69 |
+
api = HfApi(token=HF_TOKEN)
|
| 70 |
+
|
| 71 |
+
for repo_name, model_info in MODELS.items():
|
| 72 |
+
upload_model(
|
| 73 |
+
api=api,
|
| 74 |
+
repo_name=repo_name,
|
| 75 |
+
file_path=model_info["path"],
|
| 76 |
+
filename=model_info["filename"]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
logger.info("All model uploads completed successfully.")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
logging.basicConfig(
|
| 84 |
+
level=logging.INFO,
|
| 85 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
upload_all_models()
|
| 89 |
+
|
| 90 |
+
print("\nAll models uploaded successfully.")
|
src/models/fusion_model.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchvision import models
|
| 5 |
+
from transformers import ConvNextModel
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FusionClassifier(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
num_classes,
|
| 14 |
+
convnext_model_name="facebook/convnext-small-224"
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
logger.info("Initializing Fusion model...")
|
| 19 |
+
|
| 20 |
+
# EfficientNet-V2-S
|
| 21 |
+
eff = models.efficientnet_v2_s(
|
| 22 |
+
weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
for param in eff.parameters():
|
| 26 |
+
param.requires_grad = False
|
| 27 |
+
|
| 28 |
+
for param in eff.features[5].parameters():
|
| 29 |
+
param.requires_grad = True
|
| 30 |
+
|
| 31 |
+
for param in eff.features[6].parameters():
|
| 32 |
+
param.requires_grad = True
|
| 33 |
+
|
| 34 |
+
for param in eff.features[7].parameters():
|
| 35 |
+
param.requires_grad = True
|
| 36 |
+
|
| 37 |
+
self.eff_features = eff.features
|
| 38 |
+
self.eff_avgpool = eff.avgpool
|
| 39 |
+
self.eff_out_dim = eff.classifier[1].in_features
|
| 40 |
+
|
| 41 |
+
# ConvNeXt
|
| 42 |
+
cnx = ConvNextModel.from_pretrained(convnext_model_name)
|
| 43 |
+
|
| 44 |
+
for param in cnx.parameters():
|
| 45 |
+
param.requires_grad = False
|
| 46 |
+
|
| 47 |
+
for param in cnx.encoder.stages[2].parameters():
|
| 48 |
+
param.requires_grad = True
|
| 49 |
+
|
| 50 |
+
for param in cnx.encoder.stages[3].parameters():
|
| 51 |
+
param.requires_grad = True
|
| 52 |
+
|
| 53 |
+
for param in cnx.layernorm.parameters():
|
| 54 |
+
param.requires_grad = True
|
| 55 |
+
|
| 56 |
+
self.cnx_backbone = cnx
|
| 57 |
+
self.cnx_out_dim = 768
|
| 58 |
+
|
| 59 |
+
fused_dim = self.eff_out_dim + self.cnx_out_dim
|
| 60 |
+
|
| 61 |
+
self.fusion_head = nn.Sequential(
|
| 62 |
+
nn.Dropout(0.4),
|
| 63 |
+
nn.Linear(fused_dim, 512),
|
| 64 |
+
nn.LayerNorm(512),
|
| 65 |
+
nn.GELU(),
|
| 66 |
+
|
| 67 |
+
nn.Dropout(0.3),
|
| 68 |
+
nn.Linear(512, 256),
|
| 69 |
+
nn.LayerNorm(256),
|
| 70 |
+
nn.GELU(),
|
| 71 |
+
|
| 72 |
+
nn.Dropout(0.2),
|
| 73 |
+
nn.Linear(256, num_classes)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
logger.info("Fusion model initialized successfully.")
|
| 77 |
+
|
| 78 |
+
def forward(self, pixel_values_eff, pixel_values_cnx):
|
| 79 |
+
x_eff = self.eff_features(pixel_values_eff)
|
| 80 |
+
x_eff = self.eff_avgpool(x_eff)
|
| 81 |
+
x_eff = torch.flatten(x_eff, 1)
|
| 82 |
+
|
| 83 |
+
cnx_out = self.cnx_backbone(
|
| 84 |
+
pixel_values=pixel_values_cnx,
|
| 85 |
+
return_dict=True
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
x_cnx = cnx_out.pooler_output
|
| 89 |
+
|
| 90 |
+
fused = torch.cat([x_eff, x_cnx], dim=1)
|
| 91 |
+
|
| 92 |
+
logits = self.fusion_head(fused)
|
| 93 |
+
|
| 94 |
+
return logits
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
import logging
|
| 99 |
+
|
| 100 |
+
logging.basicConfig(
|
| 101 |
+
level=logging.INFO,
|
| 102 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
model = FusionClassifier(num_classes=6)
|
| 106 |
+
|
| 107 |
+
eff_dummy = torch.randn(2, 3, 260, 260)
|
| 108 |
+
cnx_dummy = torch.randn(2, 3, 224, 224)
|
| 109 |
+
|
| 110 |
+
output = model(eff_dummy, cnx_dummy)
|
| 111 |
+
|
| 112 |
+
print("Fusion output shape:", output.shape)
|
src/models/resnet_model.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchvision import models
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CarClassifierResNet(nn.Module):
|
| 10 |
+
def __init__(self, num_classes):
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
logger.info("Initializing ResNet18 model...")
|
| 14 |
+
|
| 15 |
+
self.model = models.resnet18(weights="DEFAULT")
|
| 16 |
+
|
| 17 |
+
# Freeze everything
|
| 18 |
+
for param in self.model.parameters():
|
| 19 |
+
param.requires_grad = False
|
| 20 |
+
|
| 21 |
+
# Unfreeze last layers
|
| 22 |
+
for param in self.model.layer3.parameters():
|
| 23 |
+
param.requires_grad = True
|
| 24 |
+
|
| 25 |
+
for param in self.model.layer4.parameters():
|
| 26 |
+
param.requires_grad = True
|
| 27 |
+
|
| 28 |
+
# Custom classifier head
|
| 29 |
+
self.model.fc = nn.Sequential(
|
| 30 |
+
nn.Dropout(0.5),
|
| 31 |
+
nn.Linear(self.model.fc.in_features, 256),
|
| 32 |
+
nn.ReLU(),
|
| 33 |
+
nn.Dropout(0.3),
|
| 34 |
+
nn.Linear(256, num_classes)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
logger.info("ResNet18 model initialized successfully.")
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return self.model(x)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=logging.INFO,
|
| 46 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
model = CarClassifierResNet(num_classes=6)
|
| 50 |
+
|
| 51 |
+
dummy_input = torch.randn(2, 3, 128, 128)
|
| 52 |
+
|
| 53 |
+
output = model(dummy_input)
|
| 54 |
+
|
| 55 |
+
print("Output shape:", output.shape)
|
| 56 |
+
|
| 57 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 58 |
+
trainable_params = sum(
|
| 59 |
+
p.numel() for p in model.parameters()
|
| 60 |
+
if p.requires_grad
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
print("Total params:", total_params)
|
| 64 |
+
print("Trainable params:", trainable_params)
|
src/training/train_fusion.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.optim import AdamW
|
| 4 |
+
|
| 5 |
+
from src.config import DEVICE, EPOCHS, NUM_CLASSES
|
| 6 |
+
from src.models.fusion_model import FusionClassifier
|
| 7 |
+
from src.data.dataset import create_fusion_dataloaders
|
| 8 |
+
from src.training.trainer import train_dual_input_model
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def run_fusion_training():
|
| 14 |
+
logger.info("Initializing Fusion training pipeline...")
|
| 15 |
+
|
| 16 |
+
train_loader, eval_loader = create_fusion_dataloaders()
|
| 17 |
+
|
| 18 |
+
model = FusionClassifier(
|
| 19 |
+
num_classes=NUM_CLASSES
|
| 20 |
+
).to(DEVICE)
|
| 21 |
+
|
| 22 |
+
criterion = nn.CrossEntropyLoss(
|
| 23 |
+
label_smoothing=0.1
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
optimizer = AdamW([
|
| 27 |
+
# EfficientNet unfrozen blocks
|
| 28 |
+
{
|
| 29 |
+
"params": model.eff_features[5].parameters(),
|
| 30 |
+
"lr": 1e-5
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"params": model.eff_features[6].parameters(),
|
| 34 |
+
"lr": 3e-5
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"params": model.eff_features[7].parameters(),
|
| 38 |
+
"lr": 3e-5
|
| 39 |
+
},
|
| 40 |
+
|
| 41 |
+
# ConvNeXt unfrozen blocks
|
| 42 |
+
{
|
| 43 |
+
"params": model.cnx_backbone.encoder.stages[2].parameters(),
|
| 44 |
+
"lr": 3e-5
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"params": model.cnx_backbone.encoder.stages[3].parameters(),
|
| 48 |
+
"lr": 3e-5
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"params": model.cnx_backbone.layernorm.parameters(),
|
| 52 |
+
"lr": 3e-5
|
| 53 |
+
},
|
| 54 |
+
|
| 55 |
+
# Fusion head
|
| 56 |
+
{
|
| 57 |
+
"params": model.fusion_head.parameters(),
|
| 58 |
+
"lr": 1e-4
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
], weight_decay=1e-4)
|
| 62 |
+
|
| 63 |
+
logger.info("Starting Fusion training...")
|
| 64 |
+
|
| 65 |
+
all_preds, all_labels = train_dual_input_model(
|
| 66 |
+
model=model,
|
| 67 |
+
train_loader=train_loader,
|
| 68 |
+
eval_loader=eval_loader,
|
| 69 |
+
optimizer=optimizer,
|
| 70 |
+
criterion=criterion,
|
| 71 |
+
device=DEVICE,
|
| 72 |
+
epochs=EPOCHS,
|
| 73 |
+
checkpoint_model_name="best_fusion_model",
|
| 74 |
+
patience=7
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
logger.info("Fusion training completed.")
|
| 78 |
+
|
| 79 |
+
return all_preds, all_labels
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
logging.basicConfig(
|
| 84 |
+
level=logging.INFO,
|
| 85 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
preds, labels = run_fusion_training()
|
| 89 |
+
|
| 90 |
+
print("\nFusion training completed successfully.")
|
| 91 |
+
print("Prediction samples:", preds[:10])
|
| 92 |
+
print("Label samples:", labels[:10])
|
src/training/train_resnet.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.optim import AdamW
|
| 4 |
+
|
| 5 |
+
from src.config import DEVICE, EPOCHS, NUM_CLASSES
|
| 6 |
+
from src.models.resnet_model import CarClassifierResNet
|
| 7 |
+
from src.data.dataset import create_resnet_dataloaders
|
| 8 |
+
from src.training.trainer import train_single_input_model
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def run_resnet_training():
|
| 14 |
+
logger.info("Initializing ResNet training pipeline...")
|
| 15 |
+
|
| 16 |
+
train_loader, eval_loader = create_resnet_dataloaders()
|
| 17 |
+
|
| 18 |
+
model = CarClassifierResNet(
|
| 19 |
+
num_classes=NUM_CLASSES
|
| 20 |
+
).to(DEVICE)
|
| 21 |
+
|
| 22 |
+
criterion = nn.CrossEntropyLoss()
|
| 23 |
+
|
| 24 |
+
optimizer = AdamW([
|
| 25 |
+
{
|
| 26 |
+
"params": model.model.layer3.parameters(),
|
| 27 |
+
"lr": 1e-5
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"params": model.model.layer4.parameters(),
|
| 31 |
+
"lr": 1e-5
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"params": model.model.fc.parameters(),
|
| 35 |
+
"lr": 1e-4
|
| 36 |
+
}
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
logger.info("Starting ResNet training...")
|
| 40 |
+
|
| 41 |
+
all_preds, all_labels = train_single_input_model(
|
| 42 |
+
model=model,
|
| 43 |
+
train_loader=train_loader,
|
| 44 |
+
eval_loader=eval_loader,
|
| 45 |
+
optimizer=optimizer,
|
| 46 |
+
criterion=criterion,
|
| 47 |
+
device=DEVICE,
|
| 48 |
+
epochs=EPOCHS,
|
| 49 |
+
checkpoint_model_name="best_resnet_model",
|
| 50 |
+
patience=7
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
logger.info("ResNet training completed.")
|
| 54 |
+
|
| 55 |
+
return all_preds, all_labels
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
logging.basicConfig(
|
| 60 |
+
level=logging.INFO,
|
| 61 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
preds, labels = run_resnet_training()
|
| 65 |
+
|
| 66 |
+
print("\nTraining completed successfully.")
|
| 67 |
+
print("Prediction samples:", preds[:10])
|
| 68 |
+
print("Label samples:", labels[:10])
|
src/training/train_yolo.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from shutil import copy2, rmtree
|
| 3 |
+
from ultralytics import YOLO
|
| 4 |
+
|
| 5 |
+
from src.config import BASE_DIR, CHECKPOINT_DIR, DEVICE
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
YOLO_DATASET_CONFIG = BASE_DIR / "data" / "yolo" / "dataset_custom.yaml"
|
| 10 |
+
YOLO_BASE_MODEL = CHECKPOINT_DIR / "yolo11m.pt"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def run_yolo_training():
|
| 14 |
+
logger.info("Initializing YOLO training pipeline...")
|
| 15 |
+
|
| 16 |
+
if not YOLO_DATASET_CONFIG.exists():
|
| 17 |
+
raise FileNotFoundError(
|
| 18 |
+
f"YOLO dataset config not found: {YOLO_DATASET_CONFIG}"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
if not YOLO_BASE_MODEL.exists():
|
| 22 |
+
raise FileNotFoundError(
|
| 23 |
+
f"YOLO base model not found: {YOLO_BASE_MODEL}"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
yolo_device = 0 if DEVICE == "cuda" else "cpu"
|
| 27 |
+
|
| 28 |
+
checkpoint_root = CHECKPOINT_DIR.resolve()
|
| 29 |
+
temp_run_name = "temp_yolo_run"
|
| 30 |
+
|
| 31 |
+
logger.info("Loading YOLO base model...")
|
| 32 |
+
|
| 33 |
+
model = YOLO(str(YOLO_BASE_MODEL.resolve()))
|
| 34 |
+
|
| 35 |
+
logger.info("Starting YOLO training...")
|
| 36 |
+
|
| 37 |
+
model.train(
|
| 38 |
+
data=str(YOLO_DATASET_CONFIG.resolve()),
|
| 39 |
+
imgsz=416,
|
| 40 |
+
batch=4,
|
| 41 |
+
epochs=1,
|
| 42 |
+
device=yolo_device,
|
| 43 |
+
project=str(checkpoint_root),
|
| 44 |
+
name=temp_run_name,
|
| 45 |
+
exist_ok=True
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
best_model_path = (
|
| 49 |
+
checkpoint_root /
|
| 50 |
+
temp_run_name /
|
| 51 |
+
"weights" /
|
| 52 |
+
"best.pt"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if not best_model_path.exists():
|
| 56 |
+
raise FileNotFoundError(
|
| 57 |
+
f"YOLO best model not found: {best_model_path}"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
final_model_path = checkpoint_root / "damage_detector.pt"
|
| 61 |
+
|
| 62 |
+
copy2(best_model_path, final_model_path)
|
| 63 |
+
|
| 64 |
+
logger.info(f"Final YOLO model saved at: {final_model_path}")
|
| 65 |
+
|
| 66 |
+
# cleanup temp training folder
|
| 67 |
+
temp_run_dir = checkpoint_root / temp_run_name
|
| 68 |
+
|
| 69 |
+
if temp_run_dir.exists():
|
| 70 |
+
rmtree(temp_run_dir)
|
| 71 |
+
logger.info("Temporary YOLO training artifacts deleted.")
|
| 72 |
+
|
| 73 |
+
return final_model_path
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
logging.basicConfig(
|
| 78 |
+
level=logging.INFO,
|
| 79 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
model_path = run_yolo_training()
|
| 83 |
+
|
| 84 |
+
print("\nYOLO training completed successfully.")
|
| 85 |
+
print(f"Saved model: {model_path}")
|
src/training/trainer.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from transformers import get_cosine_schedule_with_warmup
|
| 5 |
+
|
| 6 |
+
from src.config import CHECKPOINT_DIR
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EarlyStopping:
|
| 12 |
+
def __init__(self, patience=7, min_delta=0.001):
|
| 13 |
+
self.patience = patience
|
| 14 |
+
self.min_delta = min_delta
|
| 15 |
+
self.counter = 0
|
| 16 |
+
self.best_score = None
|
| 17 |
+
self.early_stop = False
|
| 18 |
+
|
| 19 |
+
def __call__(self, val_acc):
|
| 20 |
+
if self.best_score is None:
|
| 21 |
+
self.best_score = val_acc
|
| 22 |
+
|
| 23 |
+
elif val_acc < self.best_score + self.min_delta:
|
| 24 |
+
self.counter += 1
|
| 25 |
+
|
| 26 |
+
logger.info(
|
| 27 |
+
f"EarlyStopping counter: {self.counter}/{self.patience}"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
if self.counter >= self.patience:
|
| 31 |
+
self.early_stop = True
|
| 32 |
+
|
| 33 |
+
else:
|
| 34 |
+
self.best_score = val_acc
|
| 35 |
+
self.counter = 0
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def train_single_input_model(
|
| 39 |
+
model,
|
| 40 |
+
train_loader,
|
| 41 |
+
eval_loader,
|
| 42 |
+
optimizer,
|
| 43 |
+
criterion,
|
| 44 |
+
device,
|
| 45 |
+
epochs,
|
| 46 |
+
checkpoint_model_name,
|
| 47 |
+
patience=7
|
| 48 |
+
):
|
| 49 |
+
logger.info("Starting single-input training...")
|
| 50 |
+
|
| 51 |
+
num_training_steps = epochs * len(train_loader)
|
| 52 |
+
num_warmup_steps = int(0.1 * num_training_steps)
|
| 53 |
+
|
| 54 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 55 |
+
optimizer=optimizer,
|
| 56 |
+
num_warmup_steps=num_warmup_steps,
|
| 57 |
+
num_training_steps=num_training_steps
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
early_stopping = EarlyStopping(patience=patience)
|
| 61 |
+
|
| 62 |
+
best_acc = 0.0
|
| 63 |
+
all_preds = []
|
| 64 |
+
all_labels = []
|
| 65 |
+
|
| 66 |
+
for epoch in range(epochs):
|
| 67 |
+
logger.info(f"Epoch {epoch + 1}/{epochs}")
|
| 68 |
+
|
| 69 |
+
model.train()
|
| 70 |
+
|
| 71 |
+
running_loss = 0
|
| 72 |
+
correct = 0
|
| 73 |
+
total = 0
|
| 74 |
+
|
| 75 |
+
for images, labels in tqdm(
|
| 76 |
+
train_loader,
|
| 77 |
+
desc=f"Epoch {epoch+1} Training"
|
| 78 |
+
):
|
| 79 |
+
images = images.to(device)
|
| 80 |
+
labels = labels.to(device)
|
| 81 |
+
|
| 82 |
+
optimizer.zero_grad(set_to_none=True)
|
| 83 |
+
|
| 84 |
+
logits = model(images)
|
| 85 |
+
|
| 86 |
+
loss = criterion(logits, labels)
|
| 87 |
+
|
| 88 |
+
loss.backward()
|
| 89 |
+
|
| 90 |
+
optimizer.step()
|
| 91 |
+
scheduler.step()
|
| 92 |
+
|
| 93 |
+
running_loss += loss.item()
|
| 94 |
+
|
| 95 |
+
preds = torch.argmax(logits, dim=1)
|
| 96 |
+
|
| 97 |
+
correct += (preds == labels).sum().item()
|
| 98 |
+
total += labels.size(0)
|
| 99 |
+
|
| 100 |
+
train_loss = running_loss / len(train_loader)
|
| 101 |
+
train_acc = 100 * correct / total
|
| 102 |
+
|
| 103 |
+
model.eval()
|
| 104 |
+
|
| 105 |
+
val_running_loss = 0
|
| 106 |
+
val_correct = 0
|
| 107 |
+
val_total = 0
|
| 108 |
+
|
| 109 |
+
all_preds = []
|
| 110 |
+
all_labels = []
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
for images, labels in tqdm(
|
| 114 |
+
eval_loader,
|
| 115 |
+
desc=f"Epoch {epoch+1} Validation"
|
| 116 |
+
):
|
| 117 |
+
images = images.to(device)
|
| 118 |
+
labels = labels.to(device)
|
| 119 |
+
|
| 120 |
+
logits = model(images)
|
| 121 |
+
|
| 122 |
+
loss = criterion(logits, labels)
|
| 123 |
+
|
| 124 |
+
val_running_loss += loss.item()
|
| 125 |
+
|
| 126 |
+
preds = torch.argmax(logits, dim=1)
|
| 127 |
+
|
| 128 |
+
val_correct += (preds == labels).sum().item()
|
| 129 |
+
val_total += labels.size(0)
|
| 130 |
+
|
| 131 |
+
all_preds.extend(preds.cpu().numpy())
|
| 132 |
+
all_labels.extend(labels.cpu().numpy())
|
| 133 |
+
|
| 134 |
+
val_loss = val_running_loss / len(eval_loader)
|
| 135 |
+
val_acc = 100 * val_correct / val_total
|
| 136 |
+
|
| 137 |
+
logger.info(
|
| 138 |
+
f"Train Loss: {train_loss:.4f} | "
|
| 139 |
+
f"Train Acc: {train_acc:.2f}% || "
|
| 140 |
+
f"Val Loss: {val_loss:.4f} | "
|
| 141 |
+
f"Val Acc: {val_acc:.2f}%"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if val_acc > best_acc:
|
| 145 |
+
best_acc = val_acc
|
| 146 |
+
|
| 147 |
+
checkpoint_path = CHECKPOINT_DIR / f"{checkpoint_model_name}.pt"
|
| 148 |
+
|
| 149 |
+
torch.save(
|
| 150 |
+
{
|
| 151 |
+
"model_state_dict": model.state_dict(),
|
| 152 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 153 |
+
"epoch": epoch,
|
| 154 |
+
"val_acc": val_acc
|
| 155 |
+
},
|
| 156 |
+
checkpoint_path
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
logger.info(f"Best checkpoint saved at: {checkpoint_path}")
|
| 160 |
+
|
| 161 |
+
early_stopping(val_acc)
|
| 162 |
+
|
| 163 |
+
if early_stopping.early_stop:
|
| 164 |
+
logger.info("Early stopping triggered.")
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
return all_preds, all_labels
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def train_dual_input_model(
|
| 171 |
+
model,
|
| 172 |
+
train_loader,
|
| 173 |
+
eval_loader,
|
| 174 |
+
optimizer,
|
| 175 |
+
criterion,
|
| 176 |
+
device,
|
| 177 |
+
epochs,
|
| 178 |
+
checkpoint_model_name,
|
| 179 |
+
patience=7
|
| 180 |
+
):
|
| 181 |
+
logger.info("Starting dual-input training...")
|
| 182 |
+
|
| 183 |
+
num_training_steps = epochs * len(train_loader)
|
| 184 |
+
num_warmup_steps = int(0.1 * num_training_steps)
|
| 185 |
+
|
| 186 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 187 |
+
optimizer=optimizer,
|
| 188 |
+
num_warmup_steps=num_warmup_steps,
|
| 189 |
+
num_training_steps=num_training_steps
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
early_stopping = EarlyStopping(patience=patience)
|
| 193 |
+
|
| 194 |
+
best_acc = 0.0
|
| 195 |
+
all_preds = []
|
| 196 |
+
all_labels = []
|
| 197 |
+
|
| 198 |
+
for epoch in range(epochs):
|
| 199 |
+
logger.info(f"Epoch {epoch + 1}/{epochs}")
|
| 200 |
+
|
| 201 |
+
model.train()
|
| 202 |
+
|
| 203 |
+
running_loss = 0
|
| 204 |
+
correct = 0
|
| 205 |
+
total = 0
|
| 206 |
+
|
| 207 |
+
for batch in tqdm(
|
| 208 |
+
train_loader,
|
| 209 |
+
desc=f"Epoch {epoch+1} Training"
|
| 210 |
+
):
|
| 211 |
+
images_eff = batch["pixel_values_eff"].to(device)
|
| 212 |
+
images_cnx = batch["pixel_values_cnx"].to(device)
|
| 213 |
+
labels = batch["labels"].to(device)
|
| 214 |
+
|
| 215 |
+
optimizer.zero_grad(set_to_none=True)
|
| 216 |
+
|
| 217 |
+
logits = model(images_eff, images_cnx)
|
| 218 |
+
|
| 219 |
+
loss = criterion(logits, labels)
|
| 220 |
+
|
| 221 |
+
loss.backward()
|
| 222 |
+
|
| 223 |
+
optimizer.step()
|
| 224 |
+
scheduler.step()
|
| 225 |
+
|
| 226 |
+
running_loss += loss.item()
|
| 227 |
+
|
| 228 |
+
preds = torch.argmax(logits, dim=1)
|
| 229 |
+
|
| 230 |
+
correct += (preds == labels).sum().item()
|
| 231 |
+
total += labels.size(0)
|
| 232 |
+
|
| 233 |
+
train_loss = running_loss / len(train_loader)
|
| 234 |
+
train_acc = 100 * correct / total
|
| 235 |
+
|
| 236 |
+
model.eval()
|
| 237 |
+
|
| 238 |
+
val_running_loss = 0
|
| 239 |
+
val_correct = 0
|
| 240 |
+
val_total = 0
|
| 241 |
+
|
| 242 |
+
all_preds = []
|
| 243 |
+
all_labels = []
|
| 244 |
+
|
| 245 |
+
with torch.no_grad():
|
| 246 |
+
for batch in tqdm(
|
| 247 |
+
eval_loader,
|
| 248 |
+
desc=f"Epoch {epoch+1} Validation"
|
| 249 |
+
):
|
| 250 |
+
images_eff = batch["pixel_values_eff"].to(device)
|
| 251 |
+
images_cnx = batch["pixel_values_cnx"].to(device)
|
| 252 |
+
labels = batch["labels"].to(device)
|
| 253 |
+
|
| 254 |
+
logits = model(images_eff, images_cnx)
|
| 255 |
+
|
| 256 |
+
loss = criterion(logits, labels)
|
| 257 |
+
|
| 258 |
+
val_running_loss += loss.item()
|
| 259 |
+
|
| 260 |
+
preds = torch.argmax(logits, dim=1)
|
| 261 |
+
|
| 262 |
+
val_correct += (preds == labels).sum().item()
|
| 263 |
+
val_total += labels.size(0)
|
| 264 |
+
|
| 265 |
+
all_preds.extend(preds.cpu().numpy())
|
| 266 |
+
all_labels.extend(labels.cpu().numpy())
|
| 267 |
+
|
| 268 |
+
val_loss = val_running_loss / len(eval_loader)
|
| 269 |
+
val_acc = 100 * val_correct / val_total
|
| 270 |
+
|
| 271 |
+
logger.info(
|
| 272 |
+
f"Train Loss: {train_loss:.4f} | "
|
| 273 |
+
f"Train Acc: {train_acc:.2f}% || "
|
| 274 |
+
f"Val Loss: {val_loss:.4f} | "
|
| 275 |
+
f"Val Acc: {val_acc:.2f}%"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if val_acc > best_acc:
|
| 279 |
+
best_acc = val_acc
|
| 280 |
+
|
| 281 |
+
checkpoint_path = CHECKPOINT_DIR / f"{checkpoint_model_name}.pt"
|
| 282 |
+
|
| 283 |
+
torch.save(
|
| 284 |
+
{
|
| 285 |
+
"model_state_dict": model.state_dict(),
|
| 286 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 287 |
+
"epoch": epoch,
|
| 288 |
+
"val_acc": val_acc
|
| 289 |
+
},
|
| 290 |
+
checkpoint_path
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
logger.info(f"Best checkpoint saved at: {checkpoint_path}")
|
| 294 |
+
|
| 295 |
+
early_stopping(val_acc)
|
| 296 |
+
|
| 297 |
+
if early_stopping.early_stop:
|
| 298 |
+
logger.info("Early stopping triggered.")
|
| 299 |
+
break
|
| 300 |
+
|
| 301 |
+
return all_preds, all_labels
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
print("Trainer utilities ready.")
|
test/test_augmentation.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
from src.data.augmentation import (
|
| 5 |
+
get_resnet_train_transforms,
|
| 6 |
+
get_fusion_train_transforms
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_augmentation():
|
| 13 |
+
logger.info("Testing augmentation pipelines...")
|
| 14 |
+
|
| 15 |
+
dummy_image = Image.new("RGB", (300, 300))
|
| 16 |
+
|
| 17 |
+
resnet_transform = get_resnet_train_transforms()
|
| 18 |
+
fusion_transform = get_fusion_train_transforms()
|
| 19 |
+
|
| 20 |
+
resnet_tensor = resnet_transform(dummy_image)
|
| 21 |
+
fusion_tensor = fusion_transform(dummy_image)
|
| 22 |
+
|
| 23 |
+
assert resnet_tensor.shape == (3, 128, 128), \
|
| 24 |
+
f"Unexpected ResNet shape: {resnet_tensor.shape}"
|
| 25 |
+
|
| 26 |
+
assert fusion_tensor.shape == (3, 260, 260), \
|
| 27 |
+
f"Unexpected Fusion shape: {fusion_tensor.shape}"
|
| 28 |
+
|
| 29 |
+
logger.info("Augmentation test passed.")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.INFO,
|
| 35 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
test_augmentation()
|
| 39 |
+
|
| 40 |
+
print("Augmentation test completed successfully.")
|
test/test_config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from src.config import (
|
| 4 |
+
BASE_DIR,
|
| 5 |
+
CHECKPOINT_DIR,
|
| 6 |
+
DEVICE,
|
| 7 |
+
BATCH_SIZE,
|
| 8 |
+
EPOCHS,
|
| 9 |
+
NUM_CLASSES
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_config():
|
| 16 |
+
logger.info("Testing config settings...")
|
| 17 |
+
|
| 18 |
+
assert BASE_DIR.exists(), "BASE_DIR missing"
|
| 19 |
+
assert CHECKPOINT_DIR.exists(), "CHECKPOINT_DIR missing"
|
| 20 |
+
|
| 21 |
+
assert DEVICE in ["cpu", "cuda"], "Invalid device"
|
| 22 |
+
assert BATCH_SIZE > 0, "Invalid batch size"
|
| 23 |
+
assert EPOCHS > 0, "Invalid epochs"
|
| 24 |
+
assert NUM_CLASSES == 6, "NUM_CLASSES mismatch"
|
| 25 |
+
|
| 26 |
+
logger.info("Config test passed.")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
logging.basicConfig(
|
| 31 |
+
level=logging.INFO,
|
| 32 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
test_config()
|
| 36 |
+
|
| 37 |
+
print("Config test completed successfully.")
|
test/test_dataset.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from src.data.dataset import (
|
| 4 |
+
create_resnet_dataloaders,
|
| 5 |
+
create_fusion_dataloaders
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_dataset():
|
| 12 |
+
logger.info("Testing dataset loaders...")
|
| 13 |
+
|
| 14 |
+
# ---------------- ResNet ----------------
|
| 15 |
+
resnet_loader, _ = create_resnet_dataloaders()
|
| 16 |
+
|
| 17 |
+
images, labels = next(iter(resnet_loader))
|
| 18 |
+
|
| 19 |
+
assert images.shape[1:] == (3, 128, 128), \
|
| 20 |
+
f"Unexpected ResNet image shape: {images.shape}"
|
| 21 |
+
|
| 22 |
+
assert len(labels.shape) == 1, \
|
| 23 |
+
f"Unexpected ResNet labels shape: {labels.shape}"
|
| 24 |
+
|
| 25 |
+
logger.info("ResNet dataloader test passed.")
|
| 26 |
+
|
| 27 |
+
# ---------------- Fusion ----------------
|
| 28 |
+
fusion_loader, _ = create_fusion_dataloaders()
|
| 29 |
+
|
| 30 |
+
batch = next(iter(fusion_loader))
|
| 31 |
+
|
| 32 |
+
assert "pixel_values_eff" in batch, "Missing EfficientNet input"
|
| 33 |
+
assert "pixel_values_cnx" in batch, "Missing ConvNeXt input"
|
| 34 |
+
assert "labels" in batch, "Missing labels"
|
| 35 |
+
|
| 36 |
+
assert batch["pixel_values_eff"].shape[1:] == (3, 260, 260), \
|
| 37 |
+
f"Unexpected Fusion EfficientNet shape: {batch['pixel_values_eff'].shape}"
|
| 38 |
+
|
| 39 |
+
assert batch["pixel_values_cnx"].shape[1:] == (3, 224, 224), \
|
| 40 |
+
f"Unexpected Fusion ConvNeXt shape: {batch['pixel_values_cnx'].shape}"
|
| 41 |
+
|
| 42 |
+
assert len(batch["labels"].shape) == 1, \
|
| 43 |
+
f"Unexpected Fusion labels shape: {batch['labels'].shape}"
|
| 44 |
+
|
| 45 |
+
logger.info("Fusion dataloader test passed.")
|
| 46 |
+
logger.info("Dataset test passed successfully.")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
logging.basicConfig(
|
| 51 |
+
level=logging.INFO,
|
| 52 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
test_dataset()
|
| 56 |
+
|
| 57 |
+
print("Dataset test completed successfully.")
|
test/test_fusion_model.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from src.models.fusion_model import FusionClassifier
|
| 5 |
+
from src.config import NUM_CLASSES
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_fusion_model():
|
| 11 |
+
logger.info("Testing Fusion model architecture...")
|
| 12 |
+
|
| 13 |
+
model = FusionClassifier(
|
| 14 |
+
num_classes=NUM_CLASSES
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
model.eval()
|
| 18 |
+
|
| 19 |
+
eff_dummy = torch.randn(2, 3, 260, 260)
|
| 20 |
+
cnx_dummy = torch.randn(2, 3, 224, 224)
|
| 21 |
+
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
output = model(
|
| 24 |
+
eff_dummy,
|
| 25 |
+
cnx_dummy
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
assert output.shape == (2, NUM_CLASSES), \
|
| 29 |
+
f"Unexpected output shape: {output.shape}"
|
| 30 |
+
|
| 31 |
+
logger.info("Fusion model test passed.")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
logging.basicConfig(
|
| 36 |
+
level=logging.INFO,
|
| 37 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
test_fusion_model()
|
| 41 |
+
|
| 42 |
+
print("Fusion model test completed successfully.")
|
test/test_ingestion.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from src.data.ingestion import collect_image_paths
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_ingestion():
|
| 10 |
+
logger.info("Testing ingestion...")
|
| 11 |
+
|
| 12 |
+
samples = collect_image_paths()
|
| 13 |
+
|
| 14 |
+
assert len(samples) > 0, "No samples found"
|
| 15 |
+
|
| 16 |
+
image_path, label = samples[0]
|
| 17 |
+
|
| 18 |
+
assert os.path.exists(image_path), "Image path missing"
|
| 19 |
+
assert isinstance(label, int), "Label invalid"
|
| 20 |
+
|
| 21 |
+
logger.info("Ingestion test passed.")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
logging.basicConfig(
|
| 26 |
+
level=logging.INFO,
|
| 27 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
test_ingestion()
|
| 31 |
+
|
| 32 |
+
print("Ingestion test completed successfully.")
|
test/test_model_conversion.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from src.export.conver_model import convert_fusion_to_fp16
|
| 5 |
+
from src.config import CHECKPOINT_DIR
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_model_conversion():
|
| 11 |
+
logger.info("Testing fusion FP16 conversion...")
|
| 12 |
+
|
| 13 |
+
input_checkpoint = CHECKPOINT_DIR / "best_fusion_model.pt"
|
| 14 |
+
|
| 15 |
+
assert input_checkpoint.exists(), \
|
| 16 |
+
f"Missing checkpoint: {input_checkpoint}"
|
| 17 |
+
|
| 18 |
+
output_path = convert_fusion_to_fp16()
|
| 19 |
+
|
| 20 |
+
assert output_path.exists(), \
|
| 21 |
+
"FP16 model was not created"
|
| 22 |
+
|
| 23 |
+
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
| 24 |
+
|
| 25 |
+
assert size_mb > 0, \
|
| 26 |
+
"Generated FP16 model is empty"
|
| 27 |
+
|
| 28 |
+
logger.info("Model conversion test passed.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
logging.basicConfig(
|
| 33 |
+
level=logging.INFO,
|
| 34 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
test_model_conversion()
|
| 38 |
+
|
| 39 |
+
print("Model conversion test completed successfully.")
|
test/test_preprocessing.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from src.data.ingestion import collect_image_paths
|
| 4 |
+
from src.data.preprocessing import split_dataset
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_preprocessing():
|
| 10 |
+
logger.info("Testing preprocessing...")
|
| 11 |
+
|
| 12 |
+
samples = collect_image_paths()
|
| 13 |
+
|
| 14 |
+
train_data, val_data = split_dataset(samples)
|
| 15 |
+
|
| 16 |
+
assert len(train_data) > 0, "Train split is empty"
|
| 17 |
+
assert len(val_data) > 0, "Validation split is empty"
|
| 18 |
+
|
| 19 |
+
train_paths = set(x[0] for x in train_data)
|
| 20 |
+
val_paths = set(x[0] for x in val_data)
|
| 21 |
+
|
| 22 |
+
overlap = train_paths.intersection(val_paths)
|
| 23 |
+
|
| 24 |
+
assert len(overlap) == 0, "Train and validation overlap found"
|
| 25 |
+
|
| 26 |
+
logger.info("Preprocessing test passed.")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
logging.basicConfig(
|
| 31 |
+
level=logging.INFO,
|
| 32 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
test_preprocessing()
|
| 36 |
+
|
| 37 |
+
print("Preprocessing test completed successfully.")
|
test/test_resnet_model.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from src.models.resnet_model import CarClassifierResNet
|
| 5 |
+
from src.config import NUM_CLASSES
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_resnet_model():
|
| 11 |
+
logger.info("Testing ResNet model architecture...")
|
| 12 |
+
|
| 13 |
+
model = CarClassifierResNet(
|
| 14 |
+
num_classes=NUM_CLASSES
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
model.eval()
|
| 18 |
+
|
| 19 |
+
dummy_input = torch.randn(2, 3, 128, 128)
|
| 20 |
+
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
output = model(dummy_input)
|
| 23 |
+
|
| 24 |
+
assert output.shape == (2, NUM_CLASSES), \
|
| 25 |
+
f"Unexpected output shape: {output.shape}"
|
| 26 |
+
|
| 27 |
+
logger.info("ResNet model test passed.")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
logging.basicConfig(
|
| 32 |
+
level=logging.INFO,
|
| 33 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
test_resnet_model()
|
| 37 |
+
|
| 38 |
+
print("ResNet model test completed successfully.")
|
test/test_train_fusion.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from src.training.train_fusion import run_fusion_training
|
| 4 |
+
from src.config import CHECKPOINT_DIR
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_train_fusion():
|
| 10 |
+
logger.info("Testing Fusion training pipeline...")
|
| 11 |
+
|
| 12 |
+
checkpoint_path = CHECKPOINT_DIR / "best_fusion_model.pt"
|
| 13 |
+
|
| 14 |
+
if checkpoint_path.exists():
|
| 15 |
+
checkpoint_path.unlink()
|
| 16 |
+
|
| 17 |
+
preds, labels = run_fusion_training()
|
| 18 |
+
|
| 19 |
+
assert checkpoint_path.exists(), \
|
| 20 |
+
"Fusion checkpoint was not created"
|
| 21 |
+
|
| 22 |
+
assert len(preds) > 0, \
|
| 23 |
+
"No predictions returned"
|
| 24 |
+
|
| 25 |
+
assert len(labels) > 0, \
|
| 26 |
+
"No labels returned"
|
| 27 |
+
|
| 28 |
+
logger.info("Fusion training test passed.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
logging.basicConfig(
|
| 33 |
+
level=logging.INFO,
|
| 34 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
test_train_fusion()
|
| 38 |
+
|
| 39 |
+
print("Fusion training test completed successfully.")
|
test/test_train_resnet.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from src.training.train_resnet import run_resnet_training
|
| 4 |
+
from src.config import CHECKPOINT_DIR
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_train_resnet():
|
| 10 |
+
logger.info("Testing ResNet training pipeline...")
|
| 11 |
+
|
| 12 |
+
checkpoint_path = CHECKPOINT_DIR / "best_resnet_model.pt"
|
| 13 |
+
|
| 14 |
+
if checkpoint_path.exists():
|
| 15 |
+
checkpoint_path.unlink()
|
| 16 |
+
|
| 17 |
+
preds, labels = run_resnet_training()
|
| 18 |
+
|
| 19 |
+
assert checkpoint_path.exists(), \
|
| 20 |
+
"ResNet checkpoint was not created"
|
| 21 |
+
|
| 22 |
+
assert len(preds) > 0, \
|
| 23 |
+
"No predictions returned"
|
| 24 |
+
|
| 25 |
+
assert len(labels) > 0, \
|
| 26 |
+
"No labels returned"
|
| 27 |
+
|
| 28 |
+
logger.info("ResNet training test passed.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
logging.basicConfig(
|
| 33 |
+
level=logging.INFO,
|
| 34 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
test_train_resnet()
|
| 38 |
+
|
| 39 |
+
print("ResNet training test completed successfully.")
|
test/test_train_yolo.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from src.training.train_yolo import run_yolo_training
|
| 4 |
+
from src.config import CHECKPOINT_DIR
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_train_yolo():
|
| 10 |
+
logger.info("Testing YOLO training pipeline...")
|
| 11 |
+
|
| 12 |
+
checkpoint_path = CHECKPOINT_DIR / "damage_detector.pt"
|
| 13 |
+
|
| 14 |
+
if checkpoint_path.exists():
|
| 15 |
+
checkpoint_path.unlink()
|
| 16 |
+
|
| 17 |
+
output_path = run_yolo_training()
|
| 18 |
+
|
| 19 |
+
assert checkpoint_path.exists(), \
|
| 20 |
+
"YOLO checkpoint was not created"
|
| 21 |
+
|
| 22 |
+
assert output_path.exists(), \
|
| 23 |
+
"Returned YOLO model path invalid"
|
| 24 |
+
|
| 25 |
+
logger.info("YOLO training test passed.")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
logging.basicConfig(
|
| 30 |
+
level=logging.INFO,
|
| 31 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
test_train_yolo()
|
| 35 |
+
|
| 36 |
+
print("YOLO training test completed successfully.")
|
test/test_upload_to_huggingface.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from huggingface_hub import HfApi
|
| 5 |
+
|
| 6 |
+
from src.export.upload_to_huggingface import MODELS
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_huggingface_upload_setup():
|
| 12 |
+
logger.info("Testing Hugging Face upload setup...")
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 17 |
+
|
| 18 |
+
assert hf_token is not None, \
|
| 19 |
+
"HF_TOKEN missing in .env"
|
| 20 |
+
|
| 21 |
+
assert hf_token.startswith("hf_"), \
|
| 22 |
+
"Invalid Hugging Face token format"
|
| 23 |
+
|
| 24 |
+
api = HfApi(token=hf_token)
|
| 25 |
+
|
| 26 |
+
assert api is not None, \
|
| 27 |
+
"Failed to initialize Hugging Face API"
|
| 28 |
+
|
| 29 |
+
for repo_name, model_info in MODELS.items():
|
| 30 |
+
file_path = model_info["path"]
|
| 31 |
+
filename = model_info["filename"]
|
| 32 |
+
|
| 33 |
+
assert file_path.exists(), \
|
| 34 |
+
f"Missing model file: {file_path}"
|
| 35 |
+
|
| 36 |
+
assert filename.endswith(".pt"), \
|
| 37 |
+
f"Invalid model filename: {filename}"
|
| 38 |
+
|
| 39 |
+
assert repo_name.startswith("new-"), \
|
| 40 |
+
f"Repo naming invalid: {repo_name}"
|
| 41 |
+
|
| 42 |
+
logger.info("Hugging Face upload setup test passed.")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
logging.basicConfig(
|
| 47 |
+
level=logging.INFO,
|
| 48 |
+
format="%(asctime)s - %(levelname)s - %(message)s"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
test_huggingface_upload_setup()
|
| 52 |
+
|
| 53 |
+
print("Hugging Face upload test completed successfully.")
|