Spaces:
Build error
Build error
Upload 8 files
Browse files- .gitattributes +1 -34
- .gitignore +2 -0
- LICENSE +21 -0
- Readme.md +180 -0
- app.py +704 -0
- predict.py +165 -0
- prediction_engine.py +157 -0
- requirements.txt +13 -3
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
**.onnx filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/__pycache__
|
| 2 |
+
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 smokeyScraper
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
Readme.md
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Nature Nexus - Forest Surveillance System
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
Nature Nexus is an advanced forest surveillance system designed to protect natural ecosystems through AI-powered monitoring. It combines multiple detection technologies to identify illegal activities, monitor deforestation, and detect potential threats to forest areas.
|
| 6 |
+
|
| 7 |
+
The application leverages:
|
| 8 |
+
- **Satellite Imagery Analysis** - Detects deforestation using segmentation models
|
| 9 |
+
- **Audio Surveillance** - Identifies unusual sounds like chainsaws, vehicles, and human activity
|
| 10 |
+
- **Object Detection** - Recognizes trespassers, vehicles, fires, and other threats
|
| 11 |
+
|
| 12 |
+
## Features
|
| 13 |
+
|
| 14 |
+
### 1. Deforestation Detection
|
| 15 |
+
- Analyzes satellite or aerial imagery to identify deforested areas
|
| 16 |
+
- Uses Attention U-Net segmentation model optimized with ONNX runtime
|
| 17 |
+
- Provides detailed metrics on forest coverage and deforestation levels
|
| 18 |
+
- Visualizes results with color-coded overlays
|
| 19 |
+
|
| 20 |
+

|
| 21 |
+
|
| 22 |
+
### 2. Forest Audio Surveillance
|
| 23 |
+
- Detects unusual sounds that may indicate illegal activities
|
| 24 |
+
- Classifies various sounds including:
|
| 25 |
+
- **Human Sounds**: Footsteps, coughing, laughing, breathing, etc.
|
| 26 |
+
- **Tool Sounds**: Chainsaw, hand saw
|
| 27 |
+
- **Vehicle Sounds**: Car horn, engine, siren
|
| 28 |
+
- **Other Sounds**: Crackling fire, fireworks
|
| 29 |
+
- Supports both uploaded audio files and real-time recording
|
| 30 |
+
|
| 31 |
+

|
| 32 |
+
|
| 33 |
+
### 3. Object Detection
|
| 34 |
+
- Identifies potential threats using YOLOv11 model
|
| 35 |
+
- Detects objects including:
|
| 36 |
+
- Humans (trespassers)
|
| 37 |
+
- Vehicles (cars, bikes, buses/trucks)
|
| 38 |
+
- Fire and smoke
|
| 39 |
+
- Processes images, videos, and camera feeds
|
| 40 |
+
- Alerts on potential threats with confidence scores
|
| 41 |
+
|
| 42 |
+

|
| 43 |
+
|
| 44 |
+
## Getting Started
|
| 45 |
+
|
| 46 |
+
### Prerequisites
|
| 47 |
+
|
| 48 |
+
- Python 3.8+
|
| 49 |
+
- pip package manager
|
| 50 |
+
- Virtual environment (recommended)
|
| 51 |
+
|
| 52 |
+
### Installation
|
| 53 |
+
|
| 54 |
+
1. Clone the repository
|
| 55 |
+
```bash
|
| 56 |
+
git clone https://github.com/yourusername/nature-nexus.git
|
| 57 |
+
cd nature-nexus
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
2. Create and activate a virtual environment (optional but recommended)
|
| 61 |
+
```bash
|
| 62 |
+
python -m venv venv
|
| 63 |
+
source venv/bin/activate # On Windows, use: venv\Scripts\activate
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
3. Install required dependencies
|
| 67 |
+
```bash
|
| 68 |
+
pip install -r requirements.txt
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
4. Download models
|
| 72 |
+
```bash
|
| 73 |
+
# Create models directory if it doesn't exist
|
| 74 |
+
mkdir -p models
|
| 75 |
+
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Running the Application
|
| 79 |
+
|
| 80 |
+
Launch the Streamlit application:
|
| 81 |
+
```bash
|
| 82 |
+
streamlit run app.py
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
The application will open in your default web browser at http://localhost:8501
|
| 86 |
+
|
| 87 |
+
## Model Architecture
|
| 88 |
+
|
| 89 |
+
### Deforestation Detection Model
|
| 90 |
+
- **Architecture**: Attention U-Net
|
| 91 |
+
- **Input**: Satellite/aerial imagery (RGB)
|
| 92 |
+
- **Output**: Binary segmentation mask (forest vs. deforested)
|
| 93 |
+
- **Optimization**: ONNX runtime for faster inference
|
| 94 |
+
|
| 95 |
+
### Audio Classification Model
|
| 96 |
+
- **Architecture**: Convolutional Neural Network (CNN)
|
| 97 |
+
- **Input**: Audio spectrograms
|
| 98 |
+
- **Output**: 14 sound classes with confidence scores
|
| 99 |
+
- **Features**: Mel-spectrogram analysis
|
| 100 |
+
|
| 101 |
+
### Object Detection Model
|
| 102 |
+
- **Architecture**: YOLOv11
|
| 103 |
+
- **Input**: Images/video frames
|
| 104 |
+
- **Output**: Bounding boxes, class labels, confidence scores
|
| 105 |
+
- **Classes**: Humans, vehicles, fire, smoke, etc.
|
| 106 |
+
|
| 107 |
+
## System Architecture
|
| 108 |
+
|
| 109 |
+
```
|
| 110 |
+
nature-nexus/
|
| 111 |
+
│
|
| 112 |
+
├── app.py # Main Streamlit application
|
| 113 |
+
├── prediction_engine.py # Deforestation model interface
|
| 114 |
+
│
|
| 115 |
+
├── utils/
|
| 116 |
+
│ ├── audio_model.py # Audio classification model
|
| 117 |
+
│ ├── audio_processing.py # Audio preprocessing utilities
|
| 118 |
+
│ ├── helpers.py # Helper functions for visualization
|
| 119 |
+
│ ├── model.py # U-Net model definition
|
| 120 |
+
│ ├── onnx_converter.py # Converts PyTorch models to ONNX
|
| 121 |
+
│ ├── onnx_inference.py # YOLO object detection inference
|
| 122 |
+
│ └── preprocess.py # Image preprocessing utilities
|
| 123 |
+
│
|
| 124 |
+
└── models/ # Model weights (not included in repo)
|
| 125 |
+
├── deforestation_model.onnx
|
| 126 |
+
├── best_model.pth # Audio model
|
| 127 |
+
└── best_model.onnx # YOLO model
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
## Usage Guide
|
| 131 |
+
|
| 132 |
+
### Deforestation Detection
|
| 133 |
+
1. Select "Deforestation Detection" from the sidebar
|
| 134 |
+
2. Upload satellite or aerial imagery of forest areas
|
| 135 |
+
3. View segmentation results showing forest vs. deforested areas
|
| 136 |
+
4. Analyze metrics including forest coverage and deforestation level
|
| 137 |
+
|
| 138 |
+
### Audio Surveillance
|
| 139 |
+
1. Select "Forest Audio Surveillance" from the sidebar
|
| 140 |
+
2. Choose between uploading audio files or recording live audio
|
| 141 |
+
3. Submit the audio for analysis
|
| 142 |
+
4. View detected sound classification and potential alerts
|
| 143 |
+
|
| 144 |
+
### Object Detection
|
| 145 |
+
1. Select "Object Detection" from the sidebar
|
| 146 |
+
2. Choose between image, video, or camera feed
|
| 147 |
+
3. Adjust confidence and IoU thresholds as needed
|
| 148 |
+
4. Upload or capture input for processing
|
| 149 |
+
5. View detection results with bounding boxes and confidence scores
|
| 150 |
+
|
| 151 |
+
## Custom Model Training
|
| 152 |
+
|
| 153 |
+
To train custom models for your specific forest environment:
|
| 154 |
+
|
| 155 |
+
### Deforestation Model
|
| 156 |
+
```bash
|
| 157 |
+
# Convert trained PyTorch model to ONNX
|
| 158 |
+
python -m utils.onnx_converter models/your_pytorch_model.pth models/deforestation_model.onnx [input_size]
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### Audio Model
|
| 162 |
+
Train on your custom audio dataset and replace the model file at `models/best_model.pth`
|
| 163 |
+
|
| 164 |
+
### YOLO Model
|
| 165 |
+
Train on your custom object dataset and replace the model file at `models/best_model.onnx`
|
| 166 |
+
|
| 167 |
+
## Troubleshooting
|
| 168 |
+
|
| 169 |
+
### Common Issues
|
| 170 |
+
- **Models not loading**: Ensure all model files exist in the `models/` directory
|
| 171 |
+
- **CUDA errors**: If using GPU, verify CUDA and cuDNN are correctly installed
|
| 172 |
+
- **Audio processing issues**: Check audio format compatibility (WAV, MP3, OGG)
|
| 173 |
+
|
| 174 |
+
## Contributing
|
| 175 |
+
|
| 176 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
| 177 |
+
|
| 178 |
+
## License
|
| 179 |
+
|
| 180 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
app.py
ADDED
|
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
import librosa
|
| 7 |
+
import librosa.display
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import tempfile
|
| 10 |
+
import librosa
|
| 11 |
+
import librosa.display
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
# Import deforestation modules
|
| 17 |
+
from prediction_engine import load_onnx_model
|
| 18 |
+
|
| 19 |
+
# Import deforestation modules
|
| 20 |
+
from prediction_engine import load_onnx_model
|
| 21 |
+
from utils.helpers import calculate_deforestation_metrics, create_overlay
|
| 22 |
+
|
| 23 |
+
# Import audio classification modules
|
| 24 |
+
from utils.audio_processing import preprocess_audio
|
| 25 |
+
from utils.audio_model import load_audio_model, predict_audio, class_names
|
| 26 |
+
|
| 27 |
+
# Import YOLO detection modules
|
| 28 |
+
from utils.onnx_inference import YOLOv11
|
| 29 |
+
|
| 30 |
+
# Ensure torch classes path is initialized to avoid warnings
|
| 31 |
+
torch.classes.__path__ = []
|
| 32 |
+
|
| 33 |
+
# Set page config
|
| 34 |
+
st.set_page_config(
|
| 35 |
+
page_title="Nature Nexus - Forest Surveillance",
|
| 36 |
+
page_icon="🌳",
|
| 37 |
+
layout="wide",
|
| 38 |
+
initial_sidebar_state="expanded"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Constants
|
| 43 |
+
DEFOREST_MODEL_INPUT_SIZE = 256
|
| 44 |
+
AUDIO_MODEL_PATH = "models/best_model.pth"
|
| 45 |
+
YOLO_MODEL_PATH = "models/best_model.onnx"
|
| 46 |
+
|
| 47 |
+
# Initialize session state for navigation
|
| 48 |
+
if 'current_service' not in st.session_state:
|
| 49 |
+
st.session_state.current_service = 'deforestation'
|
| 50 |
+
if 'audio_input_method' not in st.session_state:
|
| 51 |
+
st.session_state.audio_input_method = 'upload'
|
| 52 |
+
if 'detection_input_method' not in st.session_state:
|
| 53 |
+
st.session_state.detection_input_method = 'image'
|
| 54 |
+
|
| 55 |
+
# Sidebar for navigation
|
| 56 |
+
with st.sidebar:
|
| 57 |
+
st.title("Nature Nexus")
|
| 58 |
+
st.subheader("Forest Surveillance System")
|
| 59 |
+
|
| 60 |
+
selected_service = st.radio(
|
| 61 |
+
"Select Service:",
|
| 62 |
+
["Deforestation Detection", "Forest Audio Surveillance", "Object Detection"]
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if selected_service == "Deforestation Detection":
|
| 66 |
+
st.session_state.current_service = 'deforestation'
|
| 67 |
+
elif selected_service == "Forest Audio Surveillance":
|
| 68 |
+
st.session_state.current_service = 'audio'
|
| 69 |
+
else:
|
| 70 |
+
st.session_state.current_service = 'detection'
|
| 71 |
+
|
| 72 |
+
st.markdown("---")
|
| 73 |
+
|
| 74 |
+
# Service-specific sidebar content
|
| 75 |
+
if st.session_state.current_service == 'deforestation':
|
| 76 |
+
st.info(
|
| 77 |
+
"""
|
| 78 |
+
**Deforestation Detection**
|
| 79 |
+
|
| 80 |
+
Upload satellite or aerial images to detect areas of deforestation.
|
| 81 |
+
"""
|
| 82 |
+
)
|
| 83 |
+
elif st.session_state.current_service == 'audio':
|
| 84 |
+
st.info(
|
| 85 |
+
"""
|
| 86 |
+
**Forest Audio Surveillance**
|
| 87 |
+
|
| 88 |
+
Detect unusual human-related sounds in forested regions.
|
| 89 |
+
"""
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Audio service specific controls
|
| 93 |
+
st.subheader("Audio Configuration")
|
| 94 |
+
audio_input_method = st.radio(
|
| 95 |
+
"Select Input Method:",
|
| 96 |
+
("Upload Audio", "Record Audio"),
|
| 97 |
+
index=0 if st.session_state.audio_input_method == 'upload' else 1
|
| 98 |
+
)
|
| 99 |
+
st.session_state.audio_input_method = 'upload' if audio_input_method == "Upload Audio" else 'record'
|
| 100 |
+
|
| 101 |
+
# Audio class information
|
| 102 |
+
st.markdown("**Detection Classes:**")
|
| 103 |
+
|
| 104 |
+
# Group classes by category
|
| 105 |
+
human_sounds = ['footsteps', 'coughing', 'laughing', 'breathing',
|
| 106 |
+
'drinking_sipping', 'snoring', 'sneezing']
|
| 107 |
+
tool_sounds = ['chainsaw', 'hand_saw']
|
| 108 |
+
vehicle_sounds = ['car_horn', 'engine', 'siren']
|
| 109 |
+
other_sounds = ['crackling_fire', 'fireworks']
|
| 110 |
+
|
| 111 |
+
st.markdown("👤 **Human Sounds:** " + ", ".join([s.capitalize() for s in human_sounds]))
|
| 112 |
+
st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
|
| 113 |
+
st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
|
| 114 |
+
st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
|
| 115 |
+
else: # Object Detection
|
| 116 |
+
st.info(
|
| 117 |
+
"""
|
| 118 |
+
**Object Detection**
|
| 119 |
+
|
| 120 |
+
Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
|
| 121 |
+
"""
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Detection service specific controls
|
| 125 |
+
st.subheader("Detection Configuration")
|
| 126 |
+
detection_input_method = st.radio(
|
| 127 |
+
"Select Input Method:",
|
| 128 |
+
("Image", "Video", "Camera"),
|
| 129 |
+
index=0 if st.session_state.detection_input_method == 'image' else
|
| 130 |
+
(1 if st.session_state.detection_input_method == 'video' else 2)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if detection_input_method == "Image":
|
| 134 |
+
st.session_state.detection_input_method = 'image'
|
| 135 |
+
elif detection_input_method == "Video":
|
| 136 |
+
st.session_state.detection_input_method = 'video'
|
| 137 |
+
else:
|
| 138 |
+
st.session_state.detection_input_method = 'camera'
|
| 139 |
+
|
| 140 |
+
# Detection threshold controls
|
| 141 |
+
st.subheader("Detection Settings")
|
| 142 |
+
confidence = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
|
| 143 |
+
iou_thres = st.slider("IoU Threshold", 0.0, 1.0, 0.5)
|
| 144 |
+
|
| 145 |
+
# Detection class information
|
| 146 |
+
st.markdown("**Detection Classes:**")
|
| 147 |
+
st.markdown("🚴 **Bike/Bicycle**")
|
| 148 |
+
st.markdown("🚚 **Bus/Truck**")
|
| 149 |
+
st.markdown("🚗 **Car**")
|
| 150 |
+
st.markdown("🔥 **Fire**")
|
| 151 |
+
st.markdown("👤 **Human**")
|
| 152 |
+
st.markdown("💨 **Smoke**")
|
| 153 |
+
|
| 154 |
+
# Load deforestation model
|
| 155 |
+
@st.cache_resource
|
| 156 |
+
def load_cached_deforestation_model():
|
| 157 |
+
model_path = "models/deforestation_model.onnx"
|
| 158 |
+
return load_onnx_model(model_path, input_size=DEFOREST_MODEL_INPUT_SIZE)
|
| 159 |
+
|
| 160 |
+
# Load audio model
|
| 161 |
+
@st.cache_resource
|
| 162 |
+
def load_cached_audio_model():
|
| 163 |
+
return load_audio_model(AUDIO_MODEL_PATH)
|
| 164 |
+
|
| 165 |
+
@st.cache_resource
|
| 166 |
+
def load_cached_yolo_model():
|
| 167 |
+
return YOLOv11(YOLO_MODEL_PATH)
|
| 168 |
+
|
| 169 |
+
# Process image for deforestation detection
|
| 170 |
+
def process_image(model, image):
|
| 171 |
+
"""Process a single image and return results"""
|
| 172 |
+
# Save original image dimensions for display
|
| 173 |
+
orig_height, orig_width = image.shape[:2]
|
| 174 |
+
|
| 175 |
+
# Make prediction
|
| 176 |
+
mask = model.predict(image)
|
| 177 |
+
|
| 178 |
+
# Resize mask back to original dimensions for display
|
| 179 |
+
display_mask = cv2.resize(mask, (orig_width, orig_height))
|
| 180 |
+
|
| 181 |
+
# Create binary mask for visualization
|
| 182 |
+
binary_mask = (display_mask > 0.5).astype(np.uint8) * 255
|
| 183 |
+
|
| 184 |
+
# Create colored overlay
|
| 185 |
+
overlay = create_overlay(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), display_mask)
|
| 186 |
+
|
| 187 |
+
# Calculate metrics
|
| 188 |
+
metrics = calculate_deforestation_metrics(mask)
|
| 189 |
+
|
| 190 |
+
return binary_mask, overlay, metrics
|
| 191 |
+
|
| 192 |
+
# Visualize audio for audio classification
|
| 193 |
+
def visualize_audio(audio_path):
|
| 194 |
+
y, sr = librosa.load(audio_path, sr=16000)
|
| 195 |
+
duration = len(y) / sr
|
| 196 |
+
|
| 197 |
+
fig, ax = plt.subplots(2, 1, figsize=(10, 6))
|
| 198 |
+
|
| 199 |
+
# Waveform plot
|
| 200 |
+
librosa.display.waveshow(y, sr=sr, ax=ax[0])
|
| 201 |
+
ax[0].set_title('Audio Waveform')
|
| 202 |
+
ax[0].set_xlabel('Time (s)')
|
| 203 |
+
ax[0].set_ylabel('Amplitude')
|
| 204 |
+
|
| 205 |
+
# Spectrogram plot
|
| 206 |
+
S = librosa.feature.melspectrogram(y=y, sr=sr)
|
| 207 |
+
S_db = librosa.power_to_db(S, ref=np.max)
|
| 208 |
+
img = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='mel', ax=ax[1])
|
| 209 |
+
fig.colorbar(img, ax=ax[1], format='%+2.0f dB')
|
| 210 |
+
ax[1].set_title('Mel Spectrogram')
|
| 211 |
+
|
| 212 |
+
plt.tight_layout()
|
| 213 |
+
st.pyplot(fig)
|
| 214 |
+
|
| 215 |
+
return y, sr, duration
|
| 216 |
+
|
| 217 |
+
# Process audio for classification
|
| 218 |
+
def process_audio(audio_file):
|
| 219 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
| 220 |
+
tmp_file.write(audio_file.read() if hasattr(audio_file, 'read') else audio_file)
|
| 221 |
+
audio_path = tmp_file.name
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
# Load audio model
|
| 225 |
+
audio_model = load_cached_audio_model()
|
| 226 |
+
|
| 227 |
+
# Visualize audio
|
| 228 |
+
with st.spinner('Analyzing audio...'):
|
| 229 |
+
y, sr, duration = visualize_audio(audio_path)
|
| 230 |
+
st.caption(f"Audio duration: {duration:.2f} seconds")
|
| 231 |
+
|
| 232 |
+
# Make prediction
|
| 233 |
+
with st.spinner('Making prediction...'):
|
| 234 |
+
class_name, confidence = predict_audio(audio_path, audio_model)
|
| 235 |
+
|
| 236 |
+
# Display results
|
| 237 |
+
st.subheader("Detection Results")
|
| 238 |
+
|
| 239 |
+
col1, col2 = st.columns(2)
|
| 240 |
+
with col1:
|
| 241 |
+
st.metric("Detected Sound", class_name.replace('_', ' ').title())
|
| 242 |
+
with col2:
|
| 243 |
+
st.metric("Confidence", f"{confidence*100:.2f}%")
|
| 244 |
+
|
| 245 |
+
# Show alerts based on class
|
| 246 |
+
human_sounds = ['footsteps', 'coughing', 'laughing', 'breathing',
|
| 247 |
+
'drinking_sipping', 'snoring', 'sneezing']
|
| 248 |
+
tool_sounds = ['chainsaw', 'hand_saw']
|
| 249 |
+
|
| 250 |
+
if class_name in human_sounds:
|
| 251 |
+
st.warning("""
|
| 252 |
+
⚠️ **Human Activity Detected!**
|
| 253 |
+
Potential human presence in the monitored area.
|
| 254 |
+
""")
|
| 255 |
+
elif class_name in tool_sounds:
|
| 256 |
+
st.error("""
|
| 257 |
+
🚨 **ALERT: Human Tool Detected!**
|
| 258 |
+
Potential illegal logging or activity detected. Consider immediate verification.
|
| 259 |
+
""")
|
| 260 |
+
elif class_name in ['car_horn', 'engine', 'siren']:
|
| 261 |
+
st.warning("""
|
| 262 |
+
⚠️ **Vehicle Detected!**
|
| 263 |
+
Vehicle sounds detected in the monitored area.
|
| 264 |
+
""")
|
| 265 |
+
elif class_name == 'fireworks':
|
| 266 |
+
st.error("""
|
| 267 |
+
🚨 **ALERT: Fireworks Detected!**
|
| 268 |
+
Potential fire hazard and disturbance to wildlife. Immediate verification required.
|
| 269 |
+
""")
|
| 270 |
+
elif class_name == 'crackling_fire':
|
| 271 |
+
st.error("""
|
| 272 |
+
🚨 **ALERT: Fire Detected!**
|
| 273 |
+
Potential wildfire detected. Immediate verification required.
|
| 274 |
+
""")
|
| 275 |
+
else:
|
| 276 |
+
st.success("✅ Environmental sound detected - no immediate threat")
|
| 277 |
+
|
| 278 |
+
except Exception as e:
|
| 279 |
+
st.error(f"Error processing audio: {str(e)}")
|
| 280 |
+
st.exception(e)
|
| 281 |
+
finally:
|
| 282 |
+
# Clean up temp file
|
| 283 |
+
try:
|
| 284 |
+
os.unlink(audio_path)
|
| 285 |
+
except:
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
# Deforestation detection UI
|
| 289 |
+
def show_deforestation_detection():
|
| 290 |
+
# App title and description
|
| 291 |
+
st.title("🌳 Deforestation Detection")
|
| 292 |
+
st.markdown(
|
| 293 |
+
"""
|
| 294 |
+
This service detects areas of deforestation in satellite or aerial images of forests.
|
| 295 |
+
Upload an image to get started!
|
| 296 |
+
"""
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Model info
|
| 300 |
+
st.info(
|
| 301 |
+
f"⚙️ Model optimized for {DEFOREST_MODEL_INPUT_SIZE}x{DEFOREST_MODEL_INPUT_SIZE} pixel images using ONNX runtime"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Load model
|
| 305 |
+
try:
|
| 306 |
+
model = load_cached_deforestation_model()
|
| 307 |
+
except Exception as e:
|
| 308 |
+
st.error(f"Error loading model: {e}")
|
| 309 |
+
st.info(
|
| 310 |
+
"Make sure you have converted your PyTorch model to ONNX format using the utils/onnx_converter.py script."
|
| 311 |
+
)
|
| 312 |
+
st.code(
|
| 313 |
+
"python -m utils.onnx_converter models/best_model_100.pth models/deforestation_model.onnx"
|
| 314 |
+
)
|
| 315 |
+
return
|
| 316 |
+
|
| 317 |
+
# File uploader for images
|
| 318 |
+
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
|
| 319 |
+
|
| 320 |
+
if uploaded_file is not None:
|
| 321 |
+
# Load image
|
| 322 |
+
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
|
| 323 |
+
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
| 324 |
+
|
| 325 |
+
# Display original image
|
| 326 |
+
st.subheader("Original Image")
|
| 327 |
+
st.image(
|
| 328 |
+
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
|
| 329 |
+
caption="Uploaded Image",
|
| 330 |
+
use_container_width=True,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Add a spinner while processing
|
| 334 |
+
with st.spinner("Processing..."):
|
| 335 |
+
try:
|
| 336 |
+
# Process image
|
| 337 |
+
binary_mask, overlay, metrics = process_image(model, image)
|
| 338 |
+
|
| 339 |
+
# Display results in columns
|
| 340 |
+
col1, col2 = st.columns(2)
|
| 341 |
+
|
| 342 |
+
with col1:
|
| 343 |
+
st.subheader("Segmentation Result")
|
| 344 |
+
st.image(
|
| 345 |
+
binary_mask,
|
| 346 |
+
caption="Forest Areas (White)",
|
| 347 |
+
use_container_width=True,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
with col2:
|
| 351 |
+
st.subheader("Overlay Visualization")
|
| 352 |
+
st.image(
|
| 353 |
+
overlay,
|
| 354 |
+
caption="Green: Forest, Brown: Deforested",
|
| 355 |
+
use_container_width=True,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Display metrics
|
| 359 |
+
st.subheader("Deforestation Analysis")
|
| 360 |
+
|
| 361 |
+
# Create metrics cards
|
| 362 |
+
metrics_col1, metrics_col2, metrics_col3 = st.columns(3)
|
| 363 |
+
|
| 364 |
+
with metrics_col1:
|
| 365 |
+
st.metric(
|
| 366 |
+
label="Forest Coverage",
|
| 367 |
+
value=f"{metrics['forest_percentage']:.1f}%",
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
with metrics_col2:
|
| 371 |
+
st.metric(
|
| 372 |
+
label="Deforested Area",
|
| 373 |
+
value=f"{metrics['deforested_percentage']:.1f}%",
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
with metrics_col3:
|
| 377 |
+
st.metric(
|
| 378 |
+
label="Deforestation Level",
|
| 379 |
+
value=metrics["deforestation_level"],
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
st.error(f"Error during processing: {e}")
|
| 384 |
+
|
| 385 |
+
# Audio classification UI
|
| 386 |
+
def show_audio_classification():
|
| 387 |
+
# App title and description
|
| 388 |
+
st.title("🎧 Forest Audio Surveillance")
|
| 389 |
+
st.markdown("""
|
| 390 |
+
Detect unusual human-related sounds in forested regions to prevent illegal activities.
|
| 391 |
+
Supported sounds: {}
|
| 392 |
+
""".format(", ".join(class_names)))
|
| 393 |
+
|
| 394 |
+
if st.session_state.audio_input_method == 'upload':
|
| 395 |
+
st.header("Upload Audio File")
|
| 396 |
+
|
| 397 |
+
sample_col, upload_col = st.columns(2)
|
| 398 |
+
with sample_col:
|
| 399 |
+
st.info("Upload a WAV, MP3 or OGG file with forest sounds")
|
| 400 |
+
st.markdown("""
|
| 401 |
+
**Tips for best results:**
|
| 402 |
+
- Use audio with minimal background noise
|
| 403 |
+
- Ensure the sound of interest is clear
|
| 404 |
+
- 2-3 second clips work best
|
| 405 |
+
""")
|
| 406 |
+
|
| 407 |
+
with upload_col:
|
| 408 |
+
audio_file = st.file_uploader(
|
| 409 |
+
"Choose an audio file",
|
| 410 |
+
type=["wav", "mp3", "ogg"],
|
| 411 |
+
help="Supported formats: WAV, MP3, OGG"
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
if audio_file:
|
| 415 |
+
st.success("File uploaded successfully!")
|
| 416 |
+
with st.expander("Audio Preview", expanded=True):
|
| 417 |
+
st.audio(audio_file)
|
| 418 |
+
process_audio(audio_file)
|
| 419 |
+
|
| 420 |
+
else: # Record mode
|
| 421 |
+
st.header("Record Live Audio")
|
| 422 |
+
|
| 423 |
+
st.info("""
|
| 424 |
+
Click the microphone button below to record a sound for analysis.
|
| 425 |
+
**Note:** Please ensure your browser has permission to access your microphone.
|
| 426 |
+
When prompted, click "Allow" to enable recording.
|
| 427 |
+
""")
|
| 428 |
+
|
| 429 |
+
recorded_audio = st.audio_input(
|
| 430 |
+
label="Record a sound",
|
| 431 |
+
key="audio_recorder",
|
| 432 |
+
help="Click to record forest sounds for analysis",
|
| 433 |
+
label_visibility="visible"
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
if recorded_audio:
|
| 437 |
+
st.success("Audio recorded successfully!")
|
| 438 |
+
with st.expander("Recorded Audio", expanded=True):
|
| 439 |
+
st.audio(recorded_audio)
|
| 440 |
+
process_audio(recorded_audio)
|
| 441 |
+
else:
|
| 442 |
+
st.write("Waiting for recording...")
|
| 443 |
+
|
| 444 |
+
# Object Detection UI
|
| 445 |
+
def show_object_detection():
|
| 446 |
+
# App title and description
|
| 447 |
+
st.title("🔍 Forest Object Detection")
|
| 448 |
+
st.markdown(
|
| 449 |
+
"""
|
| 450 |
+
Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
|
| 451 |
+
Choose an input method to begin detection.
|
| 452 |
+
"""
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# Model info
|
| 456 |
+
st.info("⚙️ Object detection model optimized with ONNX runtime for faster inference")
|
| 457 |
+
|
| 458 |
+
# Load model
|
| 459 |
+
try:
|
| 460 |
+
model = load_cached_yolo_model()
|
| 461 |
+
# Update model confidence and IoU thresholds from sidebar
|
| 462 |
+
confidence = st.session_state.get('confidence', 0.5)
|
| 463 |
+
iou_thres = st.session_state.get('iou_thres', 0.5)
|
| 464 |
+
model.conf_thres = confidence
|
| 465 |
+
model.iou_thres = iou_thres
|
| 466 |
+
except Exception as e:
|
| 467 |
+
st.error(f"Error loading model: {e}")
|
| 468 |
+
st.info(
|
| 469 |
+
"Make sure you have the YOLO ONNX model file available at models/best_model.onnx"
|
| 470 |
+
)
|
| 471 |
+
return
|
| 472 |
+
|
| 473 |
+
# Input method based selection
|
| 474 |
+
if st.session_state.detection_input_method == 'image':
|
| 475 |
+
# Image upload
|
| 476 |
+
img_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
|
| 477 |
+
if img_file is not None:
|
| 478 |
+
# Load image
|
| 479 |
+
file_bytes = np.asarray(bytearray(img_file.read()), dtype=np.uint8)
|
| 480 |
+
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
| 481 |
+
if image is not None:
|
| 482 |
+
# Display original image
|
| 483 |
+
st.subheader("Original Image")
|
| 484 |
+
st.image(
|
| 485 |
+
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
|
| 486 |
+
caption="Uploaded Image",
|
| 487 |
+
use_container_width=True,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# Process with detection model
|
| 491 |
+
with st.spinner("Processing image..."):
|
| 492 |
+
try:
|
| 493 |
+
detections = model.detect(image)
|
| 494 |
+
result_image = model.draw_detections(image.copy(), detections)
|
| 495 |
+
|
| 496 |
+
# Display results
|
| 497 |
+
st.subheader("Detection Results")
|
| 498 |
+
st.image(
|
| 499 |
+
cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
|
| 500 |
+
caption="Detected Objects",
|
| 501 |
+
use_container_width=True,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# Display detection statistics
|
| 505 |
+
st.subheader("Detection Statistics")
|
| 506 |
+
|
| 507 |
+
# Count detections by class
|
| 508 |
+
class_counts = {}
|
| 509 |
+
for det in detections:
|
| 510 |
+
class_name = det['class']
|
| 511 |
+
if class_name in class_counts:
|
| 512 |
+
class_counts[class_name] += 1
|
| 513 |
+
else:
|
| 514 |
+
class_counts[class_name] = 1
|
| 515 |
+
|
| 516 |
+
# Display counts with emojis
|
| 517 |
+
cols = st.columns(3)
|
| 518 |
+
col_idx = 0
|
| 519 |
+
|
| 520 |
+
for class_name, count in class_counts.items():
|
| 521 |
+
emoji = "👤" if class_name == "human" else (
|
| 522 |
+
"🔥" if class_name == "fire" else (
|
| 523 |
+
"💨" if class_name == "smoke" else (
|
| 524 |
+
"🚗" if class_name == "car" else (
|
| 525 |
+
"🚴" if class_name == "bike-bicycle" else "🚚"))))
|
| 526 |
+
|
| 527 |
+
with cols[col_idx % 3]:
|
| 528 |
+
st.metric(f"{emoji} {class_name.capitalize()}", count)
|
| 529 |
+
col_idx += 1
|
| 530 |
+
|
| 531 |
+
# Check for priority threats
|
| 532 |
+
if "fire" in class_counts or "smoke" in class_counts:
|
| 533 |
+
st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected. Immediate action required.")
|
| 534 |
+
|
| 535 |
+
if "human" in class_counts or "car" in class_counts or "bike-bicycle" in class_counts or "bus-truck" in class_counts:
|
| 536 |
+
st.warning("⚠️ **Trespassers Detected!** Unauthorized entry detected in monitored area.")
|
| 537 |
+
|
| 538 |
+
except Exception as e:
|
| 539 |
+
st.error(f"Error during detection: {e}")
|
| 540 |
+
st.exception(e)
|
| 541 |
+
|
| 542 |
+
elif st.session_state.detection_input_method == 'video':
|
| 543 |
+
# Video upload
|
| 544 |
+
video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov"])
|
| 545 |
+
if video_file is not None:
|
| 546 |
+
# Save uploaded video to temp file
|
| 547 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
|
| 548 |
+
tfile.write(video_file.read())
|
| 549 |
+
temp_video_path = tfile.name
|
| 550 |
+
|
| 551 |
+
# Display video upload success
|
| 552 |
+
st.success("Video uploaded successfully!")
|
| 553 |
+
|
| 554 |
+
# Process video button
|
| 555 |
+
if st.button("Process Video"):
|
| 556 |
+
with st.spinner("Processing video... This may take a while."):
|
| 557 |
+
try:
|
| 558 |
+
# Open video file
|
| 559 |
+
cap = cv2.VideoCapture(temp_video_path)
|
| 560 |
+
|
| 561 |
+
# Create video writer for output
|
| 562 |
+
output_path = "output_video.mp4"
|
| 563 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 564 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 565 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 566 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 567 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 568 |
+
|
| 569 |
+
# Create placeholder for video frames
|
| 570 |
+
video_placeholder = st.empty()
|
| 571 |
+
status_text = st.empty()
|
| 572 |
+
|
| 573 |
+
# Process frames
|
| 574 |
+
frame_count = 0
|
| 575 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 576 |
+
|
| 577 |
+
while cap.isOpened():
|
| 578 |
+
ret, frame = cap.read()
|
| 579 |
+
if not ret:
|
| 580 |
+
break
|
| 581 |
+
|
| 582 |
+
# Process every 5th frame for speed
|
| 583 |
+
if frame_count % 5 == 0:
|
| 584 |
+
detections = model.detect(frame)
|
| 585 |
+
result_frame = model.draw_detections(frame.copy(), detections)
|
| 586 |
+
|
| 587 |
+
# Update preview
|
| 588 |
+
if frame_count % 15 == 0: # Update display less frequently
|
| 589 |
+
video_placeholder.image(
|
| 590 |
+
cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB),
|
| 591 |
+
caption="Processing Video",
|
| 592 |
+
use_container_width=True
|
| 593 |
+
)
|
| 594 |
+
progress = min(100, int((frame_count / total_frames) * 100))
|
| 595 |
+
status_text.text(f"Processing: {progress}% complete")
|
| 596 |
+
else:
|
| 597 |
+
result_frame = frame # Skip detection on some frames
|
| 598 |
+
|
| 599 |
+
# Write frame to output video
|
| 600 |
+
out.write(result_frame)
|
| 601 |
+
frame_count += 1
|
| 602 |
+
|
| 603 |
+
# Release resources
|
| 604 |
+
cap.release()
|
| 605 |
+
out.release()
|
| 606 |
+
|
| 607 |
+
# Display completion message
|
| 608 |
+
st.success("Video processing complete!")
|
| 609 |
+
|
| 610 |
+
# Provide download button for processed video
|
| 611 |
+
with open(output_path, "rb") as file:
|
| 612 |
+
st.download_button(
|
| 613 |
+
label="Download Processed Video",
|
| 614 |
+
data=file,
|
| 615 |
+
file_name="forest_surveillance_results.mp4",
|
| 616 |
+
mime="video/mp4"
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
except Exception as e:
|
| 620 |
+
st.error(f"Error processing video: {e}")
|
| 621 |
+
st.exception(e)
|
| 622 |
+
finally:
|
| 623 |
+
# Clean up temp file
|
| 624 |
+
try:
|
| 625 |
+
os.unlink(temp_video_path)
|
| 626 |
+
except:
|
| 627 |
+
pass
|
| 628 |
+
|
| 629 |
+
else: # Camera mode
|
| 630 |
+
# Live camera feed
|
| 631 |
+
st.subheader("Live Camera Detection")
|
| 632 |
+
st.info("Use your webcam to detect objects in real-time")
|
| 633 |
+
|
| 634 |
+
cam = st.camera_input("Camera Feed")
|
| 635 |
+
|
| 636 |
+
if cam:
|
| 637 |
+
# Process camera input
|
| 638 |
+
with st.spinner("Processing image..."):
|
| 639 |
+
try:
|
| 640 |
+
# Convert image
|
| 641 |
+
image = cv2.imdecode(np.frombuffer(cam.getvalue(), np.uint8), cv2.IMREAD_COLOR)
|
| 642 |
+
|
| 643 |
+
# Run detection
|
| 644 |
+
detections = model.detect(image)
|
| 645 |
+
result_image = model.draw_detections(image.copy(), detections)
|
| 646 |
+
|
| 647 |
+
# Display results
|
| 648 |
+
st.image(
|
| 649 |
+
cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
|
| 650 |
+
caption="Detection Results",
|
| 651 |
+
use_container_width=True
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# Show detection summary
|
| 655 |
+
if detections:
|
| 656 |
+
# Count detections by class
|
| 657 |
+
class_counts = {}
|
| 658 |
+
for det in detections:
|
| 659 |
+
class_name = det['class']
|
| 660 |
+
if class_name in class_counts:
|
| 661 |
+
class_counts[class_name] += 1
|
| 662 |
+
else:
|
| 663 |
+
class_counts[class_name] = 1
|
| 664 |
+
|
| 665 |
+
# Display as metrics
|
| 666 |
+
st.subheader("Detection Summary")
|
| 667 |
+
cols = st.columns(3)
|
| 668 |
+
for i, (class_name, count) in enumerate(class_counts.items()):
|
| 669 |
+
with cols[i % 3]:
|
| 670 |
+
st.metric(class_name.capitalize(), count)
|
| 671 |
+
|
| 672 |
+
# Check for priority threats
|
| 673 |
+
if "fire" in class_counts or "smoke" in class_counts:
|
| 674 |
+
st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected.")
|
| 675 |
+
|
| 676 |
+
if "human" in class_counts:
|
| 677 |
+
st.warning("⚠️ **Trespasser Detected!** Human presence detected.")
|
| 678 |
+
else:
|
| 679 |
+
st.info("No objects detected in frame")
|
| 680 |
+
|
| 681 |
+
except Exception as e:
|
| 682 |
+
st.error(f"Error processing camera feed: {e}")
|
| 683 |
+
|
| 684 |
+
# Main function
|
| 685 |
+
def main():
|
| 686 |
+
# Check which service is selected and render appropriate UI
|
| 687 |
+
if st.session_state.current_service == 'deforestation':
|
| 688 |
+
show_deforestation_detection()
|
| 689 |
+
elif st.session_state.current_service == 'audio':
|
| 690 |
+
show_audio_classification()
|
| 691 |
+
else: # 'detection'
|
| 692 |
+
show_object_detection()
|
| 693 |
+
|
| 694 |
+
# Footer
|
| 695 |
+
st.markdown("---")
|
| 696 |
+
st.markdown("""
|
| 697 |
+
<div style="text-align: center; padding: 10px;">
|
| 698 |
+
<p>Nature Nexus - Forest Surveillance System | 🌳 Protect Natural Ecosystems</p>
|
| 699 |
+
<p><small>Built with Streamlit and PyTorch</small></p>
|
| 700 |
+
</div>
|
| 701 |
+
""", unsafe_allow_html=True)
|
| 702 |
+
|
| 703 |
+
if __name__ == "__main__":
|
| 704 |
+
main()
|
predict.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import onnxruntime as ort
|
| 5 |
+
from utils.preprocess import preprocess_image
|
| 6 |
+
from utils.model import load_model
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PredictionEngine:
|
| 10 |
+
def __init__(self, model_path=None, use_onnx=True, input_size=256):
|
| 11 |
+
"""
|
| 12 |
+
Initialize the prediction engine
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
model_path: Path to the model file (PyTorch or ONNX)
|
| 16 |
+
use_onnx: Whether to use ONNX runtime for inference
|
| 17 |
+
input_size: Input size for the model (default is 256)
|
| 18 |
+
"""
|
| 19 |
+
self.use_onnx = use_onnx
|
| 20 |
+
self.input_size = input_size
|
| 21 |
+
|
| 22 |
+
if model_path:
|
| 23 |
+
if use_onnx:
|
| 24 |
+
self.model = self._load_onnx_model(model_path)
|
| 25 |
+
else:
|
| 26 |
+
self.model = load_model(model_path)
|
| 27 |
+
else:
|
| 28 |
+
self.model = None
|
| 29 |
+
|
| 30 |
+
def _load_onnx_model(self, model_path):
|
| 31 |
+
"""
|
| 32 |
+
Load an ONNX model
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
model_path: Path to the ONNX model
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
ONNX Runtime InferenceSession
|
| 39 |
+
"""
|
| 40 |
+
# Try with CUDA first, fall back to CPU if needed
|
| 41 |
+
try:
|
| 42 |
+
session = ort.InferenceSession(
|
| 43 |
+
model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 44 |
+
)
|
| 45 |
+
print("ONNX model loaded with CUDA support")
|
| 46 |
+
return session
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}")
|
| 49 |
+
session = ort.InferenceSession(
|
| 50 |
+
model_path, providers=["CPUExecutionProvider"]
|
| 51 |
+
)
|
| 52 |
+
print("ONNX model loaded with CPU support")
|
| 53 |
+
return session
|
| 54 |
+
|
| 55 |
+
def preprocess(self, image):
|
| 56 |
+
"""
|
| 57 |
+
Preprocess an image for prediction
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
image: Input image (numpy array)
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Processed image suitable for the model
|
| 64 |
+
"""
|
| 65 |
+
# Keep the original image for reference
|
| 66 |
+
self.original_shape = image.shape[:2]
|
| 67 |
+
|
| 68 |
+
# Preprocess image
|
| 69 |
+
if self.use_onnx:
|
| 70 |
+
# For ONNX, we need to ensure the input is exactly the expected size
|
| 71 |
+
tensor = preprocess_image(image, img_size=self.input_size)
|
| 72 |
+
return tensor.numpy()
|
| 73 |
+
else:
|
| 74 |
+
# For PyTorch
|
| 75 |
+
return preprocess_image(image, img_size=self.input_size)
|
| 76 |
+
|
| 77 |
+
def predict(self, image):
|
| 78 |
+
"""
|
| 79 |
+
Make a prediction on an image
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
image: Input image (numpy array)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Predicted mask
|
| 86 |
+
"""
|
| 87 |
+
if self.model is None:
|
| 88 |
+
raise ValueError("Model not loaded. Initialize with a valid model path.")
|
| 89 |
+
|
| 90 |
+
# Preprocess the image
|
| 91 |
+
processed_input = self.preprocess(image)
|
| 92 |
+
|
| 93 |
+
# Run inference
|
| 94 |
+
if self.use_onnx:
|
| 95 |
+
# Get input and output names
|
| 96 |
+
input_name = self.model.get_inputs()[0].name
|
| 97 |
+
output_name = self.model.get_outputs()[0].name
|
| 98 |
+
|
| 99 |
+
# Run ONNX inference
|
| 100 |
+
outputs = self.model.run([output_name], {input_name: processed_input})
|
| 101 |
+
|
| 102 |
+
# Apply sigmoid to output
|
| 103 |
+
mask = 1 / (1 + np.exp(-outputs[0].squeeze()))
|
| 104 |
+
else:
|
| 105 |
+
# PyTorch inference
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
# Move to device
|
| 108 |
+
device = next(self.model.parameters()).device
|
| 109 |
+
processed_input = processed_input.to(device)
|
| 110 |
+
|
| 111 |
+
# Forward pass
|
| 112 |
+
output = self.model(processed_input)
|
| 113 |
+
output = torch.sigmoid(output)
|
| 114 |
+
|
| 115 |
+
# Convert to numpy
|
| 116 |
+
mask = output.cpu().numpy().squeeze()
|
| 117 |
+
|
| 118 |
+
return mask
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def load_pytorch_model(model_path):
|
| 122 |
+
"""
|
| 123 |
+
Load the PyTorch model for prediction
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
model_path: Path to the PyTorch model
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
PredictionEngine instance
|
| 130 |
+
"""
|
| 131 |
+
return PredictionEngine(model_path, use_onnx=False)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def load_onnx_model(model_path, input_size=256):
|
| 135 |
+
"""
|
| 136 |
+
Load the ONNX model for prediction
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
model_path: Path to the ONNX model
|
| 140 |
+
input_size: Input size for the model
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
PredictionEngine instance
|
| 144 |
+
"""
|
| 145 |
+
return PredictionEngine(model_path, use_onnx=True, input_size=input_size)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# For backwards compatibility
|
| 149 |
+
def predict(model, image):
|
| 150 |
+
"""
|
| 151 |
+
Legacy function for prediction
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
model: Model instance
|
| 155 |
+
image: Input image
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Predicted mask
|
| 159 |
+
"""
|
| 160 |
+
if isinstance(model, PredictionEngine):
|
| 161 |
+
return model.predict(image)
|
| 162 |
+
|
| 163 |
+
engine = PredictionEngine(use_onnx=True)
|
| 164 |
+
engine.model = model
|
| 165 |
+
return engine.predict(image)
|
prediction_engine.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import onnxruntime as ort
|
| 5 |
+
from utils.preprocess import preprocess_image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PredictionEngine:
|
| 9 |
+
def __init__(self, model_path=None, use_onnx=True, input_size=256):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the prediction engine
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
model_path: Path to the model file (PyTorch or ONNX)
|
| 15 |
+
use_onnx: Whether to use ONNX runtime for inference
|
| 16 |
+
input_size: Input size for the model (default is 256)
|
| 17 |
+
"""
|
| 18 |
+
self.use_onnx = use_onnx
|
| 19 |
+
self.input_size = input_size
|
| 20 |
+
|
| 21 |
+
if model_path:
|
| 22 |
+
if use_onnx:
|
| 23 |
+
self.model = self._load_onnx_model(model_path)
|
| 24 |
+
else:
|
| 25 |
+
self.model = self._load_pytorch_model(model_path)
|
| 26 |
+
else:
|
| 27 |
+
self.model = None
|
| 28 |
+
|
| 29 |
+
def _load_onnx_model(self, model_path):
|
| 30 |
+
"""
|
| 31 |
+
Load an ONNX model
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model_path: Path to the ONNX model
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
ONNX Runtime InferenceSession
|
| 38 |
+
"""
|
| 39 |
+
# Try with CUDA first, fall back to CPU if needed
|
| 40 |
+
try:
|
| 41 |
+
session = ort.InferenceSession(
|
| 42 |
+
model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 43 |
+
)
|
| 44 |
+
print("ONNX model loaded with CUDA support")
|
| 45 |
+
return session
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}")
|
| 48 |
+
session = ort.InferenceSession(
|
| 49 |
+
model_path, providers=["CPUExecutionProvider"]
|
| 50 |
+
)
|
| 51 |
+
print("ONNX model loaded with CPU support")
|
| 52 |
+
return session
|
| 53 |
+
|
| 54 |
+
def _load_pytorch_model(self, model_path):
|
| 55 |
+
"""
|
| 56 |
+
Load a PyTorch model
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model_path: Path to the PyTorch model
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
PyTorch model
|
| 63 |
+
"""
|
| 64 |
+
from utils.model import load_model
|
| 65 |
+
return load_model(model_path)
|
| 66 |
+
|
| 67 |
+
def preprocess(self, image):
|
| 68 |
+
"""
|
| 69 |
+
Preprocess an image for prediction
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
image: Input image (numpy array)
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Processed image suitable for the model
|
| 76 |
+
"""
|
| 77 |
+
# Keep the original image for reference
|
| 78 |
+
self.original_shape = image.shape[:2]
|
| 79 |
+
|
| 80 |
+
# Preprocess image
|
| 81 |
+
if self.use_onnx:
|
| 82 |
+
# For ONNX, we need to ensure the input is exactly the expected size
|
| 83 |
+
tensor = preprocess_image(image, img_size=self.input_size)
|
| 84 |
+
return tensor.numpy()
|
| 85 |
+
else:
|
| 86 |
+
# For PyTorch
|
| 87 |
+
return preprocess_image(image, img_size=self.input_size)
|
| 88 |
+
|
| 89 |
+
def predict(self, image):
|
| 90 |
+
"""
|
| 91 |
+
Make a prediction on an image
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
image: Input image (numpy array)
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Predicted mask
|
| 98 |
+
"""
|
| 99 |
+
if self.model is None:
|
| 100 |
+
raise ValueError("Model not loaded. Initialize with a valid model path.")
|
| 101 |
+
|
| 102 |
+
# Preprocess the image
|
| 103 |
+
processed_input = self.preprocess(image)
|
| 104 |
+
|
| 105 |
+
# Run inference
|
| 106 |
+
if self.use_onnx:
|
| 107 |
+
# Get input and output names
|
| 108 |
+
input_name = self.model.get_inputs()[0].name
|
| 109 |
+
output_name = self.model.get_outputs()[0].name
|
| 110 |
+
|
| 111 |
+
# Run ONNX inference
|
| 112 |
+
outputs = self.model.run([output_name], {input_name: processed_input})
|
| 113 |
+
|
| 114 |
+
# Apply sigmoid to output
|
| 115 |
+
mask = 1 / (1 + np.exp(-outputs[0].squeeze()))
|
| 116 |
+
else:
|
| 117 |
+
# PyTorch inference
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
# Move to device
|
| 120 |
+
device = next(self.model.parameters()).device
|
| 121 |
+
processed_input = processed_input.to(device)
|
| 122 |
+
|
| 123 |
+
# Forward pass
|
| 124 |
+
output = self.model(processed_input)
|
| 125 |
+
output = torch.sigmoid(output)
|
| 126 |
+
|
| 127 |
+
# Convert to numpy
|
| 128 |
+
mask = output.cpu().numpy().squeeze()
|
| 129 |
+
|
| 130 |
+
return mask
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_pytorch_model(model_path):
|
| 134 |
+
"""
|
| 135 |
+
Load the PyTorch model for prediction
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
model_path: Path to the PyTorch model
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
PredictionEngine instance
|
| 142 |
+
"""
|
| 143 |
+
return PredictionEngine(model_path, use_onnx=False)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def load_onnx_model(model_path, input_size=256):
|
| 147 |
+
"""
|
| 148 |
+
Load the ONNX model for prediction
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
model_path: Path to the ONNX model
|
| 152 |
+
input_size: Input size for the model
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
PredictionEngine instance
|
| 156 |
+
"""
|
| 157 |
+
return PredictionEngine(model_path, use_onnx=True, input_size=input_size)
|
requirements.txt
CHANGED
|
@@ -1,3 +1,13 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
torch==2.1.2
|
| 3 |
+
torchvision==0.16.2
|
| 4 |
+
opencv-python-headless
|
| 5 |
+
albumentations
|
| 6 |
+
numpy
|
| 7 |
+
Pillow
|
| 8 |
+
scikit-image
|
| 9 |
+
scikit-learn
|
| 10 |
+
matplotlib
|
| 11 |
+
onnxruntime
|
| 12 |
+
onnx
|
| 13 |
+
supervision
|