div828 commited on
Commit
6b9715c
·
verified ·
1 Parent(s): c4e6d01

Upload 8 files

Browse files
Files changed (8) hide show
  1. .gitattributes +1 -34
  2. .gitignore +2 -0
  3. LICENSE +21 -0
  4. Readme.md +180 -0
  5. app.py +704 -0
  6. predict.py +165 -0
  7. prediction_engine.py +157 -0
  8. requirements.txt +13 -3
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.onnx filter=lfs diff=lfs merge=lfs -text
2
+ **.onnx filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ **/__pycache__
2
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 smokeyScraper
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Readme.md ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Nature Nexus - Forest Surveillance System
2
+
3
+ ## Overview
4
+
5
+ Nature Nexus is an advanced forest surveillance system designed to protect natural ecosystems through AI-powered monitoring. It combines multiple detection technologies to identify illegal activities, monitor deforestation, and detect potential threats to forest areas.
6
+
7
+ The application leverages:
8
+ - **Satellite Imagery Analysis** - Detects deforestation using segmentation models
9
+ - **Audio Surveillance** - Identifies unusual sounds like chainsaws, vehicles, and human activity
10
+ - **Object Detection** - Recognizes trespassers, vehicles, fires, and other threats
11
+
12
+ ## Features
13
+
14
+ ### 1. Deforestation Detection
15
+ - Analyzes satellite or aerial imagery to identify deforested areas
16
+ - Uses Attention U-Net segmentation model optimized with ONNX runtime
17
+ - Provides detailed metrics on forest coverage and deforestation levels
18
+ - Visualizes results with color-coded overlays
19
+
20
+ ![Image Description](imgs/deforestation.png)
21
+
22
+ ### 2. Forest Audio Surveillance
23
+ - Detects unusual sounds that may indicate illegal activities
24
+ - Classifies various sounds including:
25
+ - **Human Sounds**: Footsteps, coughing, laughing, breathing, etc.
26
+ - **Tool Sounds**: Chainsaw, hand saw
27
+ - **Vehicle Sounds**: Car horn, engine, siren
28
+ - **Other Sounds**: Crackling fire, fireworks
29
+ - Supports both uploaded audio files and real-time recording
30
+
31
+ ![Image Description](imgs/audio.png)
32
+
33
+ ### 3. Object Detection
34
+ - Identifies potential threats using YOLOv11 model
35
+ - Detects objects including:
36
+ - Humans (trespassers)
37
+ - Vehicles (cars, bikes, buses/trucks)
38
+ - Fire and smoke
39
+ - Processes images, videos, and camera feeds
40
+ - Alerts on potential threats with confidence scores
41
+
42
+ ![Image Description](imgs/yolo.png)
43
+
44
+ ## Getting Started
45
+
46
+ ### Prerequisites
47
+
48
+ - Python 3.8+
49
+ - pip package manager
50
+ - Virtual environment (recommended)
51
+
52
+ ### Installation
53
+
54
+ 1. Clone the repository
55
+ ```bash
56
+ git clone https://github.com/yourusername/nature-nexus.git
57
+ cd nature-nexus
58
+ ```
59
+
60
+ 2. Create and activate a virtual environment (optional but recommended)
61
+ ```bash
62
+ python -m venv venv
63
+ source venv/bin/activate # On Windows, use: venv\Scripts\activate
64
+ ```
65
+
66
+ 3. Install required dependencies
67
+ ```bash
68
+ pip install -r requirements.txt
69
+ ```
70
+
71
+ 4. Download models
72
+ ```bash
73
+ # Create models directory if it doesn't exist
74
+ mkdir -p models
75
+
76
+ ```
77
+
78
+ ### Running the Application
79
+
80
+ Launch the Streamlit application:
81
+ ```bash
82
+ streamlit run app.py
83
+ ```
84
+
85
+ The application will open in your default web browser at http://localhost:8501
86
+
87
+ ## Model Architecture
88
+
89
+ ### Deforestation Detection Model
90
+ - **Architecture**: Attention U-Net
91
+ - **Input**: Satellite/aerial imagery (RGB)
92
+ - **Output**: Binary segmentation mask (forest vs. deforested)
93
+ - **Optimization**: ONNX runtime for faster inference
94
+
95
+ ### Audio Classification Model
96
+ - **Architecture**: Convolutional Neural Network (CNN)
97
+ - **Input**: Audio spectrograms
98
+ - **Output**: 14 sound classes with confidence scores
99
+ - **Features**: Mel-spectrogram analysis
100
+
101
+ ### Object Detection Model
102
+ - **Architecture**: YOLOv11
103
+ - **Input**: Images/video frames
104
+ - **Output**: Bounding boxes, class labels, confidence scores
105
+ - **Classes**: Humans, vehicles, fire, smoke, etc.
106
+
107
+ ## System Architecture
108
+
109
+ ```
110
+ nature-nexus/
111
+
112
+ ├── app.py # Main Streamlit application
113
+ ├── prediction_engine.py # Deforestation model interface
114
+
115
+ ├── utils/
116
+ │ ├── audio_model.py # Audio classification model
117
+ │ ├── audio_processing.py # Audio preprocessing utilities
118
+ │ ├── helpers.py # Helper functions for visualization
119
+ │ ├── model.py # U-Net model definition
120
+ │ ├── onnx_converter.py # Converts PyTorch models to ONNX
121
+ │ ├── onnx_inference.py # YOLO object detection inference
122
+ │ └── preprocess.py # Image preprocessing utilities
123
+
124
+ └── models/ # Model weights (not included in repo)
125
+ ├── deforestation_model.onnx
126
+ ├── best_model.pth # Audio model
127
+ └── best_model.onnx # YOLO model
128
+ ```
129
+
130
+ ## Usage Guide
131
+
132
+ ### Deforestation Detection
133
+ 1. Select "Deforestation Detection" from the sidebar
134
+ 2. Upload satellite or aerial imagery of forest areas
135
+ 3. View segmentation results showing forest vs. deforested areas
136
+ 4. Analyze metrics including forest coverage and deforestation level
137
+
138
+ ### Audio Surveillance
139
+ 1. Select "Forest Audio Surveillance" from the sidebar
140
+ 2. Choose between uploading audio files or recording live audio
141
+ 3. Submit the audio for analysis
142
+ 4. View detected sound classification and potential alerts
143
+
144
+ ### Object Detection
145
+ 1. Select "Object Detection" from the sidebar
146
+ 2. Choose between image, video, or camera feed
147
+ 3. Adjust confidence and IoU thresholds as needed
148
+ 4. Upload or capture input for processing
149
+ 5. View detection results with bounding boxes and confidence scores
150
+
151
+ ## Custom Model Training
152
+
153
+ To train custom models for your specific forest environment:
154
+
155
+ ### Deforestation Model
156
+ ```bash
157
+ # Convert trained PyTorch model to ONNX
158
+ python -m utils.onnx_converter models/your_pytorch_model.pth models/deforestation_model.onnx [input_size]
159
+ ```
160
+
161
+ ### Audio Model
162
+ Train on your custom audio dataset and replace the model file at `models/best_model.pth`
163
+
164
+ ### YOLO Model
165
+ Train on your custom object dataset and replace the model file at `models/best_model.onnx`
166
+
167
+ ## Troubleshooting
168
+
169
+ ### Common Issues
170
+ - **Models not loading**: Ensure all model files exist in the `models/` directory
171
+ - **CUDA errors**: If using GPU, verify CUDA and cuDNN are correctly installed
172
+ - **Audio processing issues**: Check audio format compatibility (WAV, MP3, OGG)
173
+
174
+ ## Contributing
175
+
176
+ Contributions are welcome! Please feel free to submit a Pull Request.
177
+
178
+ ## License
179
+
180
+ This project is licensed under the MIT License - see the LICENSE file for details.
app.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ import tempfile
6
+ import librosa
7
+ import librosa.display
8
+ import matplotlib.pyplot as plt
9
+ import tempfile
10
+ import librosa
11
+ import librosa.display
12
+ import matplotlib.pyplot as plt
13
+ from PIL import Image
14
+ import torch
15
+
16
+ # Import deforestation modules
17
+ from prediction_engine import load_onnx_model
18
+
19
+ # Import deforestation modules
20
+ from prediction_engine import load_onnx_model
21
+ from utils.helpers import calculate_deforestation_metrics, create_overlay
22
+
23
+ # Import audio classification modules
24
+ from utils.audio_processing import preprocess_audio
25
+ from utils.audio_model import load_audio_model, predict_audio, class_names
26
+
27
+ # Import YOLO detection modules
28
+ from utils.onnx_inference import YOLOv11
29
+
30
+ # Ensure torch classes path is initialized to avoid warnings
31
+ torch.classes.__path__ = []
32
+
33
+ # Set page config
34
+ st.set_page_config(
35
+ page_title="Nature Nexus - Forest Surveillance",
36
+ page_icon="🌳",
37
+ layout="wide",
38
+ initial_sidebar_state="expanded"
39
+ )
40
+
41
+
42
+ # Constants
43
+ DEFOREST_MODEL_INPUT_SIZE = 256
44
+ AUDIO_MODEL_PATH = "models/best_model.pth"
45
+ YOLO_MODEL_PATH = "models/best_model.onnx"
46
+
47
+ # Initialize session state for navigation
48
+ if 'current_service' not in st.session_state:
49
+ st.session_state.current_service = 'deforestation'
50
+ if 'audio_input_method' not in st.session_state:
51
+ st.session_state.audio_input_method = 'upload'
52
+ if 'detection_input_method' not in st.session_state:
53
+ st.session_state.detection_input_method = 'image'
54
+
55
+ # Sidebar for navigation
56
+ with st.sidebar:
57
+ st.title("Nature Nexus")
58
+ st.subheader("Forest Surveillance System")
59
+
60
+ selected_service = st.radio(
61
+ "Select Service:",
62
+ ["Deforestation Detection", "Forest Audio Surveillance", "Object Detection"]
63
+ )
64
+
65
+ if selected_service == "Deforestation Detection":
66
+ st.session_state.current_service = 'deforestation'
67
+ elif selected_service == "Forest Audio Surveillance":
68
+ st.session_state.current_service = 'audio'
69
+ else:
70
+ st.session_state.current_service = 'detection'
71
+
72
+ st.markdown("---")
73
+
74
+ # Service-specific sidebar content
75
+ if st.session_state.current_service == 'deforestation':
76
+ st.info(
77
+ """
78
+ **Deforestation Detection**
79
+
80
+ Upload satellite or aerial images to detect areas of deforestation.
81
+ """
82
+ )
83
+ elif st.session_state.current_service == 'audio':
84
+ st.info(
85
+ """
86
+ **Forest Audio Surveillance**
87
+
88
+ Detect unusual human-related sounds in forested regions.
89
+ """
90
+ )
91
+
92
+ # Audio service specific controls
93
+ st.subheader("Audio Configuration")
94
+ audio_input_method = st.radio(
95
+ "Select Input Method:",
96
+ ("Upload Audio", "Record Audio"),
97
+ index=0 if st.session_state.audio_input_method == 'upload' else 1
98
+ )
99
+ st.session_state.audio_input_method = 'upload' if audio_input_method == "Upload Audio" else 'record'
100
+
101
+ # Audio class information
102
+ st.markdown("**Detection Classes:**")
103
+
104
+ # Group classes by category
105
+ human_sounds = ['footsteps', 'coughing', 'laughing', 'breathing',
106
+ 'drinking_sipping', 'snoring', 'sneezing']
107
+ tool_sounds = ['chainsaw', 'hand_saw']
108
+ vehicle_sounds = ['car_horn', 'engine', 'siren']
109
+ other_sounds = ['crackling_fire', 'fireworks']
110
+
111
+ st.markdown("👤 **Human Sounds:** " + ", ".join([s.capitalize() for s in human_sounds]))
112
+ st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
113
+ st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
114
+ st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
115
+ else: # Object Detection
116
+ st.info(
117
+ """
118
+ **Object Detection**
119
+
120
+ Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
121
+ """
122
+ )
123
+
124
+ # Detection service specific controls
125
+ st.subheader("Detection Configuration")
126
+ detection_input_method = st.radio(
127
+ "Select Input Method:",
128
+ ("Image", "Video", "Camera"),
129
+ index=0 if st.session_state.detection_input_method == 'image' else
130
+ (1 if st.session_state.detection_input_method == 'video' else 2)
131
+ )
132
+
133
+ if detection_input_method == "Image":
134
+ st.session_state.detection_input_method = 'image'
135
+ elif detection_input_method == "Video":
136
+ st.session_state.detection_input_method = 'video'
137
+ else:
138
+ st.session_state.detection_input_method = 'camera'
139
+
140
+ # Detection threshold controls
141
+ st.subheader("Detection Settings")
142
+ confidence = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
143
+ iou_thres = st.slider("IoU Threshold", 0.0, 1.0, 0.5)
144
+
145
+ # Detection class information
146
+ st.markdown("**Detection Classes:**")
147
+ st.markdown("🚴 **Bike/Bicycle**")
148
+ st.markdown("🚚 **Bus/Truck**")
149
+ st.markdown("🚗 **Car**")
150
+ st.markdown("🔥 **Fire**")
151
+ st.markdown("👤 **Human**")
152
+ st.markdown("💨 **Smoke**")
153
+
154
+ # Load deforestation model
155
+ @st.cache_resource
156
+ def load_cached_deforestation_model():
157
+ model_path = "models/deforestation_model.onnx"
158
+ return load_onnx_model(model_path, input_size=DEFOREST_MODEL_INPUT_SIZE)
159
+
160
+ # Load audio model
161
+ @st.cache_resource
162
+ def load_cached_audio_model():
163
+ return load_audio_model(AUDIO_MODEL_PATH)
164
+
165
+ @st.cache_resource
166
+ def load_cached_yolo_model():
167
+ return YOLOv11(YOLO_MODEL_PATH)
168
+
169
+ # Process image for deforestation detection
170
+ def process_image(model, image):
171
+ """Process a single image and return results"""
172
+ # Save original image dimensions for display
173
+ orig_height, orig_width = image.shape[:2]
174
+
175
+ # Make prediction
176
+ mask = model.predict(image)
177
+
178
+ # Resize mask back to original dimensions for display
179
+ display_mask = cv2.resize(mask, (orig_width, orig_height))
180
+
181
+ # Create binary mask for visualization
182
+ binary_mask = (display_mask > 0.5).astype(np.uint8) * 255
183
+
184
+ # Create colored overlay
185
+ overlay = create_overlay(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), display_mask)
186
+
187
+ # Calculate metrics
188
+ metrics = calculate_deforestation_metrics(mask)
189
+
190
+ return binary_mask, overlay, metrics
191
+
192
+ # Visualize audio for audio classification
193
+ def visualize_audio(audio_path):
194
+ y, sr = librosa.load(audio_path, sr=16000)
195
+ duration = len(y) / sr
196
+
197
+ fig, ax = plt.subplots(2, 1, figsize=(10, 6))
198
+
199
+ # Waveform plot
200
+ librosa.display.waveshow(y, sr=sr, ax=ax[0])
201
+ ax[0].set_title('Audio Waveform')
202
+ ax[0].set_xlabel('Time (s)')
203
+ ax[0].set_ylabel('Amplitude')
204
+
205
+ # Spectrogram plot
206
+ S = librosa.feature.melspectrogram(y=y, sr=sr)
207
+ S_db = librosa.power_to_db(S, ref=np.max)
208
+ img = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='mel', ax=ax[1])
209
+ fig.colorbar(img, ax=ax[1], format='%+2.0f dB')
210
+ ax[1].set_title('Mel Spectrogram')
211
+
212
+ plt.tight_layout()
213
+ st.pyplot(fig)
214
+
215
+ return y, sr, duration
216
+
217
+ # Process audio for classification
218
+ def process_audio(audio_file):
219
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
220
+ tmp_file.write(audio_file.read() if hasattr(audio_file, 'read') else audio_file)
221
+ audio_path = tmp_file.name
222
+
223
+ try:
224
+ # Load audio model
225
+ audio_model = load_cached_audio_model()
226
+
227
+ # Visualize audio
228
+ with st.spinner('Analyzing audio...'):
229
+ y, sr, duration = visualize_audio(audio_path)
230
+ st.caption(f"Audio duration: {duration:.2f} seconds")
231
+
232
+ # Make prediction
233
+ with st.spinner('Making prediction...'):
234
+ class_name, confidence = predict_audio(audio_path, audio_model)
235
+
236
+ # Display results
237
+ st.subheader("Detection Results")
238
+
239
+ col1, col2 = st.columns(2)
240
+ with col1:
241
+ st.metric("Detected Sound", class_name.replace('_', ' ').title())
242
+ with col2:
243
+ st.metric("Confidence", f"{confidence*100:.2f}%")
244
+
245
+ # Show alerts based on class
246
+ human_sounds = ['footsteps', 'coughing', 'laughing', 'breathing',
247
+ 'drinking_sipping', 'snoring', 'sneezing']
248
+ tool_sounds = ['chainsaw', 'hand_saw']
249
+
250
+ if class_name in human_sounds:
251
+ st.warning("""
252
+ ⚠️ **Human Activity Detected!**
253
+ Potential human presence in the monitored area.
254
+ """)
255
+ elif class_name in tool_sounds:
256
+ st.error("""
257
+ 🚨 **ALERT: Human Tool Detected!**
258
+ Potential illegal logging or activity detected. Consider immediate verification.
259
+ """)
260
+ elif class_name in ['car_horn', 'engine', 'siren']:
261
+ st.warning("""
262
+ ⚠️ **Vehicle Detected!**
263
+ Vehicle sounds detected in the monitored area.
264
+ """)
265
+ elif class_name == 'fireworks':
266
+ st.error("""
267
+ 🚨 **ALERT: Fireworks Detected!**
268
+ Potential fire hazard and disturbance to wildlife. Immediate verification required.
269
+ """)
270
+ elif class_name == 'crackling_fire':
271
+ st.error("""
272
+ 🚨 **ALERT: Fire Detected!**
273
+ Potential wildfire detected. Immediate verification required.
274
+ """)
275
+ else:
276
+ st.success("✅ Environmental sound detected - no immediate threat")
277
+
278
+ except Exception as e:
279
+ st.error(f"Error processing audio: {str(e)}")
280
+ st.exception(e)
281
+ finally:
282
+ # Clean up temp file
283
+ try:
284
+ os.unlink(audio_path)
285
+ except:
286
+ pass
287
+
288
+ # Deforestation detection UI
289
+ def show_deforestation_detection():
290
+ # App title and description
291
+ st.title("🌳 Deforestation Detection")
292
+ st.markdown(
293
+ """
294
+ This service detects areas of deforestation in satellite or aerial images of forests.
295
+ Upload an image to get started!
296
+ """
297
+ )
298
+
299
+ # Model info
300
+ st.info(
301
+ f"⚙️ Model optimized for {DEFOREST_MODEL_INPUT_SIZE}x{DEFOREST_MODEL_INPUT_SIZE} pixel images using ONNX runtime"
302
+ )
303
+
304
+ # Load model
305
+ try:
306
+ model = load_cached_deforestation_model()
307
+ except Exception as e:
308
+ st.error(f"Error loading model: {e}")
309
+ st.info(
310
+ "Make sure you have converted your PyTorch model to ONNX format using the utils/onnx_converter.py script."
311
+ )
312
+ st.code(
313
+ "python -m utils.onnx_converter models/best_model_100.pth models/deforestation_model.onnx"
314
+ )
315
+ return
316
+
317
+ # File uploader for images
318
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
319
+
320
+ if uploaded_file is not None:
321
+ # Load image
322
+ file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
323
+ image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
324
+
325
+ # Display original image
326
+ st.subheader("Original Image")
327
+ st.image(
328
+ cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
329
+ caption="Uploaded Image",
330
+ use_container_width=True,
331
+ )
332
+
333
+ # Add a spinner while processing
334
+ with st.spinner("Processing..."):
335
+ try:
336
+ # Process image
337
+ binary_mask, overlay, metrics = process_image(model, image)
338
+
339
+ # Display results in columns
340
+ col1, col2 = st.columns(2)
341
+
342
+ with col1:
343
+ st.subheader("Segmentation Result")
344
+ st.image(
345
+ binary_mask,
346
+ caption="Forest Areas (White)",
347
+ use_container_width=True,
348
+ )
349
+
350
+ with col2:
351
+ st.subheader("Overlay Visualization")
352
+ st.image(
353
+ overlay,
354
+ caption="Green: Forest, Brown: Deforested",
355
+ use_container_width=True,
356
+ )
357
+
358
+ # Display metrics
359
+ st.subheader("Deforestation Analysis")
360
+
361
+ # Create metrics cards
362
+ metrics_col1, metrics_col2, metrics_col3 = st.columns(3)
363
+
364
+ with metrics_col1:
365
+ st.metric(
366
+ label="Forest Coverage",
367
+ value=f"{metrics['forest_percentage']:.1f}%",
368
+ )
369
+
370
+ with metrics_col2:
371
+ st.metric(
372
+ label="Deforested Area",
373
+ value=f"{metrics['deforested_percentage']:.1f}%",
374
+ )
375
+
376
+ with metrics_col3:
377
+ st.metric(
378
+ label="Deforestation Level",
379
+ value=metrics["deforestation_level"],
380
+ )
381
+
382
+ except Exception as e:
383
+ st.error(f"Error during processing: {e}")
384
+
385
+ # Audio classification UI
386
+ def show_audio_classification():
387
+ # App title and description
388
+ st.title("🎧 Forest Audio Surveillance")
389
+ st.markdown("""
390
+ Detect unusual human-related sounds in forested regions to prevent illegal activities.
391
+ Supported sounds: {}
392
+ """.format(", ".join(class_names)))
393
+
394
+ if st.session_state.audio_input_method == 'upload':
395
+ st.header("Upload Audio File")
396
+
397
+ sample_col, upload_col = st.columns(2)
398
+ with sample_col:
399
+ st.info("Upload a WAV, MP3 or OGG file with forest sounds")
400
+ st.markdown("""
401
+ **Tips for best results:**
402
+ - Use audio with minimal background noise
403
+ - Ensure the sound of interest is clear
404
+ - 2-3 second clips work best
405
+ """)
406
+
407
+ with upload_col:
408
+ audio_file = st.file_uploader(
409
+ "Choose an audio file",
410
+ type=["wav", "mp3", "ogg"],
411
+ help="Supported formats: WAV, MP3, OGG"
412
+ )
413
+
414
+ if audio_file:
415
+ st.success("File uploaded successfully!")
416
+ with st.expander("Audio Preview", expanded=True):
417
+ st.audio(audio_file)
418
+ process_audio(audio_file)
419
+
420
+ else: # Record mode
421
+ st.header("Record Live Audio")
422
+
423
+ st.info("""
424
+ Click the microphone button below to record a sound for analysis.
425
+ **Note:** Please ensure your browser has permission to access your microphone.
426
+ When prompted, click "Allow" to enable recording.
427
+ """)
428
+
429
+ recorded_audio = st.audio_input(
430
+ label="Record a sound",
431
+ key="audio_recorder",
432
+ help="Click to record forest sounds for analysis",
433
+ label_visibility="visible"
434
+ )
435
+
436
+ if recorded_audio:
437
+ st.success("Audio recorded successfully!")
438
+ with st.expander("Recorded Audio", expanded=True):
439
+ st.audio(recorded_audio)
440
+ process_audio(recorded_audio)
441
+ else:
442
+ st.write("Waiting for recording...")
443
+
444
+ # Object Detection UI
445
+ def show_object_detection():
446
+ # App title and description
447
+ st.title("🔍 Forest Object Detection")
448
+ st.markdown(
449
+ """
450
+ Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
451
+ Choose an input method to begin detection.
452
+ """
453
+ )
454
+
455
+ # Model info
456
+ st.info("⚙️ Object detection model optimized with ONNX runtime for faster inference")
457
+
458
+ # Load model
459
+ try:
460
+ model = load_cached_yolo_model()
461
+ # Update model confidence and IoU thresholds from sidebar
462
+ confidence = st.session_state.get('confidence', 0.5)
463
+ iou_thres = st.session_state.get('iou_thres', 0.5)
464
+ model.conf_thres = confidence
465
+ model.iou_thres = iou_thres
466
+ except Exception as e:
467
+ st.error(f"Error loading model: {e}")
468
+ st.info(
469
+ "Make sure you have the YOLO ONNX model file available at models/best_model.onnx"
470
+ )
471
+ return
472
+
473
+ # Input method based selection
474
+ if st.session_state.detection_input_method == 'image':
475
+ # Image upload
476
+ img_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
477
+ if img_file is not None:
478
+ # Load image
479
+ file_bytes = np.asarray(bytearray(img_file.read()), dtype=np.uint8)
480
+ image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
481
+ if image is not None:
482
+ # Display original image
483
+ st.subheader("Original Image")
484
+ st.image(
485
+ cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
486
+ caption="Uploaded Image",
487
+ use_container_width=True,
488
+ )
489
+
490
+ # Process with detection model
491
+ with st.spinner("Processing image..."):
492
+ try:
493
+ detections = model.detect(image)
494
+ result_image = model.draw_detections(image.copy(), detections)
495
+
496
+ # Display results
497
+ st.subheader("Detection Results")
498
+ st.image(
499
+ cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
500
+ caption="Detected Objects",
501
+ use_container_width=True,
502
+ )
503
+
504
+ # Display detection statistics
505
+ st.subheader("Detection Statistics")
506
+
507
+ # Count detections by class
508
+ class_counts = {}
509
+ for det in detections:
510
+ class_name = det['class']
511
+ if class_name in class_counts:
512
+ class_counts[class_name] += 1
513
+ else:
514
+ class_counts[class_name] = 1
515
+
516
+ # Display counts with emojis
517
+ cols = st.columns(3)
518
+ col_idx = 0
519
+
520
+ for class_name, count in class_counts.items():
521
+ emoji = "👤" if class_name == "human" else (
522
+ "🔥" if class_name == "fire" else (
523
+ "💨" if class_name == "smoke" else (
524
+ "🚗" if class_name == "car" else (
525
+ "🚴" if class_name == "bike-bicycle" else "🚚"))))
526
+
527
+ with cols[col_idx % 3]:
528
+ st.metric(f"{emoji} {class_name.capitalize()}", count)
529
+ col_idx += 1
530
+
531
+ # Check for priority threats
532
+ if "fire" in class_counts or "smoke" in class_counts:
533
+ st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected. Immediate action required.")
534
+
535
+ if "human" in class_counts or "car" in class_counts or "bike-bicycle" in class_counts or "bus-truck" in class_counts:
536
+ st.warning("⚠️ **Trespassers Detected!** Unauthorized entry detected in monitored area.")
537
+
538
+ except Exception as e:
539
+ st.error(f"Error during detection: {e}")
540
+ st.exception(e)
541
+
542
+ elif st.session_state.detection_input_method == 'video':
543
+ # Video upload
544
+ video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov"])
545
+ if video_file is not None:
546
+ # Save uploaded video to temp file
547
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
548
+ tfile.write(video_file.read())
549
+ temp_video_path = tfile.name
550
+
551
+ # Display video upload success
552
+ st.success("Video uploaded successfully!")
553
+
554
+ # Process video button
555
+ if st.button("Process Video"):
556
+ with st.spinner("Processing video... This may take a while."):
557
+ try:
558
+ # Open video file
559
+ cap = cv2.VideoCapture(temp_video_path)
560
+
561
+ # Create video writer for output
562
+ output_path = "output_video.mp4"
563
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
564
+ fps = cap.get(cv2.CAP_PROP_FPS)
565
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
566
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
567
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
568
+
569
+ # Create placeholder for video frames
570
+ video_placeholder = st.empty()
571
+ status_text = st.empty()
572
+
573
+ # Process frames
574
+ frame_count = 0
575
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
576
+
577
+ while cap.isOpened():
578
+ ret, frame = cap.read()
579
+ if not ret:
580
+ break
581
+
582
+ # Process every 5th frame for speed
583
+ if frame_count % 5 == 0:
584
+ detections = model.detect(frame)
585
+ result_frame = model.draw_detections(frame.copy(), detections)
586
+
587
+ # Update preview
588
+ if frame_count % 15 == 0: # Update display less frequently
589
+ video_placeholder.image(
590
+ cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB),
591
+ caption="Processing Video",
592
+ use_container_width=True
593
+ )
594
+ progress = min(100, int((frame_count / total_frames) * 100))
595
+ status_text.text(f"Processing: {progress}% complete")
596
+ else:
597
+ result_frame = frame # Skip detection on some frames
598
+
599
+ # Write frame to output video
600
+ out.write(result_frame)
601
+ frame_count += 1
602
+
603
+ # Release resources
604
+ cap.release()
605
+ out.release()
606
+
607
+ # Display completion message
608
+ st.success("Video processing complete!")
609
+
610
+ # Provide download button for processed video
611
+ with open(output_path, "rb") as file:
612
+ st.download_button(
613
+ label="Download Processed Video",
614
+ data=file,
615
+ file_name="forest_surveillance_results.mp4",
616
+ mime="video/mp4"
617
+ )
618
+
619
+ except Exception as e:
620
+ st.error(f"Error processing video: {e}")
621
+ st.exception(e)
622
+ finally:
623
+ # Clean up temp file
624
+ try:
625
+ os.unlink(temp_video_path)
626
+ except:
627
+ pass
628
+
629
+ else: # Camera mode
630
+ # Live camera feed
631
+ st.subheader("Live Camera Detection")
632
+ st.info("Use your webcam to detect objects in real-time")
633
+
634
+ cam = st.camera_input("Camera Feed")
635
+
636
+ if cam:
637
+ # Process camera input
638
+ with st.spinner("Processing image..."):
639
+ try:
640
+ # Convert image
641
+ image = cv2.imdecode(np.frombuffer(cam.getvalue(), np.uint8), cv2.IMREAD_COLOR)
642
+
643
+ # Run detection
644
+ detections = model.detect(image)
645
+ result_image = model.draw_detections(image.copy(), detections)
646
+
647
+ # Display results
648
+ st.image(
649
+ cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
650
+ caption="Detection Results",
651
+ use_container_width=True
652
+ )
653
+
654
+ # Show detection summary
655
+ if detections:
656
+ # Count detections by class
657
+ class_counts = {}
658
+ for det in detections:
659
+ class_name = det['class']
660
+ if class_name in class_counts:
661
+ class_counts[class_name] += 1
662
+ else:
663
+ class_counts[class_name] = 1
664
+
665
+ # Display as metrics
666
+ st.subheader("Detection Summary")
667
+ cols = st.columns(3)
668
+ for i, (class_name, count) in enumerate(class_counts.items()):
669
+ with cols[i % 3]:
670
+ st.metric(class_name.capitalize(), count)
671
+
672
+ # Check for priority threats
673
+ if "fire" in class_counts or "smoke" in class_counts:
674
+ st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected.")
675
+
676
+ if "human" in class_counts:
677
+ st.warning("⚠️ **Trespasser Detected!** Human presence detected.")
678
+ else:
679
+ st.info("No objects detected in frame")
680
+
681
+ except Exception as e:
682
+ st.error(f"Error processing camera feed: {e}")
683
+
684
+ # Main function
685
+ def main():
686
+ # Check which service is selected and render appropriate UI
687
+ if st.session_state.current_service == 'deforestation':
688
+ show_deforestation_detection()
689
+ elif st.session_state.current_service == 'audio':
690
+ show_audio_classification()
691
+ else: # 'detection'
692
+ show_object_detection()
693
+
694
+ # Footer
695
+ st.markdown("---")
696
+ st.markdown("""
697
+ <div style="text-align: center; padding: 10px;">
698
+ <p>Nature Nexus - Forest Surveillance System | 🌳 Protect Natural Ecosystems</p>
699
+ <p><small>Built with Streamlit and PyTorch</small></p>
700
+ </div>
701
+ """, unsafe_allow_html=True)
702
+
703
+ if __name__ == "__main__":
704
+ main()
predict.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import onnxruntime as ort
5
+ from utils.preprocess import preprocess_image
6
+ from utils.model import load_model
7
+
8
+
9
+ class PredictionEngine:
10
+ def __init__(self, model_path=None, use_onnx=True, input_size=256):
11
+ """
12
+ Initialize the prediction engine
13
+
14
+ Args:
15
+ model_path: Path to the model file (PyTorch or ONNX)
16
+ use_onnx: Whether to use ONNX runtime for inference
17
+ input_size: Input size for the model (default is 256)
18
+ """
19
+ self.use_onnx = use_onnx
20
+ self.input_size = input_size
21
+
22
+ if model_path:
23
+ if use_onnx:
24
+ self.model = self._load_onnx_model(model_path)
25
+ else:
26
+ self.model = load_model(model_path)
27
+ else:
28
+ self.model = None
29
+
30
+ def _load_onnx_model(self, model_path):
31
+ """
32
+ Load an ONNX model
33
+
34
+ Args:
35
+ model_path: Path to the ONNX model
36
+
37
+ Returns:
38
+ ONNX Runtime InferenceSession
39
+ """
40
+ # Try with CUDA first, fall back to CPU if needed
41
+ try:
42
+ session = ort.InferenceSession(
43
+ model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
44
+ )
45
+ print("ONNX model loaded with CUDA support")
46
+ return session
47
+ except Exception as e:
48
+ print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}")
49
+ session = ort.InferenceSession(
50
+ model_path, providers=["CPUExecutionProvider"]
51
+ )
52
+ print("ONNX model loaded with CPU support")
53
+ return session
54
+
55
+ def preprocess(self, image):
56
+ """
57
+ Preprocess an image for prediction
58
+
59
+ Args:
60
+ image: Input image (numpy array)
61
+
62
+ Returns:
63
+ Processed image suitable for the model
64
+ """
65
+ # Keep the original image for reference
66
+ self.original_shape = image.shape[:2]
67
+
68
+ # Preprocess image
69
+ if self.use_onnx:
70
+ # For ONNX, we need to ensure the input is exactly the expected size
71
+ tensor = preprocess_image(image, img_size=self.input_size)
72
+ return tensor.numpy()
73
+ else:
74
+ # For PyTorch
75
+ return preprocess_image(image, img_size=self.input_size)
76
+
77
+ def predict(self, image):
78
+ """
79
+ Make a prediction on an image
80
+
81
+ Args:
82
+ image: Input image (numpy array)
83
+
84
+ Returns:
85
+ Predicted mask
86
+ """
87
+ if self.model is None:
88
+ raise ValueError("Model not loaded. Initialize with a valid model path.")
89
+
90
+ # Preprocess the image
91
+ processed_input = self.preprocess(image)
92
+
93
+ # Run inference
94
+ if self.use_onnx:
95
+ # Get input and output names
96
+ input_name = self.model.get_inputs()[0].name
97
+ output_name = self.model.get_outputs()[0].name
98
+
99
+ # Run ONNX inference
100
+ outputs = self.model.run([output_name], {input_name: processed_input})
101
+
102
+ # Apply sigmoid to output
103
+ mask = 1 / (1 + np.exp(-outputs[0].squeeze()))
104
+ else:
105
+ # PyTorch inference
106
+ with torch.no_grad():
107
+ # Move to device
108
+ device = next(self.model.parameters()).device
109
+ processed_input = processed_input.to(device)
110
+
111
+ # Forward pass
112
+ output = self.model(processed_input)
113
+ output = torch.sigmoid(output)
114
+
115
+ # Convert to numpy
116
+ mask = output.cpu().numpy().squeeze()
117
+
118
+ return mask
119
+
120
+
121
+ def load_pytorch_model(model_path):
122
+ """
123
+ Load the PyTorch model for prediction
124
+
125
+ Args:
126
+ model_path: Path to the PyTorch model
127
+
128
+ Returns:
129
+ PredictionEngine instance
130
+ """
131
+ return PredictionEngine(model_path, use_onnx=False)
132
+
133
+
134
+ def load_onnx_model(model_path, input_size=256):
135
+ """
136
+ Load the ONNX model for prediction
137
+
138
+ Args:
139
+ model_path: Path to the ONNX model
140
+ input_size: Input size for the model
141
+
142
+ Returns:
143
+ PredictionEngine instance
144
+ """
145
+ return PredictionEngine(model_path, use_onnx=True, input_size=input_size)
146
+
147
+
148
+ # For backwards compatibility
149
+ def predict(model, image):
150
+ """
151
+ Legacy function for prediction
152
+
153
+ Args:
154
+ model: Model instance
155
+ image: Input image
156
+
157
+ Returns:
158
+ Predicted mask
159
+ """
160
+ if isinstance(model, PredictionEngine):
161
+ return model.predict(image)
162
+
163
+ engine = PredictionEngine(use_onnx=True)
164
+ engine.model = model
165
+ return engine.predict(image)
prediction_engine.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import onnxruntime as ort
5
+ from utils.preprocess import preprocess_image
6
+
7
+
8
+ class PredictionEngine:
9
+ def __init__(self, model_path=None, use_onnx=True, input_size=256):
10
+ """
11
+ Initialize the prediction engine
12
+
13
+ Args:
14
+ model_path: Path to the model file (PyTorch or ONNX)
15
+ use_onnx: Whether to use ONNX runtime for inference
16
+ input_size: Input size for the model (default is 256)
17
+ """
18
+ self.use_onnx = use_onnx
19
+ self.input_size = input_size
20
+
21
+ if model_path:
22
+ if use_onnx:
23
+ self.model = self._load_onnx_model(model_path)
24
+ else:
25
+ self.model = self._load_pytorch_model(model_path)
26
+ else:
27
+ self.model = None
28
+
29
+ def _load_onnx_model(self, model_path):
30
+ """
31
+ Load an ONNX model
32
+
33
+ Args:
34
+ model_path: Path to the ONNX model
35
+
36
+ Returns:
37
+ ONNX Runtime InferenceSession
38
+ """
39
+ # Try with CUDA first, fall back to CPU if needed
40
+ try:
41
+ session = ort.InferenceSession(
42
+ model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
43
+ )
44
+ print("ONNX model loaded with CUDA support")
45
+ return session
46
+ except Exception as e:
47
+ print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}")
48
+ session = ort.InferenceSession(
49
+ model_path, providers=["CPUExecutionProvider"]
50
+ )
51
+ print("ONNX model loaded with CPU support")
52
+ return session
53
+
54
+ def _load_pytorch_model(self, model_path):
55
+ """
56
+ Load a PyTorch model
57
+
58
+ Args:
59
+ model_path: Path to the PyTorch model
60
+
61
+ Returns:
62
+ PyTorch model
63
+ """
64
+ from utils.model import load_model
65
+ return load_model(model_path)
66
+
67
+ def preprocess(self, image):
68
+ """
69
+ Preprocess an image for prediction
70
+
71
+ Args:
72
+ image: Input image (numpy array)
73
+
74
+ Returns:
75
+ Processed image suitable for the model
76
+ """
77
+ # Keep the original image for reference
78
+ self.original_shape = image.shape[:2]
79
+
80
+ # Preprocess image
81
+ if self.use_onnx:
82
+ # For ONNX, we need to ensure the input is exactly the expected size
83
+ tensor = preprocess_image(image, img_size=self.input_size)
84
+ return tensor.numpy()
85
+ else:
86
+ # For PyTorch
87
+ return preprocess_image(image, img_size=self.input_size)
88
+
89
+ def predict(self, image):
90
+ """
91
+ Make a prediction on an image
92
+
93
+ Args:
94
+ image: Input image (numpy array)
95
+
96
+ Returns:
97
+ Predicted mask
98
+ """
99
+ if self.model is None:
100
+ raise ValueError("Model not loaded. Initialize with a valid model path.")
101
+
102
+ # Preprocess the image
103
+ processed_input = self.preprocess(image)
104
+
105
+ # Run inference
106
+ if self.use_onnx:
107
+ # Get input and output names
108
+ input_name = self.model.get_inputs()[0].name
109
+ output_name = self.model.get_outputs()[0].name
110
+
111
+ # Run ONNX inference
112
+ outputs = self.model.run([output_name], {input_name: processed_input})
113
+
114
+ # Apply sigmoid to output
115
+ mask = 1 / (1 + np.exp(-outputs[0].squeeze()))
116
+ else:
117
+ # PyTorch inference
118
+ with torch.no_grad():
119
+ # Move to device
120
+ device = next(self.model.parameters()).device
121
+ processed_input = processed_input.to(device)
122
+
123
+ # Forward pass
124
+ output = self.model(processed_input)
125
+ output = torch.sigmoid(output)
126
+
127
+ # Convert to numpy
128
+ mask = output.cpu().numpy().squeeze()
129
+
130
+ return mask
131
+
132
+
133
+ def load_pytorch_model(model_path):
134
+ """
135
+ Load the PyTorch model for prediction
136
+
137
+ Args:
138
+ model_path: Path to the PyTorch model
139
+
140
+ Returns:
141
+ PredictionEngine instance
142
+ """
143
+ return PredictionEngine(model_path, use_onnx=False)
144
+
145
+
146
+ def load_onnx_model(model_path, input_size=256):
147
+ """
148
+ Load the ONNX model for prediction
149
+
150
+ Args:
151
+ model_path: Path to the ONNX model
152
+ input_size: Input size for the model
153
+
154
+ Returns:
155
+ PredictionEngine instance
156
+ """
157
+ return PredictionEngine(model_path, use_onnx=True, input_size=input_size)
requirements.txt CHANGED
@@ -1,3 +1,13 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch==2.1.2
3
+ torchvision==0.16.2
4
+ opencv-python-headless
5
+ albumentations
6
+ numpy
7
+ Pillow
8
+ scikit-image
9
+ scikit-learn
10
+ matplotlib
11
+ onnxruntime
12
+ onnx
13
+ supervision