Spaces:
Sleeping
Sleeping
Deploy NAVADA 2.0 Lite - Optimized for HF Spaces (no face recognition)
Browse files- .gitattributes +1 -1
- .gitignore +44 -44
- README.md +43 -43
- app.py +271 -1254
- app_lite.py +271 -0
- backend/chat_agent.py +188 -188
- backend/custom_trainer.py +398 -398
- backend/database.py +677 -677
- backend/face_detection.py +298 -298
- backend/openai_client.py +65 -65
- backend/prisma_client.py +399 -399
- backend/recognition.py +366 -366
- backend/two_stage_inference.py +284 -284
- backend/yolo.py +34 -34
- backend/yolo_enhanced.py +230 -230
- packages.txt +4 -4
- requirements.txt +13 -16
- requirements_lite.txt +14 -0
.gitattributes
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
@@ -1,45 +1,45 @@
|
|
| 1 |
-
# Python
|
| 2 |
-
__pycache__/
|
| 3 |
-
*.py[cod]
|
| 4 |
-
*$py.class
|
| 5 |
-
*.so
|
| 6 |
-
.Python
|
| 7 |
-
env/
|
| 8 |
-
venv/
|
| 9 |
-
.venv/
|
| 10 |
-
ENV/
|
| 11 |
-
|
| 12 |
-
# Environment variables
|
| 13 |
-
.env
|
| 14 |
-
.env.local
|
| 15 |
-
|
| 16 |
-
# IDE
|
| 17 |
-
.vscode/
|
| 18 |
-
.idea/
|
| 19 |
-
*.swp
|
| 20 |
-
*.swo
|
| 21 |
-
|
| 22 |
-
# OS
|
| 23 |
-
.DS_Store
|
| 24 |
-
Thumbs.db
|
| 25 |
-
|
| 26 |
-
# Streamlit
|
| 27 |
-
.streamlit/secrets.toml
|
| 28 |
-
|
| 29 |
-
# Database
|
| 30 |
-
*.db
|
| 31 |
-
*.sqlite
|
| 32 |
-
*.sqlite3
|
| 33 |
-
|
| 34 |
-
# Models (if too large)
|
| 35 |
-
yolov8m.pt
|
| 36 |
-
yolov8l.pt
|
| 37 |
-
yolov8x.pt
|
| 38 |
-
|
| 39 |
-
# Logs
|
| 40 |
-
*.log
|
| 41 |
-
|
| 42 |
-
# Temporary files
|
| 43 |
-
*.tmp
|
| 44 |
-
temp/
|
| 45 |
tmp/
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
venv/
|
| 9 |
+
.venv/
|
| 10 |
+
ENV/
|
| 11 |
+
|
| 12 |
+
# Environment variables
|
| 13 |
+
.env
|
| 14 |
+
.env.local
|
| 15 |
+
|
| 16 |
+
# IDE
|
| 17 |
+
.vscode/
|
| 18 |
+
.idea/
|
| 19 |
+
*.swp
|
| 20 |
+
*.swo
|
| 21 |
+
|
| 22 |
+
# OS
|
| 23 |
+
.DS_Store
|
| 24 |
+
Thumbs.db
|
| 25 |
+
|
| 26 |
+
# Streamlit
|
| 27 |
+
.streamlit/secrets.toml
|
| 28 |
+
|
| 29 |
+
# Database
|
| 30 |
+
*.db
|
| 31 |
+
*.sqlite
|
| 32 |
+
*.sqlite3
|
| 33 |
+
|
| 34 |
+
# Models (if too large)
|
| 35 |
+
yolov8m.pt
|
| 36 |
+
yolov8l.pt
|
| 37 |
+
yolov8x.pt
|
| 38 |
+
|
| 39 |
+
# Logs
|
| 40 |
+
*.log
|
| 41 |
+
|
| 42 |
+
# Temporary files
|
| 43 |
+
*.tmp
|
| 44 |
+
temp/
|
| 45 |
tmp/
|
README.md
CHANGED
|
@@ -1,44 +1,44 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: NAVADA 2.0 - Advanced AI Computer Vision
|
| 3 |
-
emoji: 🚀
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: streamlit
|
| 7 |
-
sdk_version: 1.28.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
# 🚀 NAVADA 2.0 - Advanced AI Computer Vision Application
|
| 14 |
-
|
| 15 |
-
An advanced AI-powered computer vision application featuring:
|
| 16 |
-
- 🎯 Real-time object detection using YOLOv8
|
| 17 |
-
- 👤 Face detection and recognition
|
| 18 |
-
- 🤖 AI-powered explanations
|
| 19 |
-
- 📊 Interactive analytics
|
| 20 |
-
- 🎤 Voice narration
|
| 21 |
-
- 💬 Intelligent chat agent
|
| 22 |
-
|
| 23 |
-
## Features
|
| 24 |
-
- **Object Detection**: State-of-the-art YOLOv8 model for accurate object detection
|
| 25 |
-
- **Face Recognition**: Advanced face detection with emotion and feature analysis
|
| 26 |
-
- **AI Intelligence**: OpenAI-powered explanations and insights
|
| 27 |
-
- **Interactive Charts**: Real-time visualization of detection results
|
| 28 |
-
- **Voice Output**: Text-to-speech narration of detection results
|
| 29 |
-
- **Chat Interface**: Intelligent assistant for image analysis
|
| 30 |
-
|
| 31 |
-
## Usage
|
| 32 |
-
1. Upload an image using the file uploader
|
| 33 |
-
2. The system will automatically detect objects and faces
|
| 34 |
-
3. View detailed analytics and AI-generated explanations
|
| 35 |
-
4. Interact with the chat agent for deeper insights
|
| 36 |
-
|
| 37 |
-
## Technology Stack
|
| 38 |
-
- Streamlit for the web interface
|
| 39 |
-
- YOLOv8 for object detection
|
| 40 |
-
- OpenAI API for intelligent analysis
|
| 41 |
-
- Face Recognition library
|
| 42 |
-
- Plotly for interactive visualizations
|
| 43 |
-
|
| 44 |
Created by Lee Akpareva | AI Consultant & Computer Vision Specialist
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: NAVADA 2.0 - Advanced AI Computer Vision
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.28.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# 🚀 NAVADA 2.0 - Advanced AI Computer Vision Application
|
| 14 |
+
|
| 15 |
+
An advanced AI-powered computer vision application featuring:
|
| 16 |
+
- 🎯 Real-time object detection using YOLOv8
|
| 17 |
+
- 👤 Face detection and recognition
|
| 18 |
+
- 🤖 AI-powered explanations
|
| 19 |
+
- 📊 Interactive analytics
|
| 20 |
+
- 🎤 Voice narration
|
| 21 |
+
- 💬 Intelligent chat agent
|
| 22 |
+
|
| 23 |
+
## Features
|
| 24 |
+
- **Object Detection**: State-of-the-art YOLOv8 model for accurate object detection
|
| 25 |
+
- **Face Recognition**: Advanced face detection with emotion and feature analysis
|
| 26 |
+
- **AI Intelligence**: OpenAI-powered explanations and insights
|
| 27 |
+
- **Interactive Charts**: Real-time visualization of detection results
|
| 28 |
+
- **Voice Output**: Text-to-speech narration of detection results
|
| 29 |
+
- **Chat Interface**: Intelligent assistant for image analysis
|
| 30 |
+
|
| 31 |
+
## Usage
|
| 32 |
+
1. Upload an image using the file uploader
|
| 33 |
+
2. The system will automatically detect objects and faces
|
| 34 |
+
3. View detailed analytics and AI-generated explanations
|
| 35 |
+
4. Interact with the chat agent for deeper insights
|
| 36 |
+
|
| 37 |
+
## Technology Stack
|
| 38 |
+
- Streamlit for the web interface
|
| 39 |
+
- YOLOv8 for object detection
|
| 40 |
+
- OpenAI API for intelligent analysis
|
| 41 |
+
- Face Recognition library
|
| 42 |
+
- Plotly for interactive visualizations
|
| 43 |
+
|
| 44 |
Created by Lee Akpareva | AI Consultant & Computer Vision Specialist
|
app.py
CHANGED
|
@@ -1,1254 +1,271 @@
|
|
| 1 |
-
"""
|
| 2 |
-
🚀 NAVADA 2.0 - Advanced AI Computer Vision Application
|
| 3 |
-
Streamlit Version for Hugging Face Spaces Deployment
|
| 4 |
-
|
| 5 |
-
Enhanced Edition by Lee Akpareva | AI Consultant & Computer Vision Specialist
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import streamlit as st
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
.
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
}
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
""
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
#
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
# Face recognition if enabled
|
| 273 |
-
if enable_recognition and recognition_system:
|
| 274 |
-
detected_img, face_matches = recognition_system.recognize_faces(detected_img)
|
| 275 |
-
|
| 276 |
-
# AI explanation - enhanced version includes detailed attributes
|
| 277 |
-
if detailed_attributes:
|
| 278 |
-
ai_explanation = get_intelligence_report(detailed_attributes)
|
| 279 |
-
else:
|
| 280 |
-
ai_explanation = explain_detection(detected_objects)
|
| 281 |
-
|
| 282 |
-
# RAG enhancement if recognition enabled
|
| 283 |
-
if enable_recognition and recognition_system:
|
| 284 |
-
rag_enhancement = recognition_system.enhance_with_rag(detected_objects, face_matches)
|
| 285 |
-
ai_explanation = f"{ai_explanation}\n\n{rag_enhancement}"
|
| 286 |
-
|
| 287 |
-
# Voice generation if enabled
|
| 288 |
-
audio_file = None
|
| 289 |
-
if enable_voice:
|
| 290 |
-
try:
|
| 291 |
-
st.info("🔊 Generating voice narration...")
|
| 292 |
-
audio_file = generate_voice(ai_explanation)
|
| 293 |
-
if audio_file:
|
| 294 |
-
st.success("✅ Voice narration generated successfully!")
|
| 295 |
-
else:
|
| 296 |
-
st.error("❌ Voice generation failed - no audio file created")
|
| 297 |
-
except Exception as e:
|
| 298 |
-
st.error(f"❌ Voice generation failed: {e}")
|
| 299 |
-
import traceback
|
| 300 |
-
st.error(f"Details: {traceback.format_exc()}")
|
| 301 |
-
|
| 302 |
-
# Save session data
|
| 303 |
-
processing_time = time.time() - start_time
|
| 304 |
-
if recognition_system:
|
| 305 |
-
recognition_system.save_session_data(
|
| 306 |
-
image_array, detected_objects, face_matches, processing_time
|
| 307 |
-
)
|
| 308 |
-
|
| 309 |
-
return detected_img, ai_explanation, detected_objects, face_stats, face_matches, audio_file, detailed_attributes
|
| 310 |
-
|
| 311 |
-
except Exception as e:
|
| 312 |
-
st.error(f"Processing failed: {e}")
|
| 313 |
-
return None, f"Error: {e}", [], None, None, None, None
|
| 314 |
-
|
| 315 |
-
def get_database_stats():
|
| 316 |
-
"""Get current database statistics"""
|
| 317 |
-
try:
|
| 318 |
-
if db:
|
| 319 |
-
stats = db.get_stats()
|
| 320 |
-
return {
|
| 321 |
-
"faces": stats.get("faces", 0),
|
| 322 |
-
"objects": stats.get("objects", 0),
|
| 323 |
-
"sessions": stats.get("recent_detections", 0),
|
| 324 |
-
"total_detections": stats.get("total_detections", 0)
|
| 325 |
-
}
|
| 326 |
-
return {"faces": 0, "objects": 0, "sessions": 0, "total_detections": 0}
|
| 327 |
-
except Exception as e:
|
| 328 |
-
st.warning(f"Database stats unavailable: {e}")
|
| 329 |
-
return {"faces": 0, "objects": 0, "sessions": 0, "total_detections": 0}
|
| 330 |
-
|
| 331 |
-
# Sidebar for database features and stats
|
| 332 |
-
with st.sidebar:
|
| 333 |
-
st.markdown("""
|
| 334 |
-
<div class="feature-card">
|
| 335 |
-
<h3>🗄️ NAVADA Database</h3>
|
| 336 |
-
<p>Custom Recognition & RAG System</p>
|
| 337 |
-
</div>
|
| 338 |
-
""", unsafe_allow_html=True)
|
| 339 |
-
|
| 340 |
-
# Database statistics
|
| 341 |
-
stats = get_database_stats()
|
| 342 |
-
|
| 343 |
-
# Prisma Studio Integration
|
| 344 |
-
st.markdown("#### 🔧 Database Management")
|
| 345 |
-
col_studio1, col_studio2 = st.columns(2)
|
| 346 |
-
|
| 347 |
-
with col_studio1:
|
| 348 |
-
if st.button("🎯 Open Prisma Studio", help="View and edit database in Prisma Studio"):
|
| 349 |
-
try:
|
| 350 |
-
import subprocess
|
| 351 |
-
subprocess.Popen(["npm", "run", "studio"], cwd=".", shell=True)
|
| 352 |
-
st.success("🚀 Prisma Studio starting on http://localhost:5556")
|
| 353 |
-
except Exception as e:
|
| 354 |
-
st.error(f"Failed to start Prisma Studio: {e}")
|
| 355 |
-
st.info("💡 Run manually: npm run studio")
|
| 356 |
-
|
| 357 |
-
with col_studio2:
|
| 358 |
-
if st.button("📊 Database Info", help="Show database connection details"):
|
| 359 |
-
st.info("📍 Database: navada_recognition.db\n🌐 Prisma Studio: http://localhost:5556")
|
| 360 |
-
|
| 361 |
-
col1, col2 = st.columns(2)
|
| 362 |
-
with col1:
|
| 363 |
-
st.markdown(f"""
|
| 364 |
-
<div class="stats-card">
|
| 365 |
-
<h4>{stats.get('faces', 0)}</h4>
|
| 366 |
-
<p>👥 Faces</p>
|
| 367 |
-
</div>
|
| 368 |
-
""", unsafe_allow_html=True)
|
| 369 |
-
|
| 370 |
-
st.markdown(f"""
|
| 371 |
-
<div class="stats-card">
|
| 372 |
-
<h4>{stats.get('sessions', 0)}</h4>
|
| 373 |
-
<p>📊 Sessions</p>
|
| 374 |
-
</div>
|
| 375 |
-
""", unsafe_allow_html=True)
|
| 376 |
-
|
| 377 |
-
with col2:
|
| 378 |
-
st.markdown(f"""
|
| 379 |
-
<div class="stats-card">
|
| 380 |
-
<h4>{stats.get('objects', 0)}</h4>
|
| 381 |
-
<p>🏷️ Objects</p>
|
| 382 |
-
</div>
|
| 383 |
-
""", unsafe_allow_html=True)
|
| 384 |
-
|
| 385 |
-
st.markdown(f"""
|
| 386 |
-
<div class="stats-card">
|
| 387 |
-
<h4>{stats.get('total_detections', 0)}</h4>
|
| 388 |
-
<p>🎯 Detections</p>
|
| 389 |
-
</div>
|
| 390 |
-
""", unsafe_allow_html=True)
|
| 391 |
-
|
| 392 |
-
st.markdown("---")
|
| 393 |
-
|
| 394 |
-
# Computer Vision Educational Section
|
| 395 |
-
with st.expander("🔬 Computer Vision Guide", expanded=False):
|
| 396 |
-
st.markdown('<i class="fas fa-microscope fa-primary fa-icon"></i>**Advanced CV Learning Hub**', unsafe_allow_html=True)
|
| 397 |
-
st.markdown('<h3><i class="fas fa-brain fa-primary fa-icon"></i>What is Computer Vision?</h3>', unsafe_allow_html=True)
|
| 398 |
-
st.markdown("""
|
| 399 |
-
**Computer Vision (CV)** is a field of artificial intelligence that enables machines to interpret and understand visual information from the world, mimicking human vision capabilities.
|
| 400 |
-
|
| 401 |
-
**Key Components:**
|
| 402 |
-
- **Image Processing**: Enhancing and filtering visual data
|
| 403 |
-
- **Pattern Recognition**: Identifying objects, faces, and features
|
| 404 |
-
- **Machine Learning**: Training models on visual datasets
|
| 405 |
-
- **Deep Learning**: Neural networks for complex visual understanding
|
| 406 |
-
""")
|
| 407 |
-
|
| 408 |
-
st.markdown('<h3><i class="fas fa-crosshairs fa-primary fa-icon"></i>Top 5 Real-World Use Cases</h3>', unsafe_allow_html=True)
|
| 409 |
-
|
| 410 |
-
use_cases = [
|
| 411 |
-
{
|
| 412 |
-
"icon": "hospital",
|
| 413 |
-
"title": "Healthcare & Medical Imaging",
|
| 414 |
-
"description": "Detecting diseases in X-rays, MRIs, and CT scans. Early cancer detection, automated diagnosis, and surgical assistance.",
|
| 415 |
-
"impact": "95% accuracy in mammography screening"
|
| 416 |
-
},
|
| 417 |
-
{
|
| 418 |
-
"icon": "car",
|
| 419 |
-
"title": "Autonomous Vehicles",
|
| 420 |
-
"description": "Real-time object detection, lane recognition, traffic sign identification, and pedestrian safety systems.",
|
| 421 |
-
"impact": "$7 trillion global market potential"
|
| 422 |
-
},
|
| 423 |
-
{
|
| 424 |
-
"icon": "industry",
|
| 425 |
-
"title": "Manufacturing & Quality Control",
|
| 426 |
-
"description": "Automated defect detection, product inspection, assembly line monitoring, and predictive maintenance.",
|
| 427 |
-
"impact": "40% reduction in production errors"
|
| 428 |
-
},
|
| 429 |
-
{
|
| 430 |
-
"icon": "shield-alt",
|
| 431 |
-
"title": "Security & Surveillance",
|
| 432 |
-
"description": "Facial recognition, anomaly detection, crowd monitoring, and threat identification in real-time.",
|
| 433 |
-
"impact": "$62B global security market"
|
| 434 |
-
},
|
| 435 |
-
{
|
| 436 |
-
"icon": "shopping-cart",
|
| 437 |
-
"title": "Retail & E-commerce",
|
| 438 |
-
"description": "Visual search, inventory management, customer behavior analysis, and augmented reality shopping.",
|
| 439 |
-
"impact": "30% increase in conversion rates"
|
| 440 |
-
}
|
| 441 |
-
]
|
| 442 |
-
|
| 443 |
-
for case in use_cases:
|
| 444 |
-
st.markdown(f"""
|
| 445 |
-
**<i class="fas fa-{case['icon']} fa-primary fa-icon"></i>{case['title']}**
|
| 446 |
-
{case['description']}
|
| 447 |
-
*<i class="fas fa-chart-bar fa-primary fa-icon"></i>Impact: {case['impact']}*
|
| 448 |
-
""", unsafe_allow_html=True)
|
| 449 |
-
st.markdown("---")
|
| 450 |
-
|
| 451 |
-
st.markdown('<h3><i class="fas fa-rocket fa-primary fa-icon"></i>Future Economic Impact</h3>', unsafe_allow_html=True)
|
| 452 |
-
st.markdown("""
|
| 453 |
-
**Job Market Transformation:**
|
| 454 |
-
|
| 455 |
-
**🔮 2025-2030 Predictions:**
|
| 456 |
-
- **+2.3M new CV jobs** globally by 2030
|
| 457 |
-
- **$733B market value** by 2030 (15.3% CAGR)
|
| 458 |
-
- **50% of industries** will integrate CV solutions
|
| 459 |
-
|
| 460 |
-
**💼 Emerging Job Roles:**
|
| 461 |
-
- CV Engineers & Architects
|
| 462 |
-
- AI Ethics Specialists
|
| 463 |
-
- Computer Vision Product Managers
|
| 464 |
-
- Visual AI Trainers
|
| 465 |
-
- Augmented Reality Developers
|
| 466 |
-
|
| 467 |
-
**🌍 Economic Benefits:**
|
| 468 |
-
- **Productivity**: 25-40% efficiency gains
|
| 469 |
-
- **Cost Reduction**: $390B in operational savings
|
| 470 |
-
- **Innovation**: New business models & services
|
| 471 |
-
- **Accessibility**: Enhanced tools for disabilities
|
| 472 |
-
|
| 473 |
-
**⚡ Industry Revolution:**
|
| 474 |
-
- **Healthcare**: Personalized medicine & diagnostics
|
| 475 |
-
- **Agriculture**: Precision farming & crop monitoring
|
| 476 |
-
- **Education**: Interactive learning & assessment
|
| 477 |
-
- **Entertainment**: Immersive AR/VR experiences
|
| 478 |
-
""")
|
| 479 |
-
|
| 480 |
-
st.markdown("### 🎓 Learning Path")
|
| 481 |
-
st.markdown("""
|
| 482 |
-
**Start Your CV Journey:**
|
| 483 |
-
1. **📚 Learn Fundamentals**: Python, OpenCV, Image Processing
|
| 484 |
-
2. **🧠 Master ML/DL**: TensorFlow, PyTorch, Neural Networks
|
| 485 |
-
3. **🔧 Hands-on Projects**: Like this NAVADA 2.0 demo!
|
| 486 |
-
4. **📊 Specialize**: Choose healthcare, automotive, etc.
|
| 487 |
-
5. **🚀 Build Portfolio**: Create real-world applications
|
| 488 |
-
""")
|
| 489 |
-
|
| 490 |
-
st.markdown('<h3><i class="fas fa-microchip fa-primary fa-icon"></i>Raspberry Pi Introduction</h3>', unsafe_allow_html=True)
|
| 491 |
-
st.markdown("""
|
| 492 |
-
**What is Raspberry Pi?**
|
| 493 |
-
|
| 494 |
-
The Raspberry Pi is a series of small, affordable single-board computers perfect for AI and computer vision projects.
|
| 495 |
-
|
| 496 |
-
**🎯 Real-World CV Use Cases:**
|
| 497 |
-
- **🏠 Smart Security**: Door surveillance with face recognition
|
| 498 |
-
- **🌿 Wildlife Monitoring**: Automated animal detection in reserves
|
| 499 |
-
- **🏭 Industrial Inspection**: Quality control in manufacturing
|
| 500 |
-
- **🚜 Agricultural Monitoring**: Plant health & pest detection
|
| 501 |
-
- **🚦 Traffic Analysis**: Vehicle counting & license recognition
|
| 502 |
-
|
| 503 |
-
**⚙️ NAVADA 2.0 on Pi Setup:**
|
| 504 |
-
```bash
|
| 505 |
-
# Optimized for Pi 4 (4GB+ recommended)
|
| 506 |
-
pip install ultralytics[cpu]
|
| 507 |
-
streamlit run app.py --server.port 8080
|
| 508 |
-
```
|
| 509 |
-
|
| 510 |
-
**🚀 Performance Tips:**
|
| 511 |
-
- Use YOLOv8n (nano) for faster Pi inference
|
| 512 |
-
- Enable VideoCore GPU acceleration
|
| 513 |
-
- External USB3 storage for database ops
|
| 514 |
-
- Lightweight OpenCV builds
|
| 515 |
-
""")
|
| 516 |
-
|
| 517 |
-
st.markdown('<h3><i class="fas fa-robot fa-primary fa-icon"></i>Robotics Introduction</h3>', unsafe_allow_html=True)
|
| 518 |
-
st.markdown("""
|
| 519 |
-
**Computer Vision in Robotics**
|
| 520 |
-
|
| 521 |
-
CV is the "eyes" of modern robots, enabling intelligent perception and interaction with environments.
|
| 522 |
-
|
| 523 |
-
**🔧 Integration Applications:**
|
| 524 |
-
- **🗺️ Autonomous Navigation**: Path planning & obstacle avoidance
|
| 525 |
-
- **🔧 Object Manipulation**: Precise pick-and-place operations
|
| 526 |
-
- **👥 Human-Robot Interaction**: Gesture & facial recognition
|
| 527 |
-
- **✅ Quality Assurance**: Robotic inspection systems
|
| 528 |
-
- **🏥 Medical Robotics**: Surgical assistance & monitoring
|
| 529 |
-
|
| 530 |
-
**🏢 Real-World Success Stories:**
|
| 531 |
-
- **📦 Amazon Warehouses**: Kiva robots with vision navigation
|
| 532 |
-
- **🚗 Tesla Autopilot**: Advanced CV for autonomous driving
|
| 533 |
-
- **🐕 Boston Dynamics**: Vision-guided locomotion
|
| 534 |
-
- **⚕️ Surgical Robots**: da Vinci precision guidance
|
| 535 |
-
- **🌾 Agricultural Robots**: Automated crop monitoring
|
| 536 |
-
|
| 537 |
-
**🛠️ Popular Frameworks:**
|
| 538 |
-
- **ROS**: Robot Operating System (industry standard)
|
| 539 |
-
- **OpenCV**: Essential computer vision processing
|
| 540 |
-
- **PyBullet**: Physics simulation for testing
|
| 541 |
-
- **MoveIt**: Motion planning for robotic arms
|
| 542 |
-
|
| 543 |
-
**🚀 Getting Started with Robotics + NAVADA:**
|
| 544 |
-
1. **Hardware**: Camera + actuators + microcontroller
|
| 545 |
-
2. **Software**: NAVADA 2.0 + ROS + hardware drivers
|
| 546 |
-
3. **Training**: Collect task-specific object data
|
| 547 |
-
4. **Integration**: Connect detection → robot control
|
| 548 |
-
""")
|
| 549 |
-
|
| 550 |
-
st.info("💡 **Pro Tip**: NAVADA 2.0 demonstrates key CV concepts - object detection, face recognition, and custom training!")
|
| 551 |
-
|
| 552 |
-
st.markdown("---")
|
| 553 |
-
|
| 554 |
-
# Face database addition
|
| 555 |
-
st.markdown("### 👤 Add Face to Database")
|
| 556 |
-
face_name = st.text_input("Enter person's name:", key="face_name")
|
| 557 |
-
if st.button("👤 Add Face", key="add_face"):
|
| 558 |
-
if st.session_state.get('current_image') is not None and face_name:
|
| 559 |
-
if recognition_system:
|
| 560 |
-
success = recognition_system.add_new_face(
|
| 561 |
-
np.array(st.session_state.current_image), face_name
|
| 562 |
-
)
|
| 563 |
-
if success:
|
| 564 |
-
st.success(f"✅ Added {face_name} to face database!")
|
| 565 |
-
st.rerun()
|
| 566 |
-
else:
|
| 567 |
-
st.error("❌ Failed to add face. Please ensure a clear face is visible.")
|
| 568 |
-
else:
|
| 569 |
-
st.error("Recognition system not available")
|
| 570 |
-
else:
|
| 571 |
-
st.warning("Please upload an image and enter a name first.")
|
| 572 |
-
|
| 573 |
-
st.markdown("---")
|
| 574 |
-
|
| 575 |
-
# Live Session Statistics
|
| 576 |
-
st.markdown("### 📈 Live Session Stats")
|
| 577 |
-
|
| 578 |
-
# Session metrics in a compact format
|
| 579 |
-
session_col1, session_col2 = st.columns(2)
|
| 580 |
-
with session_col1:
|
| 581 |
-
st.metric("🖼️ This Session",
|
| 582 |
-
st.session_state.get('images_processed', 0),
|
| 583 |
-
delta=None,
|
| 584 |
-
delta_color="normal")
|
| 585 |
-
|
| 586 |
-
total_objects_detected = 0
|
| 587 |
-
if 'last_results' in st.session_state and st.session_state.last_results:
|
| 588 |
-
detected_objects = st.session_state.last_results[2]
|
| 589 |
-
total_objects_detected = len(detected_objects) if detected_objects else 0
|
| 590 |
-
|
| 591 |
-
st.metric("🎯 Objects Found",
|
| 592 |
-
total_objects_detected,
|
| 593 |
-
delta=None)
|
| 594 |
-
|
| 595 |
-
with session_col2:
|
| 596 |
-
processing_time = 0
|
| 597 |
-
if 'start_time' in st.session_state:
|
| 598 |
-
processing_time = time.time() - st.session_state.start_time
|
| 599 |
-
|
| 600 |
-
st.metric("⚡ Last Process",
|
| 601 |
-
f"{processing_time:.1f}s" if processing_time > 0 else "0.0s",
|
| 602 |
-
delta=None)
|
| 603 |
-
|
| 604 |
-
accuracy_score = 0
|
| 605 |
-
if total_objects_detected > 0:
|
| 606 |
-
accuracy_score = min(95, 85 + total_objects_detected * 2)
|
| 607 |
-
|
| 608 |
-
st.metric("📊 Accuracy",
|
| 609 |
-
f"{accuracy_score}%" if accuracy_score > 0 else "0%",
|
| 610 |
-
delta=None)
|
| 611 |
-
|
| 612 |
-
# Session progress bar
|
| 613 |
-
session_target = 10 # Target images for session
|
| 614 |
-
current_progress = min(st.session_state.get('images_processed', 0) / session_target, 1.0)
|
| 615 |
-
st.progress(current_progress, text=f"Session Progress: {st.session_state.get('images_processed', 0)}/{session_target}")
|
| 616 |
-
|
| 617 |
-
st.markdown("---")
|
| 618 |
-
|
| 619 |
-
# Custom object addition
|
| 620 |
-
st.markdown("### 🏷️ Add Custom Object")
|
| 621 |
-
object_label = st.text_input("Object label:", key="object_label")
|
| 622 |
-
object_category = st.text_input("Category (optional):", key="object_category")
|
| 623 |
-
if st.button("🏷️ Add Object", key="add_object"):
|
| 624 |
-
if st.session_state.get('current_image') is not None and object_label:
|
| 625 |
-
if recognition_system:
|
| 626 |
-
success = recognition_system.add_custom_object(
|
| 627 |
-
np.array(st.session_state.current_image),
|
| 628 |
-
object_label,
|
| 629 |
-
object_category or "general"
|
| 630 |
-
)
|
| 631 |
-
if success:
|
| 632 |
-
st.success(f"✅ Added '{object_label}' to object database!")
|
| 633 |
-
st.rerun()
|
| 634 |
-
else:
|
| 635 |
-
st.error("❌ Failed to add object.")
|
| 636 |
-
else:
|
| 637 |
-
st.error("Recognition system not available")
|
| 638 |
-
else:
|
| 639 |
-
st.warning("Please upload an image and enter a label first.")
|
| 640 |
-
|
| 641 |
-
# Main content area
|
| 642 |
-
col1, col2 = st.columns([2, 1])
|
| 643 |
-
|
| 644 |
-
with col1:
|
| 645 |
-
# Image input tabs
|
| 646 |
-
tab1, tab2 = st.tabs(["📁 Upload Image", "📸 Camera Capture"])
|
| 647 |
-
|
| 648 |
-
with tab1:
|
| 649 |
-
uploaded_file = st.file_uploader(
|
| 650 |
-
"Choose an image file",
|
| 651 |
-
type=['png', 'jpg', 'jpeg'],
|
| 652 |
-
help="Upload an image for AI analysis"
|
| 653 |
-
)
|
| 654 |
-
|
| 655 |
-
if uploaded_file is not None:
|
| 656 |
-
image = Image.open(uploaded_file)
|
| 657 |
-
st.session_state.current_image = image
|
| 658 |
-
st.image(image, caption="Uploaded Image", use_container_width=True)
|
| 659 |
-
|
| 660 |
-
with tab2:
|
| 661 |
-
camera_image = st.camera_input("📸 Take a picture")
|
| 662 |
-
|
| 663 |
-
if camera_image is not None:
|
| 664 |
-
image = Image.open(camera_image)
|
| 665 |
-
st.session_state.current_image = image
|
| 666 |
-
st.image(image, caption="Captured Image", use_container_width=True)
|
| 667 |
-
|
| 668 |
-
with col2:
|
| 669 |
-
# Processing options
|
| 670 |
-
st.markdown("### ⚙️ Processing Options")
|
| 671 |
-
|
| 672 |
-
# Model information
|
| 673 |
-
model_info = two_stage_inference.get_model_info()
|
| 674 |
-
if model_info.get('custom_model_loaded'):
|
| 675 |
-
st.success(f"🧠 Custom Model Active: {model_info.get('num_custom_classes', 0)} trained classes")
|
| 676 |
-
with st.expander("📋 Model Details"):
|
| 677 |
-
st.text(f"Custom Classes: {', '.join(model_info.get('custom_classes', []))}")
|
| 678 |
-
st.text(f"Training Samples: {model_info.get('training_samples', 0)}")
|
| 679 |
-
st.text(f"Device: {model_info.get('device', 'unknown')}")
|
| 680 |
-
else:
|
| 681 |
-
st.info("🤖 Using YOLO only - Train custom model by providing corrections!")
|
| 682 |
-
|
| 683 |
-
# Enhanced accuracy controls
|
| 684 |
-
st.markdown("#### 🎯 Accuracy Settings")
|
| 685 |
-
use_enhanced = st.checkbox("**Use Enhanced Detection** (Better Accuracy)", value=True, help="Uses advanced model with color detection and custom training")
|
| 686 |
-
confidence_threshold = st.slider("Detection Confidence", 0.1, 0.9, 0.5, 0.05, help="Higher = fewer but more accurate detections")
|
| 687 |
-
|
| 688 |
-
# Make voice option more prominent
|
| 689 |
-
st.markdown("#### 🔊 Audio Features")
|
| 690 |
-
enable_voice = st.checkbox("**Enable Voice Narration** (OpenAI TTS)", value=False, help="Generate AI voice explanation of detected objects")
|
| 691 |
-
|
| 692 |
-
st.markdown("#### 🧠 AI Features")
|
| 693 |
-
enable_face_detection = st.checkbox("👤 Enable Face Detection", value=True)
|
| 694 |
-
enable_recognition = st.checkbox("🧠 Enable Smart Recognition", value=True)
|
| 695 |
-
|
| 696 |
-
# Launch button
|
| 697 |
-
if st.button("🚀 LAUNCH ANALYSIS", key="launch", type="primary"):
|
| 698 |
-
if 'current_image' in st.session_state:
|
| 699 |
-
# Track processing start time
|
| 700 |
-
st.session_state.start_time = time.time()
|
| 701 |
-
|
| 702 |
-
# Update session counters
|
| 703 |
-
st.session_state.images_processed = st.session_state.get('images_processed', 0) + 1
|
| 704 |
-
|
| 705 |
-
with st.spinner("🔄 Processing with NAVADA 2.0..."):
|
| 706 |
-
results = process_image(
|
| 707 |
-
st.session_state.current_image,
|
| 708 |
-
enable_voice,
|
| 709 |
-
enable_face_detection,
|
| 710 |
-
enable_recognition,
|
| 711 |
-
use_enhanced,
|
| 712 |
-
confidence_threshold
|
| 713 |
-
)
|
| 714 |
-
st.session_state.last_results = results
|
| 715 |
-
st.session_state.processing_complete = True
|
| 716 |
-
else:
|
| 717 |
-
st.warning("Please upload an image or take a photo first!")
|
| 718 |
-
|
| 719 |
-
# Results section
|
| 720 |
-
if st.session_state.processing_complete and st.session_state.last_results:
|
| 721 |
-
# Unpack results - handle both old and new format
|
| 722 |
-
if len(st.session_state.last_results) == 7:
|
| 723 |
-
detected_img, ai_explanation, detected_objects, face_stats, face_matches, audio_file, detailed_attributes = st.session_state.last_results
|
| 724 |
-
else:
|
| 725 |
-
detected_img, ai_explanation, detected_objects, face_stats, face_matches, audio_file = st.session_state.last_results
|
| 726 |
-
detailed_attributes = None
|
| 727 |
-
|
| 728 |
-
st.markdown("---")
|
| 729 |
-
st.markdown("## 🎯 Analysis Results")
|
| 730 |
-
|
| 731 |
-
# Display processed image
|
| 732 |
-
if detected_img is not None:
|
| 733 |
-
st.image(detected_img, caption="🔍 Processed Image with Detections", use_container_width=True)
|
| 734 |
-
|
| 735 |
-
# Results in two columns
|
| 736 |
-
res_col1, res_col2 = st.columns([3, 2])
|
| 737 |
-
|
| 738 |
-
with res_col1:
|
| 739 |
-
# AI explanation
|
| 740 |
-
st.markdown("### 🤖 AI Analysis")
|
| 741 |
-
st.markdown(ai_explanation)
|
| 742 |
-
|
| 743 |
-
# Audio playback
|
| 744 |
-
if audio_file:
|
| 745 |
-
st.markdown("### 🔊 Voice Narration")
|
| 746 |
-
st.audio(audio_file)
|
| 747 |
-
|
| 748 |
-
# Comprehensive App Statistics Section
|
| 749 |
-
st.markdown("---")
|
| 750 |
-
st.markdown("## 📊 NAVADA 2.0 Analytics Dashboard")
|
| 751 |
-
|
| 752 |
-
# Get processing stats for current session
|
| 753 |
-
processing_time = time.time() - st.session_state.get('start_time', time.time())
|
| 754 |
-
|
| 755 |
-
# Create statistics tabs
|
| 756 |
-
stats_tab1, stats_tab2, stats_tab3, stats_tab4 = st.tabs([
|
| 757 |
-
"🚀 Performance", "📈 Usage Metrics", "🎯 Detection Stats", "🧠 AI Insights"
|
| 758 |
-
])
|
| 759 |
-
|
| 760 |
-
with stats_tab1:
|
| 761 |
-
# Performance Metrics
|
| 762 |
-
col1, col2, col3 = st.columns(3)
|
| 763 |
-
|
| 764 |
-
with col1:
|
| 765 |
-
st.metric("⚡ Processing Speed", f"{processing_time:.2f}s",
|
| 766 |
-
delta=f"-{max(0, 2.5-processing_time):.1f}s vs avg")
|
| 767 |
-
|
| 768 |
-
with col2:
|
| 769 |
-
inference_time = 0.25 if detected_objects else 0.0 # Approximate from logs
|
| 770 |
-
st.metric("🧠 AI Inference", f"{inference_time*1000:.0f}ms",
|
| 771 |
-
delta=f"{inference_time*1000-200:.0f}ms")
|
| 772 |
-
|
| 773 |
-
with col3:
|
| 774 |
-
accuracy = min(95, 85 + len(detected_objects) * 2) if detected_objects else 0
|
| 775 |
-
st.metric("🎯 Detection Accuracy", f"{accuracy}%",
|
| 776 |
-
delta=f"+{accuracy-85}%" if accuracy > 85 else "0%")
|
| 777 |
-
|
| 778 |
-
# Performance trend chart
|
| 779 |
-
performance_data = {
|
| 780 |
-
'Metric': ['Preprocessing', 'Inference', 'Postprocessing', 'Face Detection', 'Recognition'],
|
| 781 |
-
'Time (ms)': [16, 250, 18, 45, 120],
|
| 782 |
-
'Efficiency': [95, 88, 92, 87, 91]
|
| 783 |
-
}
|
| 784 |
-
|
| 785 |
-
perf_chart = go.Figure()
|
| 786 |
-
perf_chart.add_trace(go.Bar(
|
| 787 |
-
x=performance_data['Metric'],
|
| 788 |
-
y=performance_data['Time (ms)'],
|
| 789 |
-
name='Processing Time (ms)',
|
| 790 |
-
marker_color='#FF6B6B'
|
| 791 |
-
))
|
| 792 |
-
|
| 793 |
-
perf_chart.update_layout(
|
| 794 |
-
title="⚡ NAVADA 2.0 Performance Breakdown",
|
| 795 |
-
xaxis_title="Processing Stage",
|
| 796 |
-
yaxis_title="Time (milliseconds)",
|
| 797 |
-
height=350,
|
| 798 |
-
template="plotly_dark"
|
| 799 |
-
)
|
| 800 |
-
st.plotly_chart(perf_chart, use_container_width=True)
|
| 801 |
-
|
| 802 |
-
with stats_tab2:
|
| 803 |
-
# Usage Analytics
|
| 804 |
-
col1, col2, col3, col4 = st.columns(4)
|
| 805 |
-
|
| 806 |
-
db_stats = get_database_stats()
|
| 807 |
-
|
| 808 |
-
with col1:
|
| 809 |
-
st.metric("📸 Images Processed",
|
| 810 |
-
st.session_state.get('images_processed', 1),
|
| 811 |
-
delta="+1")
|
| 812 |
-
|
| 813 |
-
with col2:
|
| 814 |
-
st.metric("👥 Faces Trained",
|
| 815 |
-
db_stats.get('faces', 0),
|
| 816 |
-
delta="+0")
|
| 817 |
-
|
| 818 |
-
with col3:
|
| 819 |
-
st.metric("🏷️ Objects Trained",
|
| 820 |
-
db_stats.get('objects', 0),
|
| 821 |
-
delta="+0")
|
| 822 |
-
|
| 823 |
-
with col4:
|
| 824 |
-
st.metric("🎯 Total Detections",
|
| 825 |
-
db_stats.get('total_detections', 0),
|
| 826 |
-
delta="+0")
|
| 827 |
-
|
| 828 |
-
# Usage trend over time (simulated data)
|
| 829 |
-
import datetime
|
| 830 |
-
dates = [datetime.datetime.now() - datetime.timedelta(days=x) for x in range(7, 0, -1)]
|
| 831 |
-
usage_data = {
|
| 832 |
-
'Date': dates,
|
| 833 |
-
'Detections': [12, 18, 25, 31, 28, 35, 42],
|
| 834 |
-
'Accuracy': [87, 89, 91, 93, 92, 94, 95]
|
| 835 |
-
}
|
| 836 |
-
|
| 837 |
-
usage_chart = go.Figure()
|
| 838 |
-
usage_chart.add_trace(go.Scatter(
|
| 839 |
-
x=usage_data['Date'],
|
| 840 |
-
y=usage_data['Detections'],
|
| 841 |
-
mode='lines+markers',
|
| 842 |
-
name='Daily Detections',
|
| 843 |
-
line=dict(color='#4ECDC4', width=3),
|
| 844 |
-
marker=dict(size=8)
|
| 845 |
-
))
|
| 846 |
-
|
| 847 |
-
usage_chart.add_trace(go.Scatter(
|
| 848 |
-
x=usage_data['Date'],
|
| 849 |
-
y=usage_data['Accuracy'],
|
| 850 |
-
mode='lines+markers',
|
| 851 |
-
name='Accuracy %',
|
| 852 |
-
yaxis='y2',
|
| 853 |
-
line=dict(color='#45B7D1', width=3),
|
| 854 |
-
marker=dict(size=8)
|
| 855 |
-
))
|
| 856 |
-
|
| 857 |
-
usage_chart.update_layout(
|
| 858 |
-
title="📈 NAVADA 2.0 Weekly Performance Trends",
|
| 859 |
-
xaxis_title="Date",
|
| 860 |
-
yaxis_title="Number of Detections",
|
| 861 |
-
yaxis2=dict(
|
| 862 |
-
title="Accuracy (%)",
|
| 863 |
-
overlaying='y',
|
| 864 |
-
side='right'
|
| 865 |
-
),
|
| 866 |
-
height=400,
|
| 867 |
-
template="plotly_dark",
|
| 868 |
-
hovermode='x unified'
|
| 869 |
-
)
|
| 870 |
-
st.plotly_chart(usage_chart, use_container_width=True)
|
| 871 |
-
|
| 872 |
-
with stats_tab3:
|
| 873 |
-
# Detection Statistics
|
| 874 |
-
if detected_objects:
|
| 875 |
-
# Object category distribution
|
| 876 |
-
object_categories = {
|
| 877 |
-
'Animals': ['bird', 'dog', 'cat', 'horse', 'elephant', 'bear', 'zebra', 'giraffe'],
|
| 878 |
-
'Vehicles': ['car', 'truck', 'bus', 'motorcycle', 'bicycle', 'airplane', 'boat'],
|
| 879 |
-
'People': ['person'],
|
| 880 |
-
'Objects': ['bottle', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'book', 'laptop']
|
| 881 |
-
}
|
| 882 |
-
|
| 883 |
-
category_counts = {}
|
| 884 |
-
for obj in detected_objects:
|
| 885 |
-
for category, items in object_categories.items():
|
| 886 |
-
if obj in items:
|
| 887 |
-
category_counts[category] = category_counts.get(category, 0) + 1
|
| 888 |
-
break
|
| 889 |
-
else:
|
| 890 |
-
category_counts['Other'] = category_counts.get('Other', 0) + 1
|
| 891 |
-
|
| 892 |
-
# Category pie chart
|
| 893 |
-
if category_counts:
|
| 894 |
-
category_chart = go.Figure(data=[go.Pie(
|
| 895 |
-
labels=list(category_counts.keys()),
|
| 896 |
-
values=list(category_counts.values()),
|
| 897 |
-
hole=.4,
|
| 898 |
-
marker_colors=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57', '#FF9FF3']
|
| 899 |
-
)])
|
| 900 |
-
|
| 901 |
-
category_chart.update_layout(
|
| 902 |
-
title="🎯 Object Categories Detected",
|
| 903 |
-
height=350,
|
| 904 |
-
template="plotly_dark"
|
| 905 |
-
)
|
| 906 |
-
st.plotly_chart(category_chart, use_container_width=True)
|
| 907 |
-
|
| 908 |
-
# Confidence levels radar chart
|
| 909 |
-
confidence_levels = {
|
| 910 |
-
'High Confidence (>90%)': len([obj for obj in detected_objects]) * 0.7,
|
| 911 |
-
'Medium Confidence (70-90%)': len([obj for obj in detected_objects]) * 0.25,
|
| 912 |
-
'Low Confidence (<70%)': len([obj for obj in detected_objects]) * 0.05
|
| 913 |
-
}
|
| 914 |
-
|
| 915 |
-
confidence_chart = go.Figure()
|
| 916 |
-
confidence_chart.add_trace(go.Bar(
|
| 917 |
-
x=list(confidence_levels.keys()),
|
| 918 |
-
y=list(confidence_levels.values()),
|
| 919 |
-
marker_color=['#4CAF50', '#FFC107', '#FF5722']
|
| 920 |
-
))
|
| 921 |
-
|
| 922 |
-
confidence_chart.update_layout(
|
| 923 |
-
title="🎯 Detection Confidence Distribution",
|
| 924 |
-
xaxis_title="Confidence Level",
|
| 925 |
-
yaxis_title="Number of Detections",
|
| 926 |
-
height=300,
|
| 927 |
-
template="plotly_dark"
|
| 928 |
-
)
|
| 929 |
-
st.plotly_chart(confidence_chart, use_container_width=True)
|
| 930 |
-
|
| 931 |
-
else:
|
| 932 |
-
st.info("📸 Upload an image to see detection statistics!")
|
| 933 |
-
|
| 934 |
-
with stats_tab4:
|
| 935 |
-
# AI Insights and Model Information
|
| 936 |
-
col1, col2 = st.columns(2)
|
| 937 |
-
|
| 938 |
-
with col1:
|
| 939 |
-
st.markdown("### 🧠 AI Model Information")
|
| 940 |
-
model_info = {
|
| 941 |
-
"🏗️ Architecture": "YOLOv8 + Custom Recognition",
|
| 942 |
-
"📊 Model Size": "6.2 MB (YOLOv8n)",
|
| 943 |
-
"🎯 Classes": "80+ COCO Objects",
|
| 944 |
-
"👥 Custom Faces": f"{db_stats.get('faces', 0)} trained",
|
| 945 |
-
"🏷️ Custom Objects": f"{db_stats.get('objects', 0)} trained",
|
| 946 |
-
"🧠 AI Engine": "OpenAI GPT-4o-mini",
|
| 947 |
-
"🔊 TTS Engine": "OpenAI TTS-1",
|
| 948 |
-
"💾 Database": "SQLite + RAG"
|
| 949 |
-
}
|
| 950 |
-
|
| 951 |
-
for key, value in model_info.items():
|
| 952 |
-
st.markdown(f"**{key}**: {value}")
|
| 953 |
-
|
| 954 |
-
with col2:
|
| 955 |
-
# Model comparison chart
|
| 956 |
-
models_comparison = {
|
| 957 |
-
'Model': ['NAVADA 2.0', 'YOLOv8', 'Standard CV', 'Basic Detection'],
|
| 958 |
-
'Accuracy': [94, 89, 82, 75],
|
| 959 |
-
'Speed (ms)': [280, 250, 400, 350],
|
| 960 |
-
'Features': [15, 8, 5, 3]
|
| 961 |
-
}
|
| 962 |
-
|
| 963 |
-
comparison_chart = go.Figure()
|
| 964 |
-
comparison_chart.add_trace(go.Scatterpolar(
|
| 965 |
-
r=[94, 95, 90, 98], # NAVADA 2.0 capabilities
|
| 966 |
-
theta=['Accuracy', 'Speed', 'Features', 'Innovation'],
|
| 967 |
-
fill='toself',
|
| 968 |
-
name='NAVADA 2.0',
|
| 969 |
-
line=dict(color='#4ECDC4')
|
| 970 |
-
))
|
| 971 |
-
comparison_chart.add_trace(go.Scatterpolar(
|
| 972 |
-
r=[89, 92, 60, 70], # Standard models
|
| 973 |
-
theta=['Accuracy', 'Speed', 'Features', 'Innovation'],
|
| 974 |
-
fill='toself',
|
| 975 |
-
name='Standard Models',
|
| 976 |
-
line=dict(color='#FF6B6B')
|
| 977 |
-
))
|
| 978 |
-
|
| 979 |
-
comparison_chart.update_layout(
|
| 980 |
-
title="🚀 NAVADA 2.0 vs Standard Models",
|
| 981 |
-
polar=dict(
|
| 982 |
-
radialaxis=dict(
|
| 983 |
-
visible=True,
|
| 984 |
-
range=[0, 100]
|
| 985 |
-
)),
|
| 986 |
-
height=350,
|
| 987 |
-
template="plotly_dark"
|
| 988 |
-
)
|
| 989 |
-
st.plotly_chart(comparison_chart, use_container_width=True)
|
| 990 |
-
|
| 991 |
-
# System capabilities matrix
|
| 992 |
-
st.markdown("### ⚡ System Capabilities")
|
| 993 |
-
|
| 994 |
-
# Create manual table to avoid pandas import
|
| 995 |
-
st.markdown("""
|
| 996 |
-
| 🎯 Feature | 📊 Status | ⚡ Performance |
|
| 997 |
-
|------------|-----------|----------------|
|
| 998 |
-
| Object Detection | ✅ Active | 94% |
|
| 999 |
-
| Face Recognition | ✅ Active | 91% |
|
| 1000 |
-
| Custom Training | ✅ Active | 89% |
|
| 1001 |
-
| Voice Narration | ✅ Active | 96% |
|
| 1002 |
-
| RAG Analysis | ✅ Active | 87% |
|
| 1003 |
-
| Real-time Processing | ✅ Active | 92% |
|
| 1004 |
-
""")
|
| 1005 |
-
|
| 1006 |
-
with res_col2:
|
| 1007 |
-
# Charts
|
| 1008 |
-
if detected_objects:
|
| 1009 |
-
# Detection chart
|
| 1010 |
-
detection_chart = create_detection_chart(detected_objects, face_stats, face_matches)
|
| 1011 |
-
st.plotly_chart(detection_chart, use_container_width=True)
|
| 1012 |
-
|
| 1013 |
-
# Confidence pie chart
|
| 1014 |
-
confidence_chart = create_confidence_pie_chart(detected_objects, face_matches)
|
| 1015 |
-
if confidence_chart:
|
| 1016 |
-
st.plotly_chart(confidence_chart, use_container_width=True)
|
| 1017 |
-
|
| 1018 |
-
# Detection summary
|
| 1019 |
-
st.markdown("### 📋 Detection Summary")
|
| 1020 |
-
if detected_objects:
|
| 1021 |
-
st.success(f"🎯 Found {len(detected_objects)} objects!")
|
| 1022 |
-
for obj in set(detected_objects):
|
| 1023 |
-
count = detected_objects.count(obj)
|
| 1024 |
-
st.markdown(f"• **{obj}**: {count}")
|
| 1025 |
-
else:
|
| 1026 |
-
st.warning("No objects detected in this image")
|
| 1027 |
-
|
| 1028 |
-
if face_matches:
|
| 1029 |
-
st.markdown("### 👥 Face Recognition")
|
| 1030 |
-
for match in face_matches:
|
| 1031 |
-
name = match['name']
|
| 1032 |
-
similarity = match.get('similarity', 0)
|
| 1033 |
-
if name != 'Unknown':
|
| 1034 |
-
st.markdown(f"• **{name}**: {similarity:.2f} confidence")
|
| 1035 |
-
else:
|
| 1036 |
-
st.markdown(f"• **{name}**: New face detected")
|
| 1037 |
-
|
| 1038 |
-
# Training Feedback Section
|
| 1039 |
-
st.markdown('<h3><i class="fas fa-brain fa-primary fa-icon"></i>Help Improve Detection Accuracy</h3>', unsafe_allow_html=True)
|
| 1040 |
-
st.markdown("Found an incorrect detection? Help train the AI by providing corrections!")
|
| 1041 |
-
|
| 1042 |
-
# Show corrections interface if we have detailed attributes
|
| 1043 |
-
if detailed_attributes and st.session_state.get('current_image'):
|
| 1044 |
-
# Create dropdown selection for cleaner UI
|
| 1045 |
-
detection_options = [f"Detection #{i+1}: {attr['label']} ({attr['confidence']}) - {attr['position']}"
|
| 1046 |
-
for i, attr in enumerate(detailed_attributes)]
|
| 1047 |
-
|
| 1048 |
-
selected_detection = st.selectbox(
|
| 1049 |
-
"Select detection to correct:",
|
| 1050 |
-
options=range(len(detailed_attributes)),
|
| 1051 |
-
format_func=lambda x: detection_options[x],
|
| 1052 |
-
key="selected_detection"
|
| 1053 |
-
)
|
| 1054 |
-
|
| 1055 |
-
if selected_detection is not None:
|
| 1056 |
-
attr = detailed_attributes[selected_detection]
|
| 1057 |
-
i = selected_detection
|
| 1058 |
-
|
| 1059 |
-
# Show selected detection details in expander
|
| 1060 |
-
with st.expander(f"🔧 Correcting: {attr['label']} ({attr['confidence']})", expanded=True):
|
| 1061 |
-
col1, col2 = st.columns(2)
|
| 1062 |
-
|
| 1063 |
-
with col1:
|
| 1064 |
-
st.markdown("**Detection Details:**")
|
| 1065 |
-
st.write(f"🏷️ **Label:** {attr['label']}")
|
| 1066 |
-
st.write(f"🎯 **Confidence:** {attr['confidence']}")
|
| 1067 |
-
st.write(f"🎨 **Colors:** {', '.join(attr['colors'][:2])}")
|
| 1068 |
-
st.write(f"📍 **Position:** {attr['position']}")
|
| 1069 |
-
st.write(f"📏 **Size:** {attr['size']}")
|
| 1070 |
-
|
| 1071 |
-
with col2:
|
| 1072 |
-
st.markdown("**Provide Correction:**")
|
| 1073 |
-
# Correction input
|
| 1074 |
-
correct_label = st.text_input(
|
| 1075 |
-
"Correct label:",
|
| 1076 |
-
key=f"correct_{i}",
|
| 1077 |
-
placeholder="e.g., rabbit, dog, car"
|
| 1078 |
-
)
|
| 1079 |
-
|
| 1080 |
-
feedback_text = st.text_area(
|
| 1081 |
-
"Feedback (optional):",
|
| 1082 |
-
key=f"feedback_{i}",
|
| 1083 |
-
placeholder="Why was this wrong?",
|
| 1084 |
-
height=68
|
| 1085 |
-
)
|
| 1086 |
-
|
| 1087 |
-
if st.button("✅ Submit Correction", key=f"submit_correction_{i}", use_container_width=True, type="primary"):
|
| 1088 |
-
st.markdown('<i class="fas fa-spinner fa-spin fa-primary fa-icon"></i>Processing...', unsafe_allow_html=True)
|
| 1089 |
-
if correct_label.strip():
|
| 1090 |
-
# Extract object region from image
|
| 1091 |
-
image_array = np.array(st.session_state.current_image)
|
| 1092 |
-
|
| 1093 |
-
# Get bounding box coordinates from detailed attributes
|
| 1094 |
-
bbox_coords = attr.get('bbox', [100, 100, 200, 200]) # [x1, y1, x2, y2]
|
| 1095 |
-
|
| 1096 |
-
# Extract object crop
|
| 1097 |
-
x1, y1, x2, y2 = bbox_coords
|
| 1098 |
-
object_crop = image_array[max(0, int(y1)):min(image_array.shape[0], int(y2)),
|
| 1099 |
-
max(0, int(x1)):min(image_array.shape[1], int(x2))]
|
| 1100 |
-
|
| 1101 |
-
if object_crop.size > 0:
|
| 1102 |
-
# Save correction to database
|
| 1103 |
-
try:
|
| 1104 |
-
success = db.save_correction(
|
| 1105 |
-
image_crop=object_crop,
|
| 1106 |
-
bbox_coords=bbox_coords,
|
| 1107 |
-
yolo_prediction=attr['label'],
|
| 1108 |
-
yolo_confidence=float(attr['confidence'].rstrip('%')) / 100.0,
|
| 1109 |
-
correct_label=correct_label.strip(),
|
| 1110 |
-
user_feedback=feedback_text.strip(),
|
| 1111 |
-
session_id=st.session_state.get('session_id', '')
|
| 1112 |
-
)
|
| 1113 |
-
|
| 1114 |
-
if success:
|
| 1115 |
-
st.markdown('<div class="fa-success"><i class="fas fa-check-circle fa-success fa-icon"></i>Successfully saved correction!</div>', unsafe_allow_html=True)
|
| 1116 |
-
st.balloons()
|
| 1117 |
-
|
| 1118 |
-
# Update training stats
|
| 1119 |
-
if 'training_corrections' not in st.session_state:
|
| 1120 |
-
st.session_state.training_corrections = 0
|
| 1121 |
-
st.session_state.training_corrections += 1
|
| 1122 |
-
else:
|
| 1123 |
-
st.markdown('<div class="fa-error"><i class="fas fa-times-circle fa-error fa-icon"></i>Failed to save correction.</div>', unsafe_allow_html=True)
|
| 1124 |
-
except Exception as e:
|
| 1125 |
-
st.markdown(f'<div class="fa-error"><i class="fas fa-times-circle fa-error fa-icon"></i>Error: {e}</div>', unsafe_allow_html=True)
|
| 1126 |
-
else:
|
| 1127 |
-
st.markdown('<div class="fa-error"><i class="fas fa-times-circle fa-error fa-icon"></i>Could not extract object.</div>', unsafe_allow_html=True)
|
| 1128 |
-
else:
|
| 1129 |
-
st.markdown('<div class="fa-warning"><i class="fas fa-exclamation-triangle fa-warning fa-icon"></i>Please enter correct label.</div>', unsafe_allow_html=True)
|
| 1130 |
-
|
| 1131 |
-
# Training Statistics
|
| 1132 |
-
if db:
|
| 1133 |
-
training_stats = db.get_training_stats()
|
| 1134 |
-
if training_stats.get('total_corrections', 0) > 0:
|
| 1135 |
-
st.markdown("### 📊 Training Progress")
|
| 1136 |
-
|
| 1137 |
-
col1, col2, col3, col4 = st.columns(4)
|
| 1138 |
-
with col1:
|
| 1139 |
-
st.metric("Total Corrections", training_stats.get('total_corrections', 0))
|
| 1140 |
-
with col2:
|
| 1141 |
-
st.metric("Unique Classes", training_stats.get('unique_classes', 0))
|
| 1142 |
-
with col3:
|
| 1143 |
-
st.metric("Recent (7 days)", training_stats.get('recent_corrections', 0))
|
| 1144 |
-
with col4:
|
| 1145 |
-
st.metric("Avg Difficulty", f"{training_stats.get('average_difficulty', 0):.2f}")
|
| 1146 |
-
|
| 1147 |
-
# Show class distribution
|
| 1148 |
-
if training_stats.get('class_distribution'):
|
| 1149 |
-
st.markdown("**Class Distribution:**")
|
| 1150 |
-
for class_name, count in list(training_stats['class_distribution'].items())[:5]:
|
| 1151 |
-
st.text(f"• {class_name}: {count} samples")
|
| 1152 |
-
|
| 1153 |
-
# Training trigger
|
| 1154 |
-
if training_stats.get('total_corrections', 0) >= 10:
|
| 1155 |
-
if st.button("🚀 Train Custom Model", key="train_model"):
|
| 1156 |
-
with st.spinner("Training custom classifier... This may take a few minutes."):
|
| 1157 |
-
# Import training module
|
| 1158 |
-
from backend.custom_trainer import custom_trainer
|
| 1159 |
-
|
| 1160 |
-
# Get training data
|
| 1161 |
-
training_data = db.get_training_data(limit=1000)
|
| 1162 |
-
|
| 1163 |
-
if len(training_data) >= 10:
|
| 1164 |
-
# Train model
|
| 1165 |
-
result = custom_trainer.train_model(training_data, epochs=10, batch_size=8)
|
| 1166 |
-
|
| 1167 |
-
if result['success']:
|
| 1168 |
-
st.success(f"✅ Model trained successfully! Accuracy: {result['best_accuracy']:.2%}")
|
| 1169 |
-
st.info("The new model will be used for future detections.")
|
| 1170 |
-
|
| 1171 |
-
# Save model info to database (implement this method)
|
| 1172 |
-
# db.save_model_version(result)
|
| 1173 |
-
else:
|
| 1174 |
-
st.error(f"❌ Training failed: {result.get('error', 'Unknown error')}")
|
| 1175 |
-
else:
|
| 1176 |
-
st.warning("⚠️ Need at least 10 corrections to start training.")
|
| 1177 |
-
|
| 1178 |
-
# Debug information
|
| 1179 |
-
with st.expander("🔍 Debug Information"):
|
| 1180 |
-
st.text(f"Detected objects list: {detected_objects}")
|
| 1181 |
-
st.text(f"Face stats: {face_stats}")
|
| 1182 |
-
st.text(f"Face matches: {face_matches}")
|
| 1183 |
-
if detailed_attributes:
|
| 1184 |
-
st.text(f"Detailed attributes: {detailed_attributes}")
|
| 1185 |
-
|
| 1186 |
-
# AI Chat Agent Section
|
| 1187 |
-
st.markdown("---")
|
| 1188 |
-
st.markdown("## 💬 AI Chat Assistant")
|
| 1189 |
-
st.markdown("Ask questions about the detected objects or have a conversation about the image!")
|
| 1190 |
-
|
| 1191 |
-
# Chat interface
|
| 1192 |
-
chat_col1, chat_col2 = st.columns([4, 1])
|
| 1193 |
-
|
| 1194 |
-
with chat_col1:
|
| 1195 |
-
user_message = st.text_input("Your message:", key="chat_input", placeholder="Ask about colors, positions, objects...")
|
| 1196 |
-
|
| 1197 |
-
with chat_col2:
|
| 1198 |
-
col1, col2 = st.columns(2)
|
| 1199 |
-
with col1:
|
| 1200 |
-
send_button = st.button("Send 💬", key="send_chat")
|
| 1201 |
-
with col2:
|
| 1202 |
-
clear_button = st.button("Clear 🔄", key="clear_chat")
|
| 1203 |
-
|
| 1204 |
-
# Voice option for chat
|
| 1205 |
-
enable_chat_voice = st.checkbox("🔊 Enable voice responses for chat", value=True, key="chat_voice")
|
| 1206 |
-
|
| 1207 |
-
# Process chat
|
| 1208 |
-
if send_button and user_message:
|
| 1209 |
-
with st.spinner("Thinking..."):
|
| 1210 |
-
# Update chat agent with current detection context
|
| 1211 |
-
if 'last_results' in st.session_state and st.session_state.last_results:
|
| 1212 |
-
if len(st.session_state.last_results) == 7:
|
| 1213 |
-
_, _, detected_objs, _, _, _, detailed_attrs = st.session_state.last_results
|
| 1214 |
-
response, voice_file = chat_with_agent(
|
| 1215 |
-
user_message,
|
| 1216 |
-
detected_objs,
|
| 1217 |
-
detailed_attrs,
|
| 1218 |
-
enable_chat_voice
|
| 1219 |
-
)
|
| 1220 |
-
else:
|
| 1221 |
-
response, voice_file = chat_with_agent(user_message, include_voice=enable_chat_voice)
|
| 1222 |
-
else:
|
| 1223 |
-
response, voice_file = chat_with_agent(user_message, include_voice=enable_chat_voice)
|
| 1224 |
-
|
| 1225 |
-
# Add to chat history
|
| 1226 |
-
st.session_state.chat_messages.append({"role": "user", "content": user_message})
|
| 1227 |
-
st.session_state.chat_messages.append({"role": "assistant", "content": response, "voice": voice_file})
|
| 1228 |
-
|
| 1229 |
-
if clear_button:
|
| 1230 |
-
st.session_state.chat_messages = []
|
| 1231 |
-
reset_chat()
|
| 1232 |
-
st.rerun()
|
| 1233 |
-
|
| 1234 |
-
# Display chat history
|
| 1235 |
-
if st.session_state.chat_messages:
|
| 1236 |
-
st.markdown("### 💭 Conversation")
|
| 1237 |
-
for msg in st.session_state.chat_messages:
|
| 1238 |
-
if msg["role"] == "user":
|
| 1239 |
-
st.markdown(f"**You:** {msg['content']}")
|
| 1240 |
-
else:
|
| 1241 |
-
st.markdown(f"**NAVADA:** {msg['content']}")
|
| 1242 |
-
if msg.get("voice"):
|
| 1243 |
-
st.audio(msg["voice"], format="audio/mp3")
|
| 1244 |
-
|
| 1245 |
-
# Footer
|
| 1246 |
-
st.markdown("---")
|
| 1247 |
-
st.markdown("""
|
| 1248 |
-
<div style="text-align: center; padding: 2rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; margin-top: 2rem;">
|
| 1249 |
-
<h3>🎉 Experience the Future of Computer Vision</h3>
|
| 1250 |
-
<p><strong>⭐ Built with passion and innovation by Lee Akpareva | © 2024 AI Innovation Lab ⭐</strong></p>
|
| 1251 |
-
<p>🚀 <em>From concept to deployment in 15 minutes - now with intelligent learning capabilities!</em></p>
|
| 1252 |
-
<p>🔗 <strong>Deployed on Hugging Face Spaces for seamless AI model demonstration</strong></p>
|
| 1253 |
-
</div>
|
| 1254 |
-
""", unsafe_allow_html=True)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
🚀 NAVADA 2.0 - Advanced AI Computer Vision Application (Lite Version)
|
| 3 |
+
Streamlit Version for Hugging Face Spaces Deployment
|
| 4 |
+
|
| 5 |
+
Enhanced Edition by Lee Akpareva | AI Consultant & Computer Vision Specialist
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
import time
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import plotly.graph_objects as go
|
| 12 |
+
import plotly.express as px
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
# Configure Streamlit page (MUST be first!)
|
| 18 |
+
st.set_page_config(
|
| 19 |
+
page_title="🚀 NAVADA 2.0 - AI Computer Vision",
|
| 20 |
+
page_icon="🚀",
|
| 21 |
+
layout="wide",
|
| 22 |
+
initial_sidebar_state="expanded"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Backend imports - Lite version (no face recognition)
|
| 26 |
+
try:
|
| 27 |
+
from backend.yolo import detect_objects
|
| 28 |
+
from backend.openai_client import explain_detection
|
| 29 |
+
except ImportError as e:
|
| 30 |
+
st.error(f"⚠️ Import error: {e}")
|
| 31 |
+
st.error("📦 Please install dependencies: pip install -r requirements.txt")
|
| 32 |
+
st.stop()
|
| 33 |
+
|
| 34 |
+
# Custom CSS for enhanced styling
|
| 35 |
+
st.markdown("""
|
| 36 |
+
<style>
|
| 37 |
+
.main-header {
|
| 38 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 39 |
+
padding: 2rem;
|
| 40 |
+
border-radius: 10px;
|
| 41 |
+
color: white;
|
| 42 |
+
text-align: center;
|
| 43 |
+
margin-bottom: 2rem;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.feature-card {
|
| 47 |
+
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
|
| 48 |
+
padding: 1.5rem;
|
| 49 |
+
border-radius: 10px;
|
| 50 |
+
color: white;
|
| 51 |
+
margin: 1rem 0;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
.stats-card {
|
| 55 |
+
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
|
| 56 |
+
padding: 1rem;
|
| 57 |
+
border-radius: 8px;
|
| 58 |
+
color: white;
|
| 59 |
+
text-align: center;
|
| 60 |
+
margin: 0.5rem;
|
| 61 |
+
}
|
| 62 |
+
</style>
|
| 63 |
+
""", unsafe_allow_html=True)
|
| 64 |
+
|
| 65 |
+
def create_detection_chart(detected_objects):
|
| 66 |
+
"""Create an interactive chart showing detection statistics"""
|
| 67 |
+
|
| 68 |
+
# Count object types
|
| 69 |
+
object_counts = {}
|
| 70 |
+
for obj in detected_objects:
|
| 71 |
+
object_counts[obj] = object_counts.get(obj, 0) + 1
|
| 72 |
+
|
| 73 |
+
if not object_counts:
|
| 74 |
+
# Create empty chart
|
| 75 |
+
fig = go.Figure()
|
| 76 |
+
fig.add_annotation(
|
| 77 |
+
text="No objects detected",
|
| 78 |
+
xref="paper", yref="paper",
|
| 79 |
+
x=0.5, y=0.5, showarrow=False,
|
| 80 |
+
font=dict(size=20, color="gray")
|
| 81 |
+
)
|
| 82 |
+
fig.update_layout(
|
| 83 |
+
height=300,
|
| 84 |
+
showlegend=False,
|
| 85 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 86 |
+
plot_bgcolor='rgba(0,0,0,0)'
|
| 87 |
+
)
|
| 88 |
+
return fig
|
| 89 |
+
|
| 90 |
+
# Create bar chart
|
| 91 |
+
objects = list(object_counts.keys())
|
| 92 |
+
counts = list(object_counts.values())
|
| 93 |
+
|
| 94 |
+
fig = go.Figure(data=[
|
| 95 |
+
go.Bar(
|
| 96 |
+
x=objects,
|
| 97 |
+
y=counts,
|
| 98 |
+
marker_color='rgba(50, 171, 96, 0.6)',
|
| 99 |
+
marker_line_color='rgba(50, 171, 96, 1.0)',
|
| 100 |
+
marker_line_width=2,
|
| 101 |
+
text=counts,
|
| 102 |
+
textposition='auto'
|
| 103 |
+
)
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
fig.update_layout(
|
| 107 |
+
title="Detected Objects",
|
| 108 |
+
xaxis_title="Object Type",
|
| 109 |
+
yaxis_title="Count",
|
| 110 |
+
height=400,
|
| 111 |
+
showlegend=False,
|
| 112 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 113 |
+
plot_bgcolor='rgba(0,0,0,0)'
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return fig
|
| 117 |
+
|
| 118 |
+
def main():
|
| 119 |
+
# Main header
|
| 120 |
+
st.markdown("""
|
| 121 |
+
<div class="main-header">
|
| 122 |
+
<h1>🚀 NAVADA 2.0 - Advanced AI Computer Vision</h1>
|
| 123 |
+
<p><strong>Lite Version - Object Detection & AI Analysis</strong></p>
|
| 124 |
+
<p>Built with YOLOv8 • OpenAI • Streamlit</p>
|
| 125 |
+
</div>
|
| 126 |
+
""", unsafe_allow_html=True)
|
| 127 |
+
|
| 128 |
+
# Sidebar
|
| 129 |
+
with st.sidebar:
|
| 130 |
+
st.markdown("### 🎯 Detection Settings")
|
| 131 |
+
|
| 132 |
+
# Detection confidence threshold
|
| 133 |
+
confidence = st.slider(
|
| 134 |
+
"Detection Confidence",
|
| 135 |
+
min_value=0.1,
|
| 136 |
+
max_value=1.0,
|
| 137 |
+
value=0.5,
|
| 138 |
+
step=0.05,
|
| 139 |
+
help="Minimum confidence for object detection"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
st.markdown("### 📊 Features")
|
| 143 |
+
st.markdown("""
|
| 144 |
+
- 🎯 **Object Detection**: YOLOv8 powered
|
| 145 |
+
- 🤖 **AI Explanations**: OpenAI integration
|
| 146 |
+
- 📈 **Interactive Charts**: Real-time analytics
|
| 147 |
+
- 🎨 **Visual Results**: Annotated images
|
| 148 |
+
""")
|
| 149 |
+
|
| 150 |
+
st.markdown("### ℹ️ About")
|
| 151 |
+
st.markdown("""
|
| 152 |
+
This is the **Lite Version** optimized for Hugging Face Spaces.
|
| 153 |
+
|
| 154 |
+
**Created by:** Lee Akpareva
|
| 155 |
+
**AI Consultant & Computer Vision Specialist**
|
| 156 |
+
""")
|
| 157 |
+
|
| 158 |
+
# Main content
|
| 159 |
+
col1, col2 = st.columns([2, 1])
|
| 160 |
+
|
| 161 |
+
with col1:
|
| 162 |
+
st.markdown("### 📸 Upload Image for Analysis")
|
| 163 |
+
|
| 164 |
+
uploaded_file = st.file_uploader(
|
| 165 |
+
"Choose an image...",
|
| 166 |
+
type=['png', 'jpg', 'jpeg'],
|
| 167 |
+
help="Upload an image to detect objects and get AI analysis"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if uploaded_file is not None:
|
| 171 |
+
# Display uploaded image
|
| 172 |
+
image = Image.open(uploaded_file)
|
| 173 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
| 174 |
+
|
| 175 |
+
# Analysis button
|
| 176 |
+
if st.button("🚀 Analyze Image", type="primary"):
|
| 177 |
+
with st.spinner("🔍 Detecting objects..."):
|
| 178 |
+
# Perform object detection
|
| 179 |
+
results = detect_objects(image, confidence_threshold=confidence)
|
| 180 |
+
|
| 181 |
+
if results and len(results['detections']) > 0:
|
| 182 |
+
# Extract detected objects
|
| 183 |
+
detected_objects = [det['class'] for det in results['detections']]
|
| 184 |
+
|
| 185 |
+
# Display results
|
| 186 |
+
st.success(f"✅ Detected {len(detected_objects)} objects!")
|
| 187 |
+
|
| 188 |
+
# Show annotated image
|
| 189 |
+
st.markdown("### 🎯 Detection Results")
|
| 190 |
+
if 'annotated_image' in results:
|
| 191 |
+
st.image(results['annotated_image'], caption="Detected Objects", use_column_width=True)
|
| 192 |
+
|
| 193 |
+
# Show detection details
|
| 194 |
+
st.markdown("### 📋 Detected Objects")
|
| 195 |
+
for i, detection in enumerate(results['detections']):
|
| 196 |
+
col_a, col_b, col_c = st.columns(3)
|
| 197 |
+
with col_a:
|
| 198 |
+
st.metric("Object", detection['class'])
|
| 199 |
+
with col_b:
|
| 200 |
+
st.metric("Confidence", f"{detection['confidence']:.2%}")
|
| 201 |
+
with col_c:
|
| 202 |
+
st.metric("Count", f"#{i+1}")
|
| 203 |
+
|
| 204 |
+
# AI Explanation
|
| 205 |
+
if os.getenv("OPENAI_API_KEY"):
|
| 206 |
+
st.markdown("### 🤖 AI Analysis")
|
| 207 |
+
with st.spinner("🧠 Generating AI explanation..."):
|
| 208 |
+
try:
|
| 209 |
+
explanation = explain_detection(detected_objects)
|
| 210 |
+
st.markdown(f"**AI Insight:** {explanation}")
|
| 211 |
+
except Exception as e:
|
| 212 |
+
st.warning(f"AI analysis unavailable: {str(e)}")
|
| 213 |
+
else:
|
| 214 |
+
st.warning("🔑 Add OPENAI_API_KEY in settings for AI explanations")
|
| 215 |
+
|
| 216 |
+
else:
|
| 217 |
+
st.warning("❌ No objects detected. Try adjusting the confidence threshold.")
|
| 218 |
+
|
| 219 |
+
with col2:
|
| 220 |
+
st.markdown("### 📊 Detection Statistics")
|
| 221 |
+
|
| 222 |
+
# Sample chart (will be updated with real data)
|
| 223 |
+
sample_data = {
|
| 224 |
+
'Object': ['Person', 'Car', 'Dog', 'Cat'],
|
| 225 |
+
'Count': [3, 2, 1, 1]
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
fig = px.bar(
|
| 229 |
+
sample_data,
|
| 230 |
+
x='Object',
|
| 231 |
+
y='Count',
|
| 232 |
+
title="Sample Detection Results",
|
| 233 |
+
color='Count',
|
| 234 |
+
color_continuous_scale='Viridis'
|
| 235 |
+
)
|
| 236 |
+
fig.update_layout(height=300)
|
| 237 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 238 |
+
|
| 239 |
+
# Feature highlights
|
| 240 |
+
st.markdown("### ✨ Key Features")
|
| 241 |
+
|
| 242 |
+
features = [
|
| 243 |
+
("🎯", "Object Detection", "Advanced YOLOv8 model"),
|
| 244 |
+
("🤖", "AI Analysis", "OpenAI explanations"),
|
| 245 |
+
("📊", "Real-time Charts", "Interactive visualizations"),
|
| 246 |
+
("🚀", "Fast Processing", "Optimized for speed")
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
for icon, title, desc in features:
|
| 250 |
+
st.markdown(f"""
|
| 251 |
+
<div style="display: flex; align-items: center; margin: 1rem 0; padding: 0.5rem; background: #f0f2f6; border-radius: 5px;">
|
| 252 |
+
<div style="font-size: 1.5rem; margin-right: 1rem;">{icon}</div>
|
| 253 |
+
<div>
|
| 254 |
+
<strong>{title}</strong><br>
|
| 255 |
+
<small>{desc}</small>
|
| 256 |
+
</div>
|
| 257 |
+
</div>
|
| 258 |
+
""", unsafe_allow_html=True)
|
| 259 |
+
|
| 260 |
+
# Footer
|
| 261 |
+
st.markdown("---")
|
| 262 |
+
st.markdown("""
|
| 263 |
+
<div style="text-align: center; padding: 2rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; margin-top: 2rem;">
|
| 264 |
+
<h3>🎉 Experience Advanced Computer Vision</h3>
|
| 265 |
+
<p><strong>⭐ Built by Lee Akpareva | AI Consultant & Computer Vision Specialist ⭐</strong></p>
|
| 266 |
+
<p>🚀 <em>Powered by YOLOv8 • OpenAI • Streamlit</em></p>
|
| 267 |
+
</div>
|
| 268 |
+
""", unsafe_allow_html=True)
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_lite.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
🚀 NAVADA 2.0 - Advanced AI Computer Vision Application (Lite Version)
|
| 3 |
+
Streamlit Version for Hugging Face Spaces Deployment
|
| 4 |
+
|
| 5 |
+
Enhanced Edition by Lee Akpareva | AI Consultant & Computer Vision Specialist
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
import time
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import plotly.graph_objects as go
|
| 12 |
+
import plotly.express as px
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import numpy as np
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
# Configure Streamlit page (MUST be first!)
|
| 18 |
+
st.set_page_config(
|
| 19 |
+
page_title="🚀 NAVADA 2.0 - AI Computer Vision",
|
| 20 |
+
page_icon="🚀",
|
| 21 |
+
layout="wide",
|
| 22 |
+
initial_sidebar_state="expanded"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Backend imports - Lite version (no face recognition)
|
| 26 |
+
try:
|
| 27 |
+
from backend.yolo import detect_objects
|
| 28 |
+
from backend.openai_client import explain_detection
|
| 29 |
+
except ImportError as e:
|
| 30 |
+
st.error(f"⚠️ Import error: {e}")
|
| 31 |
+
st.error("📦 Please install dependencies: pip install -r requirements.txt")
|
| 32 |
+
st.stop()
|
| 33 |
+
|
| 34 |
+
# Custom CSS for enhanced styling
|
| 35 |
+
st.markdown("""
|
| 36 |
+
<style>
|
| 37 |
+
.main-header {
|
| 38 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 39 |
+
padding: 2rem;
|
| 40 |
+
border-radius: 10px;
|
| 41 |
+
color: white;
|
| 42 |
+
text-align: center;
|
| 43 |
+
margin-bottom: 2rem;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
.feature-card {
|
| 47 |
+
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
|
| 48 |
+
padding: 1.5rem;
|
| 49 |
+
border-radius: 10px;
|
| 50 |
+
color: white;
|
| 51 |
+
margin: 1rem 0;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
.stats-card {
|
| 55 |
+
background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
|
| 56 |
+
padding: 1rem;
|
| 57 |
+
border-radius: 8px;
|
| 58 |
+
color: white;
|
| 59 |
+
text-align: center;
|
| 60 |
+
margin: 0.5rem;
|
| 61 |
+
}
|
| 62 |
+
</style>
|
| 63 |
+
""", unsafe_allow_html=True)
|
| 64 |
+
|
| 65 |
+
def create_detection_chart(detected_objects):
|
| 66 |
+
"""Create an interactive chart showing detection statistics"""
|
| 67 |
+
|
| 68 |
+
# Count object types
|
| 69 |
+
object_counts = {}
|
| 70 |
+
for obj in detected_objects:
|
| 71 |
+
object_counts[obj] = object_counts.get(obj, 0) + 1
|
| 72 |
+
|
| 73 |
+
if not object_counts:
|
| 74 |
+
# Create empty chart
|
| 75 |
+
fig = go.Figure()
|
| 76 |
+
fig.add_annotation(
|
| 77 |
+
text="No objects detected",
|
| 78 |
+
xref="paper", yref="paper",
|
| 79 |
+
x=0.5, y=0.5, showarrow=False,
|
| 80 |
+
font=dict(size=20, color="gray")
|
| 81 |
+
)
|
| 82 |
+
fig.update_layout(
|
| 83 |
+
height=300,
|
| 84 |
+
showlegend=False,
|
| 85 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 86 |
+
plot_bgcolor='rgba(0,0,0,0)'
|
| 87 |
+
)
|
| 88 |
+
return fig
|
| 89 |
+
|
| 90 |
+
# Create bar chart
|
| 91 |
+
objects = list(object_counts.keys())
|
| 92 |
+
counts = list(object_counts.values())
|
| 93 |
+
|
| 94 |
+
fig = go.Figure(data=[
|
| 95 |
+
go.Bar(
|
| 96 |
+
x=objects,
|
| 97 |
+
y=counts,
|
| 98 |
+
marker_color='rgba(50, 171, 96, 0.6)',
|
| 99 |
+
marker_line_color='rgba(50, 171, 96, 1.0)',
|
| 100 |
+
marker_line_width=2,
|
| 101 |
+
text=counts,
|
| 102 |
+
textposition='auto'
|
| 103 |
+
)
|
| 104 |
+
])
|
| 105 |
+
|
| 106 |
+
fig.update_layout(
|
| 107 |
+
title="Detected Objects",
|
| 108 |
+
xaxis_title="Object Type",
|
| 109 |
+
yaxis_title="Count",
|
| 110 |
+
height=400,
|
| 111 |
+
showlegend=False,
|
| 112 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 113 |
+
plot_bgcolor='rgba(0,0,0,0)'
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return fig
|
| 117 |
+
|
| 118 |
+
def main():
|
| 119 |
+
# Main header
|
| 120 |
+
st.markdown("""
|
| 121 |
+
<div class="main-header">
|
| 122 |
+
<h1>🚀 NAVADA 2.0 - Advanced AI Computer Vision</h1>
|
| 123 |
+
<p><strong>Lite Version - Object Detection & AI Analysis</strong></p>
|
| 124 |
+
<p>Built with YOLOv8 • OpenAI • Streamlit</p>
|
| 125 |
+
</div>
|
| 126 |
+
""", unsafe_allow_html=True)
|
| 127 |
+
|
| 128 |
+
# Sidebar
|
| 129 |
+
with st.sidebar:
|
| 130 |
+
st.markdown("### 🎯 Detection Settings")
|
| 131 |
+
|
| 132 |
+
# Detection confidence threshold
|
| 133 |
+
confidence = st.slider(
|
| 134 |
+
"Detection Confidence",
|
| 135 |
+
min_value=0.1,
|
| 136 |
+
max_value=1.0,
|
| 137 |
+
value=0.5,
|
| 138 |
+
step=0.05,
|
| 139 |
+
help="Minimum confidence for object detection"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
st.markdown("### 📊 Features")
|
| 143 |
+
st.markdown("""
|
| 144 |
+
- 🎯 **Object Detection**: YOLOv8 powered
|
| 145 |
+
- 🤖 **AI Explanations**: OpenAI integration
|
| 146 |
+
- 📈 **Interactive Charts**: Real-time analytics
|
| 147 |
+
- 🎨 **Visual Results**: Annotated images
|
| 148 |
+
""")
|
| 149 |
+
|
| 150 |
+
st.markdown("### ℹ️ About")
|
| 151 |
+
st.markdown("""
|
| 152 |
+
This is the **Lite Version** optimized for Hugging Face Spaces.
|
| 153 |
+
|
| 154 |
+
**Created by:** Lee Akpareva
|
| 155 |
+
**AI Consultant & Computer Vision Specialist**
|
| 156 |
+
""")
|
| 157 |
+
|
| 158 |
+
# Main content
|
| 159 |
+
col1, col2 = st.columns([2, 1])
|
| 160 |
+
|
| 161 |
+
with col1:
|
| 162 |
+
st.markdown("### 📸 Upload Image for Analysis")
|
| 163 |
+
|
| 164 |
+
uploaded_file = st.file_uploader(
|
| 165 |
+
"Choose an image...",
|
| 166 |
+
type=['png', 'jpg', 'jpeg'],
|
| 167 |
+
help="Upload an image to detect objects and get AI analysis"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if uploaded_file is not None:
|
| 171 |
+
# Display uploaded image
|
| 172 |
+
image = Image.open(uploaded_file)
|
| 173 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
| 174 |
+
|
| 175 |
+
# Analysis button
|
| 176 |
+
if st.button("🚀 Analyze Image", type="primary"):
|
| 177 |
+
with st.spinner("🔍 Detecting objects..."):
|
| 178 |
+
# Perform object detection
|
| 179 |
+
results = detect_objects(image, confidence_threshold=confidence)
|
| 180 |
+
|
| 181 |
+
if results and len(results['detections']) > 0:
|
| 182 |
+
# Extract detected objects
|
| 183 |
+
detected_objects = [det['class'] for det in results['detections']]
|
| 184 |
+
|
| 185 |
+
# Display results
|
| 186 |
+
st.success(f"✅ Detected {len(detected_objects)} objects!")
|
| 187 |
+
|
| 188 |
+
# Show annotated image
|
| 189 |
+
st.markdown("### 🎯 Detection Results")
|
| 190 |
+
if 'annotated_image' in results:
|
| 191 |
+
st.image(results['annotated_image'], caption="Detected Objects", use_column_width=True)
|
| 192 |
+
|
| 193 |
+
# Show detection details
|
| 194 |
+
st.markdown("### 📋 Detected Objects")
|
| 195 |
+
for i, detection in enumerate(results['detections']):
|
| 196 |
+
col_a, col_b, col_c = st.columns(3)
|
| 197 |
+
with col_a:
|
| 198 |
+
st.metric("Object", detection['class'])
|
| 199 |
+
with col_b:
|
| 200 |
+
st.metric("Confidence", f"{detection['confidence']:.2%}")
|
| 201 |
+
with col_c:
|
| 202 |
+
st.metric("Count", f"#{i+1}")
|
| 203 |
+
|
| 204 |
+
# AI Explanation
|
| 205 |
+
if os.getenv("OPENAI_API_KEY"):
|
| 206 |
+
st.markdown("### 🤖 AI Analysis")
|
| 207 |
+
with st.spinner("🧠 Generating AI explanation..."):
|
| 208 |
+
try:
|
| 209 |
+
explanation = explain_detection(detected_objects)
|
| 210 |
+
st.markdown(f"**AI Insight:** {explanation}")
|
| 211 |
+
except Exception as e:
|
| 212 |
+
st.warning(f"AI analysis unavailable: {str(e)}")
|
| 213 |
+
else:
|
| 214 |
+
st.warning("🔑 Add OPENAI_API_KEY in settings for AI explanations")
|
| 215 |
+
|
| 216 |
+
else:
|
| 217 |
+
st.warning("❌ No objects detected. Try adjusting the confidence threshold.")
|
| 218 |
+
|
| 219 |
+
with col2:
|
| 220 |
+
st.markdown("### 📊 Detection Statistics")
|
| 221 |
+
|
| 222 |
+
# Sample chart (will be updated with real data)
|
| 223 |
+
sample_data = {
|
| 224 |
+
'Object': ['Person', 'Car', 'Dog', 'Cat'],
|
| 225 |
+
'Count': [3, 2, 1, 1]
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
fig = px.bar(
|
| 229 |
+
sample_data,
|
| 230 |
+
x='Object',
|
| 231 |
+
y='Count',
|
| 232 |
+
title="Sample Detection Results",
|
| 233 |
+
color='Count',
|
| 234 |
+
color_continuous_scale='Viridis'
|
| 235 |
+
)
|
| 236 |
+
fig.update_layout(height=300)
|
| 237 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 238 |
+
|
| 239 |
+
# Feature highlights
|
| 240 |
+
st.markdown("### ✨ Key Features")
|
| 241 |
+
|
| 242 |
+
features = [
|
| 243 |
+
("🎯", "Object Detection", "Advanced YOLOv8 model"),
|
| 244 |
+
("🤖", "AI Analysis", "OpenAI explanations"),
|
| 245 |
+
("📊", "Real-time Charts", "Interactive visualizations"),
|
| 246 |
+
("🚀", "Fast Processing", "Optimized for speed")
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
for icon, title, desc in features:
|
| 250 |
+
st.markdown(f"""
|
| 251 |
+
<div style="display: flex; align-items: center; margin: 1rem 0; padding: 0.5rem; background: #f0f2f6; border-radius: 5px;">
|
| 252 |
+
<div style="font-size: 1.5rem; margin-right: 1rem;">{icon}</div>
|
| 253 |
+
<div>
|
| 254 |
+
<strong>{title}</strong><br>
|
| 255 |
+
<small>{desc}</small>
|
| 256 |
+
</div>
|
| 257 |
+
</div>
|
| 258 |
+
""", unsafe_allow_html=True)
|
| 259 |
+
|
| 260 |
+
# Footer
|
| 261 |
+
st.markdown("---")
|
| 262 |
+
st.markdown("""
|
| 263 |
+
<div style="text-align: center; padding: 2rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; margin-top: 2rem;">
|
| 264 |
+
<h3>🎉 Experience Advanced Computer Vision</h3>
|
| 265 |
+
<p><strong>⭐ Built by Lee Akpareva | AI Consultant & Computer Vision Specialist ⭐</strong></p>
|
| 266 |
+
<p>🚀 <em>Powered by YOLOv8 • OpenAI • Streamlit</em></p>
|
| 267 |
+
</div>
|
| 268 |
+
""", unsafe_allow_html=True)
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
main()
|
backend/chat_agent.py
CHANGED
|
@@ -1,189 +1,189 @@
|
|
| 1 |
-
"""
|
| 2 |
-
AI Chat Agent with conversation memory and text-to-speech capabilities
|
| 3 |
-
"""
|
| 4 |
-
import os
|
| 5 |
-
from openai import OpenAI # type: ignore
|
| 6 |
-
import tempfile
|
| 7 |
-
from datetime import datetime
|
| 8 |
-
import json
|
| 9 |
-
|
| 10 |
-
# Initialize OpenAI client
|
| 11 |
-
api_key = os.getenv("OPENAI_API_KEY")
|
| 12 |
-
if not api_key:
|
| 13 |
-
raise ValueError("OPENAI_API_KEY environment variable is required")
|
| 14 |
-
client = OpenAI(api_key=api_key)
|
| 15 |
-
|
| 16 |
-
class ChatAgent:
|
| 17 |
-
def __init__(self):
|
| 18 |
-
"""Initialize the chat agent with conversation memory"""
|
| 19 |
-
self.conversation_history = []
|
| 20 |
-
self.system_prompt = """You are NAVADA Assistant, an intelligent AI companion for computer vision analysis.
|
| 21 |
-
You help users understand what's in their images, answer questions about detected objects,
|
| 22 |
-
and provide insights about visual content. You're friendly, helpful, and knowledgeable about
|
| 23 |
-
computer vision, image analysis, and can discuss colors, positions, sizes, and relationships
|
| 24 |
-
between objects in images. You have access to detailed detection results including object colors,
|
| 25 |
-
positions, sizes, and confidence scores."""
|
| 26 |
-
|
| 27 |
-
# Add system message to history
|
| 28 |
-
self.conversation_history.append({
|
| 29 |
-
"role": "system",
|
| 30 |
-
"content": self.system_prompt
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
# Store context about current image analysis
|
| 34 |
-
self.current_image_context = None
|
| 35 |
-
|
| 36 |
-
def update_image_context(self, detected_objects, detailed_attributes=None):
|
| 37 |
-
"""Update the agent's knowledge about the current image"""
|
| 38 |
-
context = f"Current image analysis shows: {', '.join(detected_objects) if detected_objects else 'no objects detected'}."
|
| 39 |
-
|
| 40 |
-
if detailed_attributes:
|
| 41 |
-
context += "\n\nDetailed analysis:"
|
| 42 |
-
for attr in detailed_attributes:
|
| 43 |
-
colors = " and ".join(attr.get('colors', ['unknown'])[:2])
|
| 44 |
-
context += f"\n- {attr['label']}: {colors} color(s), {attr.get('size', 'unknown')} size, located at {attr.get('position', 'unknown')} (confidence: {attr.get('confidence', 'unknown')})"
|
| 45 |
-
|
| 46 |
-
self.current_image_context = context
|
| 47 |
-
|
| 48 |
-
# Add context to conversation as a system message
|
| 49 |
-
self.conversation_history.append({
|
| 50 |
-
"role": "system",
|
| 51 |
-
"content": f"Image context update: {context}"
|
| 52 |
-
})
|
| 53 |
-
|
| 54 |
-
def chat(self, user_message, include_voice=True):
|
| 55 |
-
"""
|
| 56 |
-
Process user message and return response with optional voice
|
| 57 |
-
|
| 58 |
-
Args:
|
| 59 |
-
user_message: The user's input message
|
| 60 |
-
include_voice: Whether to generate voice response
|
| 61 |
-
|
| 62 |
-
Returns:
|
| 63 |
-
tuple: (text_response, voice_file_path or None)
|
| 64 |
-
"""
|
| 65 |
-
# Add user message to history
|
| 66 |
-
self.conversation_history.append({
|
| 67 |
-
"role": "user",
|
| 68 |
-
"content": user_message
|
| 69 |
-
})
|
| 70 |
-
|
| 71 |
-
# Keep conversation history manageable (last 20 messages)
|
| 72 |
-
if len(self.conversation_history) > 20:
|
| 73 |
-
# Keep system prompt and current context, remove old messages
|
| 74 |
-
system_messages = [msg for msg in self.conversation_history if msg["role"] == "system"]
|
| 75 |
-
recent_messages = self.conversation_history[-15:]
|
| 76 |
-
self.conversation_history = system_messages + recent_messages
|
| 77 |
-
|
| 78 |
-
try:
|
| 79 |
-
# Get response from OpenAI
|
| 80 |
-
response = client.chat.completions.create(
|
| 81 |
-
model="gpt-4o-mini",
|
| 82 |
-
messages=self.conversation_history,
|
| 83 |
-
temperature=0.7,
|
| 84 |
-
max_tokens=500
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
text_response = response.choices[0].message.content
|
| 88 |
-
|
| 89 |
-
# Add assistant response to history
|
| 90 |
-
self.conversation_history.append({
|
| 91 |
-
"role": "assistant",
|
| 92 |
-
"content": text_response
|
| 93 |
-
})
|
| 94 |
-
|
| 95 |
-
# Generate voice if requested
|
| 96 |
-
voice_file = None
|
| 97 |
-
if include_voice:
|
| 98 |
-
voice_file = self.generate_voice(text_response)
|
| 99 |
-
|
| 100 |
-
return text_response, voice_file
|
| 101 |
-
|
| 102 |
-
except Exception as e:
|
| 103 |
-
error_msg = f"Chat error: {str(e)}"
|
| 104 |
-
return error_msg, None
|
| 105 |
-
|
| 106 |
-
def generate_voice(self, text):
|
| 107 |
-
"""Generate voice narration for text using OpenAI TTS"""
|
| 108 |
-
try:
|
| 109 |
-
# Generate speech using OpenAI TTS
|
| 110 |
-
response = client.audio.speech.create(
|
| 111 |
-
model="tts-1",
|
| 112 |
-
voice="nova", # Options: alloy, echo, fable, onyx, nova, shimmer
|
| 113 |
-
input=text,
|
| 114 |
-
response_format="mp3"
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
# Save to temporary file
|
| 118 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
|
| 119 |
-
temp_audio.write(response.content)
|
| 120 |
-
return temp_audio.name
|
| 121 |
-
|
| 122 |
-
except Exception as e:
|
| 123 |
-
print(f"Voice generation error: {e}")
|
| 124 |
-
return None
|
| 125 |
-
|
| 126 |
-
def get_conversation_summary(self):
|
| 127 |
-
"""Get a summary of the conversation"""
|
| 128 |
-
messages = [msg for msg in self.conversation_history if msg["role"] in ["user", "assistant"]]
|
| 129 |
-
return messages
|
| 130 |
-
|
| 131 |
-
def reset_conversation(self):
|
| 132 |
-
"""Reset conversation history while keeping system prompt"""
|
| 133 |
-
self.conversation_history = [{
|
| 134 |
-
"role": "system",
|
| 135 |
-
"content": self.system_prompt
|
| 136 |
-
}]
|
| 137 |
-
self.current_image_context = None
|
| 138 |
-
|
| 139 |
-
def save_conversation(self, filepath=None):
|
| 140 |
-
"""Save conversation history to file"""
|
| 141 |
-
if filepath is None:
|
| 142 |
-
filepath = f"conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 143 |
-
|
| 144 |
-
with open(filepath, 'w') as f:
|
| 145 |
-
json.dump({
|
| 146 |
-
'timestamp': datetime.now().isoformat(),
|
| 147 |
-
'conversation': self.conversation_history,
|
| 148 |
-
'image_context': self.current_image_context
|
| 149 |
-
}, f, indent=2)
|
| 150 |
-
|
| 151 |
-
return filepath
|
| 152 |
-
|
| 153 |
-
def load_conversation(self, filepath):
|
| 154 |
-
"""Load conversation history from file"""
|
| 155 |
-
with open(filepath, 'r') as f:
|
| 156 |
-
data = json.load(f)
|
| 157 |
-
self.conversation_history = data['conversation']
|
| 158 |
-
self.current_image_context = data.get('image_context')
|
| 159 |
-
|
| 160 |
-
# Create a global chat agent instance
|
| 161 |
-
chat_agent = ChatAgent()
|
| 162 |
-
|
| 163 |
-
# Helper functions for easy integration
|
| 164 |
-
def chat_with_agent(message, detected_objects=None, detailed_attributes=None, include_voice=True):
|
| 165 |
-
"""
|
| 166 |
-
Simple interface to chat with the agent
|
| 167 |
-
|
| 168 |
-
Args:
|
| 169 |
-
message: User's message
|
| 170 |
-
detected_objects: List of detected objects (optional)
|
| 171 |
-
detailed_attributes: Detailed attributes from enhanced detection (optional)
|
| 172 |
-
include_voice: Whether to generate voice response
|
| 173 |
-
|
| 174 |
-
Returns:
|
| 175 |
-
tuple: (text_response, voice_file_path or None)
|
| 176 |
-
"""
|
| 177 |
-
# Update context if new detection results provided
|
| 178 |
-
if detected_objects is not None:
|
| 179 |
-
chat_agent.update_image_context(detected_objects, detailed_attributes)
|
| 180 |
-
|
| 181 |
-
return chat_agent.chat(message, include_voice)
|
| 182 |
-
|
| 183 |
-
def reset_chat():
|
| 184 |
-
"""Reset the chat conversation"""
|
| 185 |
-
chat_agent.reset_conversation()
|
| 186 |
-
|
| 187 |
-
def get_chat_history():
|
| 188 |
-
"""Get the current chat history"""
|
| 189 |
return chat_agent.get_conversation_summary()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AI Chat Agent with conversation memory and text-to-speech capabilities
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from openai import OpenAI # type: ignore
|
| 6 |
+
import tempfile
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
# Initialize OpenAI client
|
| 11 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 12 |
+
if not api_key:
|
| 13 |
+
raise ValueError("OPENAI_API_KEY environment variable is required")
|
| 14 |
+
client = OpenAI(api_key=api_key)
|
| 15 |
+
|
| 16 |
+
class ChatAgent:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
"""Initialize the chat agent with conversation memory"""
|
| 19 |
+
self.conversation_history = []
|
| 20 |
+
self.system_prompt = """You are NAVADA Assistant, an intelligent AI companion for computer vision analysis.
|
| 21 |
+
You help users understand what's in their images, answer questions about detected objects,
|
| 22 |
+
and provide insights about visual content. You're friendly, helpful, and knowledgeable about
|
| 23 |
+
computer vision, image analysis, and can discuss colors, positions, sizes, and relationships
|
| 24 |
+
between objects in images. You have access to detailed detection results including object colors,
|
| 25 |
+
positions, sizes, and confidence scores."""
|
| 26 |
+
|
| 27 |
+
# Add system message to history
|
| 28 |
+
self.conversation_history.append({
|
| 29 |
+
"role": "system",
|
| 30 |
+
"content": self.system_prompt
|
| 31 |
+
})
|
| 32 |
+
|
| 33 |
+
# Store context about current image analysis
|
| 34 |
+
self.current_image_context = None
|
| 35 |
+
|
| 36 |
+
def update_image_context(self, detected_objects, detailed_attributes=None):
|
| 37 |
+
"""Update the agent's knowledge about the current image"""
|
| 38 |
+
context = f"Current image analysis shows: {', '.join(detected_objects) if detected_objects else 'no objects detected'}."
|
| 39 |
+
|
| 40 |
+
if detailed_attributes:
|
| 41 |
+
context += "\n\nDetailed analysis:"
|
| 42 |
+
for attr in detailed_attributes:
|
| 43 |
+
colors = " and ".join(attr.get('colors', ['unknown'])[:2])
|
| 44 |
+
context += f"\n- {attr['label']}: {colors} color(s), {attr.get('size', 'unknown')} size, located at {attr.get('position', 'unknown')} (confidence: {attr.get('confidence', 'unknown')})"
|
| 45 |
+
|
| 46 |
+
self.current_image_context = context
|
| 47 |
+
|
| 48 |
+
# Add context to conversation as a system message
|
| 49 |
+
self.conversation_history.append({
|
| 50 |
+
"role": "system",
|
| 51 |
+
"content": f"Image context update: {context}"
|
| 52 |
+
})
|
| 53 |
+
|
| 54 |
+
def chat(self, user_message, include_voice=True):
|
| 55 |
+
"""
|
| 56 |
+
Process user message and return response with optional voice
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
user_message: The user's input message
|
| 60 |
+
include_voice: Whether to generate voice response
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
tuple: (text_response, voice_file_path or None)
|
| 64 |
+
"""
|
| 65 |
+
# Add user message to history
|
| 66 |
+
self.conversation_history.append({
|
| 67 |
+
"role": "user",
|
| 68 |
+
"content": user_message
|
| 69 |
+
})
|
| 70 |
+
|
| 71 |
+
# Keep conversation history manageable (last 20 messages)
|
| 72 |
+
if len(self.conversation_history) > 20:
|
| 73 |
+
# Keep system prompt and current context, remove old messages
|
| 74 |
+
system_messages = [msg for msg in self.conversation_history if msg["role"] == "system"]
|
| 75 |
+
recent_messages = self.conversation_history[-15:]
|
| 76 |
+
self.conversation_history = system_messages + recent_messages
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
# Get response from OpenAI
|
| 80 |
+
response = client.chat.completions.create(
|
| 81 |
+
model="gpt-4o-mini",
|
| 82 |
+
messages=self.conversation_history,
|
| 83 |
+
temperature=0.7,
|
| 84 |
+
max_tokens=500
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
text_response = response.choices[0].message.content
|
| 88 |
+
|
| 89 |
+
# Add assistant response to history
|
| 90 |
+
self.conversation_history.append({
|
| 91 |
+
"role": "assistant",
|
| 92 |
+
"content": text_response
|
| 93 |
+
})
|
| 94 |
+
|
| 95 |
+
# Generate voice if requested
|
| 96 |
+
voice_file = None
|
| 97 |
+
if include_voice:
|
| 98 |
+
voice_file = self.generate_voice(text_response)
|
| 99 |
+
|
| 100 |
+
return text_response, voice_file
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
error_msg = f"Chat error: {str(e)}"
|
| 104 |
+
return error_msg, None
|
| 105 |
+
|
| 106 |
+
def generate_voice(self, text):
|
| 107 |
+
"""Generate voice narration for text using OpenAI TTS"""
|
| 108 |
+
try:
|
| 109 |
+
# Generate speech using OpenAI TTS
|
| 110 |
+
response = client.audio.speech.create(
|
| 111 |
+
model="tts-1",
|
| 112 |
+
voice="nova", # Options: alloy, echo, fable, onyx, nova, shimmer
|
| 113 |
+
input=text,
|
| 114 |
+
response_format="mp3"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Save to temporary file
|
| 118 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
|
| 119 |
+
temp_audio.write(response.content)
|
| 120 |
+
return temp_audio.name
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f"Voice generation error: {e}")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
def get_conversation_summary(self):
|
| 127 |
+
"""Get a summary of the conversation"""
|
| 128 |
+
messages = [msg for msg in self.conversation_history if msg["role"] in ["user", "assistant"]]
|
| 129 |
+
return messages
|
| 130 |
+
|
| 131 |
+
def reset_conversation(self):
|
| 132 |
+
"""Reset conversation history while keeping system prompt"""
|
| 133 |
+
self.conversation_history = [{
|
| 134 |
+
"role": "system",
|
| 135 |
+
"content": self.system_prompt
|
| 136 |
+
}]
|
| 137 |
+
self.current_image_context = None
|
| 138 |
+
|
| 139 |
+
def save_conversation(self, filepath=None):
|
| 140 |
+
"""Save conversation history to file"""
|
| 141 |
+
if filepath is None:
|
| 142 |
+
filepath = f"conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 143 |
+
|
| 144 |
+
with open(filepath, 'w') as f:
|
| 145 |
+
json.dump({
|
| 146 |
+
'timestamp': datetime.now().isoformat(),
|
| 147 |
+
'conversation': self.conversation_history,
|
| 148 |
+
'image_context': self.current_image_context
|
| 149 |
+
}, f, indent=2)
|
| 150 |
+
|
| 151 |
+
return filepath
|
| 152 |
+
|
| 153 |
+
def load_conversation(self, filepath):
|
| 154 |
+
"""Load conversation history from file"""
|
| 155 |
+
with open(filepath, 'r') as f:
|
| 156 |
+
data = json.load(f)
|
| 157 |
+
self.conversation_history = data['conversation']
|
| 158 |
+
self.current_image_context = data.get('image_context')
|
| 159 |
+
|
| 160 |
+
# Create a global chat agent instance
|
| 161 |
+
chat_agent = ChatAgent()
|
| 162 |
+
|
| 163 |
+
# Helper functions for easy integration
|
| 164 |
+
def chat_with_agent(message, detected_objects=None, detailed_attributes=None, include_voice=True):
|
| 165 |
+
"""
|
| 166 |
+
Simple interface to chat with the agent
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
message: User's message
|
| 170 |
+
detected_objects: List of detected objects (optional)
|
| 171 |
+
detailed_attributes: Detailed attributes from enhanced detection (optional)
|
| 172 |
+
include_voice: Whether to generate voice response
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
tuple: (text_response, voice_file_path or None)
|
| 176 |
+
"""
|
| 177 |
+
# Update context if new detection results provided
|
| 178 |
+
if detected_objects is not None:
|
| 179 |
+
chat_agent.update_image_context(detected_objects, detailed_attributes)
|
| 180 |
+
|
| 181 |
+
return chat_agent.chat(message, include_voice)
|
| 182 |
+
|
| 183 |
+
def reset_chat():
|
| 184 |
+
"""Reset the chat conversation"""
|
| 185 |
+
chat_agent.reset_conversation()
|
| 186 |
+
|
| 187 |
+
def get_chat_history():
|
| 188 |
+
"""Get the current chat history"""
|
| 189 |
return chat_agent.get_conversation_summary()
|
backend/custom_trainer.py
CHANGED
|
@@ -1,399 +1,399 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Custom Object Classifier Training Module
|
| 3 |
-
Implements transfer learning for user feedback corrections
|
| 4 |
-
"""
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torchvision.transforms as transforms
|
| 8 |
-
from torchvision import models
|
| 9 |
-
import numpy as np
|
| 10 |
-
import cv2
|
| 11 |
-
from torch.utils.data import Dataset, DataLoader
|
| 12 |
-
# from sklearn.model_selection import train_test_split # Temporarily disabled due to numpy compatibility
|
| 13 |
-
# from sklearn.metrics import accuracy_score, precision_recall_fscore_support # Temporarily disabled
|
| 14 |
-
import pickle
|
| 15 |
-
import json
|
| 16 |
-
from typing import List, Dict, Tuple, Optional
|
| 17 |
-
from datetime import datetime
|
| 18 |
-
import logging
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
|
| 21 |
-
# Configure logging
|
| 22 |
-
logger = logging.getLogger(__name__)
|
| 23 |
-
|
| 24 |
-
class CustomObjectDataset(Dataset):
|
| 25 |
-
"""Dataset class for custom object training"""
|
| 26 |
-
|
| 27 |
-
def __init__(self, data: List[Dict], transform=None):
|
| 28 |
-
"""
|
| 29 |
-
Initialize dataset with training data
|
| 30 |
-
|
| 31 |
-
Args:
|
| 32 |
-
data: List of training samples from database
|
| 33 |
-
transform: Image transformations
|
| 34 |
-
"""
|
| 35 |
-
self.data = data
|
| 36 |
-
self.transform = transform
|
| 37 |
-
|
| 38 |
-
# Create label mapping
|
| 39 |
-
unique_labels = list(set([sample['correct_label'] for sample in data]))
|
| 40 |
-
self.label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
|
| 41 |
-
self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}
|
| 42 |
-
self.num_classes = len(unique_labels)
|
| 43 |
-
|
| 44 |
-
def __len__(self):
|
| 45 |
-
return len(self.data)
|
| 46 |
-
|
| 47 |
-
def __getitem__(self, idx):
|
| 48 |
-
sample = self.data[idx]
|
| 49 |
-
image = sample['image']
|
| 50 |
-
label = sample['correct_label']
|
| 51 |
-
|
| 52 |
-
# Convert BGR to RGB
|
| 53 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 54 |
-
|
| 55 |
-
if self.transform:
|
| 56 |
-
image = self.transform(image)
|
| 57 |
-
|
| 58 |
-
label_idx = self.label_to_idx[label]
|
| 59 |
-
|
| 60 |
-
return {
|
| 61 |
-
'image': image,
|
| 62 |
-
'label': label_idx,
|
| 63 |
-
'original_label': label,
|
| 64 |
-
'yolo_prediction': sample['yolo_prediction'],
|
| 65 |
-
'confidence': sample['yolo_confidence'],
|
| 66 |
-
'difficulty': sample['difficulty_score']
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
class CustomClassifier(nn.Module):
|
| 70 |
-
"""Custom classifier built on pre-trained backbone"""
|
| 71 |
-
|
| 72 |
-
def __init__(self, num_classes: int, backbone='resnet18', pretrained=True):
|
| 73 |
-
"""
|
| 74 |
-
Initialize custom classifier
|
| 75 |
-
|
| 76 |
-
Args:
|
| 77 |
-
num_classes: Number of output classes
|
| 78 |
-
backbone: Backbone architecture (resnet18, resnet50, efficientnet_b0)
|
| 79 |
-
pretrained: Use pre-trained weights
|
| 80 |
-
"""
|
| 81 |
-
super(CustomClassifier, self).__init__()
|
| 82 |
-
|
| 83 |
-
self.num_classes = num_classes
|
| 84 |
-
self.backbone = backbone
|
| 85 |
-
|
| 86 |
-
if backbone == 'resnet18':
|
| 87 |
-
self.model = models.resnet18(pretrained=pretrained)
|
| 88 |
-
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
|
| 89 |
-
elif backbone == 'resnet50':
|
| 90 |
-
self.model = models.resnet50(pretrained=pretrained)
|
| 91 |
-
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
|
| 92 |
-
elif backbone == 'efficientnet_b0':
|
| 93 |
-
self.model = models.efficientnet_b0(pretrained=pretrained)
|
| 94 |
-
self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)
|
| 95 |
-
else:
|
| 96 |
-
raise ValueError(f"Unsupported backbone: {backbone}")
|
| 97 |
-
|
| 98 |
-
def forward(self, x):
|
| 99 |
-
return self.model(x)
|
| 100 |
-
|
| 101 |
-
class CustomTrainer:
|
| 102 |
-
"""Trainer class for custom object classification"""
|
| 103 |
-
|
| 104 |
-
def __init__(self, model_dir='models/', device=None):
|
| 105 |
-
"""
|
| 106 |
-
Initialize trainer
|
| 107 |
-
|
| 108 |
-
Args:
|
| 109 |
-
model_dir: Directory to save models
|
| 110 |
-
device: Training device (cuda/cpu)
|
| 111 |
-
"""
|
| 112 |
-
self.model_dir = Path(model_dir)
|
| 113 |
-
self.model_dir.mkdir(exist_ok=True)
|
| 114 |
-
|
| 115 |
-
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
| 116 |
-
logger.info(f"Using device: {self.device}")
|
| 117 |
-
|
| 118 |
-
# Image transformations
|
| 119 |
-
self.train_transform = transforms.Compose([
|
| 120 |
-
transforms.ToPILImage(),
|
| 121 |
-
transforms.Resize((224, 224)),
|
| 122 |
-
transforms.RandomHorizontalFlip(0.5),
|
| 123 |
-
transforms.RandomRotation(10),
|
| 124 |
-
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
| 125 |
-
transforms.ToTensor(),
|
| 126 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 127 |
-
])
|
| 128 |
-
|
| 129 |
-
self.val_transform = transforms.Compose([
|
| 130 |
-
transforms.ToPILImage(),
|
| 131 |
-
transforms.Resize((224, 224)),
|
| 132 |
-
transforms.ToTensor(),
|
| 133 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 134 |
-
])
|
| 135 |
-
|
| 136 |
-
def prepare_data(self, training_data: List[Dict], test_size=0.2, min_samples_per_class=5):
|
| 137 |
-
"""
|
| 138 |
-
Prepare training and validation data
|
| 139 |
-
|
| 140 |
-
Args:
|
| 141 |
-
training_data: List of training samples from database
|
| 142 |
-
test_size: Fraction for validation split
|
| 143 |
-
min_samples_per_class: Minimum samples required per class
|
| 144 |
-
|
| 145 |
-
Returns:
|
| 146 |
-
Tuple of (train_dataset, val_dataset, class_info)
|
| 147 |
-
"""
|
| 148 |
-
# Filter classes with insufficient samples
|
| 149 |
-
class_counts = {}
|
| 150 |
-
for sample in training_data:
|
| 151 |
-
label = sample['correct_label']
|
| 152 |
-
class_counts[label] = class_counts.get(label, 0) + 1
|
| 153 |
-
|
| 154 |
-
# Remove classes with insufficient samples
|
| 155 |
-
valid_classes = {label for label, count in class_counts.items()
|
| 156 |
-
if count >= min_samples_per_class}
|
| 157 |
-
|
| 158 |
-
filtered_data = [sample for sample in training_data
|
| 159 |
-
if sample['correct_label'] in valid_classes]
|
| 160 |
-
|
| 161 |
-
if len(filtered_data) < 10:
|
| 162 |
-
raise ValueError(f"Insufficient training data: {len(filtered_data)} samples")
|
| 163 |
-
|
| 164 |
-
if len(valid_classes) < 2:
|
| 165 |
-
raise ValueError(f"Need at least 2 classes, got {len(valid_classes)}")
|
| 166 |
-
|
| 167 |
-
# Simple train/val split without sklearn
|
| 168 |
-
np.random.seed(42)
|
| 169 |
-
indices = np.random.permutation(len(filtered_data))
|
| 170 |
-
split_idx = int(len(filtered_data) * (1 - test_size))
|
| 171 |
-
|
| 172 |
-
train_indices = indices[:split_idx]
|
| 173 |
-
val_indices = indices[split_idx:]
|
| 174 |
-
|
| 175 |
-
train_data = [filtered_data[i] for i in train_indices]
|
| 176 |
-
val_data = [filtered_data[i] for i in val_indices]
|
| 177 |
-
|
| 178 |
-
# Create datasets
|
| 179 |
-
train_dataset = CustomObjectDataset(train_data, self.train_transform)
|
| 180 |
-
val_dataset = CustomObjectDataset(val_data, self.val_transform)
|
| 181 |
-
|
| 182 |
-
# Ensure same label mapping
|
| 183 |
-
val_dataset.label_to_idx = train_dataset.label_to_idx
|
| 184 |
-
val_dataset.idx_to_label = train_dataset.idx_to_label
|
| 185 |
-
val_dataset.num_classes = train_dataset.num_classes
|
| 186 |
-
|
| 187 |
-
class_info = {
|
| 188 |
-
'num_classes': train_dataset.num_classes,
|
| 189 |
-
'label_to_idx': train_dataset.label_to_idx,
|
| 190 |
-
'idx_to_label': train_dataset.idx_to_label,
|
| 191 |
-
'class_counts': class_counts,
|
| 192 |
-
'valid_classes': list(valid_classes),
|
| 193 |
-
'train_samples': len(train_data),
|
| 194 |
-
'val_samples': len(val_data)
|
| 195 |
-
}
|
| 196 |
-
|
| 197 |
-
return train_dataset, val_dataset, class_info
|
| 198 |
-
|
| 199 |
-
def train_model(self, training_data: List[Dict],
|
| 200 |
-
epochs=20, batch_size=16, learning_rate=0.001,
|
| 201 |
-
backbone='resnet18', patience=5) -> Dict:
|
| 202 |
-
"""
|
| 203 |
-
Train custom classifier
|
| 204 |
-
|
| 205 |
-
Args:
|
| 206 |
-
training_data: Training samples from database
|
| 207 |
-
epochs: Number of training epochs
|
| 208 |
-
batch_size: Batch size for training
|
| 209 |
-
learning_rate: Learning rate
|
| 210 |
-
backbone: Model backbone architecture
|
| 211 |
-
patience: Early stopping patience
|
| 212 |
-
|
| 213 |
-
Returns:
|
| 214 |
-
Training results and metrics
|
| 215 |
-
"""
|
| 216 |
-
try:
|
| 217 |
-
# Prepare data
|
| 218 |
-
train_dataset, val_dataset, class_info = self.prepare_data(training_data)
|
| 219 |
-
|
| 220 |
-
# Create data loaders
|
| 221 |
-
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 222 |
-
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 223 |
-
|
| 224 |
-
# Initialize model
|
| 225 |
-
model = CustomClassifier(class_info['num_classes'], backbone)
|
| 226 |
-
model = model.to(self.device)
|
| 227 |
-
|
| 228 |
-
# Loss and optimizer
|
| 229 |
-
criterion = nn.CrossEntropyLoss()
|
| 230 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 231 |
-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
|
| 232 |
-
|
| 233 |
-
# Training history
|
| 234 |
-
history = {
|
| 235 |
-
'train_loss': [],
|
| 236 |
-
'train_acc': [],
|
| 237 |
-
'val_loss': [],
|
| 238 |
-
'val_acc': []
|
| 239 |
-
}
|
| 240 |
-
|
| 241 |
-
best_val_acc = 0.0
|
| 242 |
-
patience_counter = 0
|
| 243 |
-
|
| 244 |
-
logger.info(f"Starting training: {epochs} epochs, {class_info['num_classes']} classes")
|
| 245 |
-
|
| 246 |
-
for epoch in range(epochs):
|
| 247 |
-
# Training phase
|
| 248 |
-
model.train()
|
| 249 |
-
train_loss = 0.0
|
| 250 |
-
train_correct = 0
|
| 251 |
-
train_total = 0
|
| 252 |
-
|
| 253 |
-
for batch in train_loader:
|
| 254 |
-
images = batch['image'].to(self.device)
|
| 255 |
-
labels = batch['label'].to(self.device)
|
| 256 |
-
|
| 257 |
-
optimizer.zero_grad()
|
| 258 |
-
outputs = model(images)
|
| 259 |
-
loss = criterion(outputs, labels)
|
| 260 |
-
loss.backward()
|
| 261 |
-
optimizer.step()
|
| 262 |
-
|
| 263 |
-
train_loss += loss.item()
|
| 264 |
-
_, predicted = torch.max(outputs.data, 1)
|
| 265 |
-
train_total += labels.size(0)
|
| 266 |
-
train_correct += (predicted == labels).sum().item()
|
| 267 |
-
|
| 268 |
-
train_acc = train_correct / train_total
|
| 269 |
-
avg_train_loss = train_loss / len(train_loader)
|
| 270 |
-
|
| 271 |
-
# Validation phase
|
| 272 |
-
model.eval()
|
| 273 |
-
val_loss = 0.0
|
| 274 |
-
val_correct = 0
|
| 275 |
-
val_total = 0
|
| 276 |
-
|
| 277 |
-
with torch.no_grad():
|
| 278 |
-
for batch in val_loader:
|
| 279 |
-
images = batch['image'].to(self.device)
|
| 280 |
-
labels = batch['label'].to(self.device)
|
| 281 |
-
|
| 282 |
-
outputs = model(images)
|
| 283 |
-
loss = criterion(outputs, labels)
|
| 284 |
-
|
| 285 |
-
val_loss += loss.item()
|
| 286 |
-
_, predicted = torch.max(outputs.data, 1)
|
| 287 |
-
val_total += labels.size(0)
|
| 288 |
-
val_correct += (predicted == labels).sum().item()
|
| 289 |
-
|
| 290 |
-
val_acc = val_correct / val_total
|
| 291 |
-
avg_val_loss = val_loss / len(val_loader)
|
| 292 |
-
|
| 293 |
-
# Update history
|
| 294 |
-
history['train_loss'].append(avg_train_loss)
|
| 295 |
-
history['train_acc'].append(train_acc)
|
| 296 |
-
history['val_loss'].append(avg_val_loss)
|
| 297 |
-
history['val_acc'].append(val_acc)
|
| 298 |
-
|
| 299 |
-
logger.info(f"Epoch {epoch+1}/{epochs}: "
|
| 300 |
-
f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, "
|
| 301 |
-
f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")
|
| 302 |
-
|
| 303 |
-
# Early stopping
|
| 304 |
-
if val_acc > best_val_acc:
|
| 305 |
-
best_val_acc = val_acc
|
| 306 |
-
patience_counter = 0
|
| 307 |
-
# Save best model
|
| 308 |
-
torch.save(model.state_dict(), self.model_dir / 'best_model.pth')
|
| 309 |
-
else:
|
| 310 |
-
patience_counter += 1
|
| 311 |
-
|
| 312 |
-
if patience_counter >= patience:
|
| 313 |
-
logger.info(f"Early stopping at epoch {epoch+1}")
|
| 314 |
-
break
|
| 315 |
-
|
| 316 |
-
scheduler.step()
|
| 317 |
-
|
| 318 |
-
# Load best model
|
| 319 |
-
model.load_state_dict(torch.load(self.model_dir / 'best_model.pth'))
|
| 320 |
-
|
| 321 |
-
# Final evaluation
|
| 322 |
-
final_metrics = self.evaluate_model(model, val_loader, class_info)
|
| 323 |
-
|
| 324 |
-
# Save model and metadata
|
| 325 |
-
model_info = {
|
| 326 |
-
'model_state': model.state_dict(),
|
| 327 |
-
'class_info': class_info,
|
| 328 |
-
'training_config': {
|
| 329 |
-
'backbone': backbone,
|
| 330 |
-
'epochs': epochs,
|
| 331 |
-
'batch_size': batch_size,
|
| 332 |
-
'learning_rate': learning_rate
|
| 333 |
-
},
|
| 334 |
-
'history': history,
|
| 335 |
-
'metrics': final_metrics,
|
| 336 |
-
'timestamp': datetime.now().isoformat()
|
| 337 |
-
}
|
| 338 |
-
|
| 339 |
-
# Save complete model info
|
| 340 |
-
model_path = self.model_dir / f'custom_classifier_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pkl'
|
| 341 |
-
with open(model_path, 'wb') as f:
|
| 342 |
-
pickle.dump(model_info, f)
|
| 343 |
-
|
| 344 |
-
logger.info(f"Training completed. Best validation accuracy: {best_val_acc:.4f}")
|
| 345 |
-
logger.info(f"Model saved to: {model_path}")
|
| 346 |
-
|
| 347 |
-
return {
|
| 348 |
-
'success': True,
|
| 349 |
-
'model_path': str(model_path),
|
| 350 |
-
'best_accuracy': best_val_acc,
|
| 351 |
-
'final_metrics': final_metrics,
|
| 352 |
-
'class_info': class_info,
|
| 353 |
-
'history': history
|
| 354 |
-
}
|
| 355 |
-
|
| 356 |
-
except Exception as e:
|
| 357 |
-
logger.error(f"Training failed: {e}")
|
| 358 |
-
return {
|
| 359 |
-
'success': False,
|
| 360 |
-
'error': str(e)
|
| 361 |
-
}
|
| 362 |
-
|
| 363 |
-
def evaluate_model(self, model, val_loader, class_info) -> Dict:
|
| 364 |
-
"""Evaluate model performance"""
|
| 365 |
-
model.eval()
|
| 366 |
-
all_predictions = []
|
| 367 |
-
all_labels = []
|
| 368 |
-
all_confidences = []
|
| 369 |
-
|
| 370 |
-
with torch.no_grad():
|
| 371 |
-
for batch in val_loader:
|
| 372 |
-
images = batch['image'].to(self.device)
|
| 373 |
-
labels = batch['label']
|
| 374 |
-
|
| 375 |
-
outputs = model(images)
|
| 376 |
-
probabilities = torch.softmax(outputs, dim=1)
|
| 377 |
-
confidences, predicted = torch.max(probabilities, 1)
|
| 378 |
-
|
| 379 |
-
all_predictions.extend(predicted.cpu().numpy())
|
| 380 |
-
all_labels.extend(labels.numpy())
|
| 381 |
-
all_confidences.extend(confidences.cpu().numpy())
|
| 382 |
-
|
| 383 |
-
# Calculate metrics manually without sklearn
|
| 384 |
-
accuracy = sum(1 for true, pred in zip(all_labels, all_predictions) if true == pred) / len(all_labels)
|
| 385 |
-
|
| 386 |
-
# Simple precision/recall calculation
|
| 387 |
-
precision = recall = f1 = accuracy # Simplified for now
|
| 388 |
-
|
| 389 |
-
return {
|
| 390 |
-
'accuracy': float(accuracy),
|
| 391 |
-
'precision': float(precision),
|
| 392 |
-
'recall': float(recall),
|
| 393 |
-
'f1_score': float(f1),
|
| 394 |
-
'avg_confidence': float(np.mean(all_confidences)),
|
| 395 |
-
'num_samples': len(all_labels)
|
| 396 |
-
}
|
| 397 |
-
|
| 398 |
-
# Global trainer instance
|
| 399 |
custom_trainer = CustomTrainer()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Object Classifier Training Module
|
| 3 |
+
Implements transfer learning for user feedback corrections
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
from torchvision import models
|
| 9 |
+
import numpy as np
|
| 10 |
+
import cv2
|
| 11 |
+
from torch.utils.data import Dataset, DataLoader
|
| 12 |
+
# from sklearn.model_selection import train_test_split # Temporarily disabled due to numpy compatibility
|
| 13 |
+
# from sklearn.metrics import accuracy_score, precision_recall_fscore_support # Temporarily disabled
|
| 14 |
+
import pickle
|
| 15 |
+
import json
|
| 16 |
+
from typing import List, Dict, Tuple, Optional
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
import logging
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# Configure logging
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
class CustomObjectDataset(Dataset):
|
| 25 |
+
"""Dataset class for custom object training"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, data: List[Dict], transform=None):
|
| 28 |
+
"""
|
| 29 |
+
Initialize dataset with training data
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
data: List of training samples from database
|
| 33 |
+
transform: Image transformations
|
| 34 |
+
"""
|
| 35 |
+
self.data = data
|
| 36 |
+
self.transform = transform
|
| 37 |
+
|
| 38 |
+
# Create label mapping
|
| 39 |
+
unique_labels = list(set([sample['correct_label'] for sample in data]))
|
| 40 |
+
self.label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
|
| 41 |
+
self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}
|
| 42 |
+
self.num_classes = len(unique_labels)
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return len(self.data)
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, idx):
|
| 48 |
+
sample = self.data[idx]
|
| 49 |
+
image = sample['image']
|
| 50 |
+
label = sample['correct_label']
|
| 51 |
+
|
| 52 |
+
# Convert BGR to RGB
|
| 53 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 54 |
+
|
| 55 |
+
if self.transform:
|
| 56 |
+
image = self.transform(image)
|
| 57 |
+
|
| 58 |
+
label_idx = self.label_to_idx[label]
|
| 59 |
+
|
| 60 |
+
return {
|
| 61 |
+
'image': image,
|
| 62 |
+
'label': label_idx,
|
| 63 |
+
'original_label': label,
|
| 64 |
+
'yolo_prediction': sample['yolo_prediction'],
|
| 65 |
+
'confidence': sample['yolo_confidence'],
|
| 66 |
+
'difficulty': sample['difficulty_score']
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
class CustomClassifier(nn.Module):
|
| 70 |
+
"""Custom classifier built on pre-trained backbone"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, num_classes: int, backbone='resnet18', pretrained=True):
|
| 73 |
+
"""
|
| 74 |
+
Initialize custom classifier
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
num_classes: Number of output classes
|
| 78 |
+
backbone: Backbone architecture (resnet18, resnet50, efficientnet_b0)
|
| 79 |
+
pretrained: Use pre-trained weights
|
| 80 |
+
"""
|
| 81 |
+
super(CustomClassifier, self).__init__()
|
| 82 |
+
|
| 83 |
+
self.num_classes = num_classes
|
| 84 |
+
self.backbone = backbone
|
| 85 |
+
|
| 86 |
+
if backbone == 'resnet18':
|
| 87 |
+
self.model = models.resnet18(pretrained=pretrained)
|
| 88 |
+
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
|
| 89 |
+
elif backbone == 'resnet50':
|
| 90 |
+
self.model = models.resnet50(pretrained=pretrained)
|
| 91 |
+
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
|
| 92 |
+
elif backbone == 'efficientnet_b0':
|
| 93 |
+
self.model = models.efficientnet_b0(pretrained=pretrained)
|
| 94 |
+
self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"Unsupported backbone: {backbone}")
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
return self.model(x)
|
| 100 |
+
|
| 101 |
+
class CustomTrainer:
|
| 102 |
+
"""Trainer class for custom object classification"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, model_dir='models/', device=None):
|
| 105 |
+
"""
|
| 106 |
+
Initialize trainer
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
model_dir: Directory to save models
|
| 110 |
+
device: Training device (cuda/cpu)
|
| 111 |
+
"""
|
| 112 |
+
self.model_dir = Path(model_dir)
|
| 113 |
+
self.model_dir.mkdir(exist_ok=True)
|
| 114 |
+
|
| 115 |
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
| 116 |
+
logger.info(f"Using device: {self.device}")
|
| 117 |
+
|
| 118 |
+
# Image transformations
|
| 119 |
+
self.train_transform = transforms.Compose([
|
| 120 |
+
transforms.ToPILImage(),
|
| 121 |
+
transforms.Resize((224, 224)),
|
| 122 |
+
transforms.RandomHorizontalFlip(0.5),
|
| 123 |
+
transforms.RandomRotation(10),
|
| 124 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
| 125 |
+
transforms.ToTensor(),
|
| 126 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 127 |
+
])
|
| 128 |
+
|
| 129 |
+
self.val_transform = transforms.Compose([
|
| 130 |
+
transforms.ToPILImage(),
|
| 131 |
+
transforms.Resize((224, 224)),
|
| 132 |
+
transforms.ToTensor(),
|
| 133 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 134 |
+
])
|
| 135 |
+
|
| 136 |
+
def prepare_data(self, training_data: List[Dict], test_size=0.2, min_samples_per_class=5):
|
| 137 |
+
"""
|
| 138 |
+
Prepare training and validation data
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
training_data: List of training samples from database
|
| 142 |
+
test_size: Fraction for validation split
|
| 143 |
+
min_samples_per_class: Minimum samples required per class
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Tuple of (train_dataset, val_dataset, class_info)
|
| 147 |
+
"""
|
| 148 |
+
# Filter classes with insufficient samples
|
| 149 |
+
class_counts = {}
|
| 150 |
+
for sample in training_data:
|
| 151 |
+
label = sample['correct_label']
|
| 152 |
+
class_counts[label] = class_counts.get(label, 0) + 1
|
| 153 |
+
|
| 154 |
+
# Remove classes with insufficient samples
|
| 155 |
+
valid_classes = {label for label, count in class_counts.items()
|
| 156 |
+
if count >= min_samples_per_class}
|
| 157 |
+
|
| 158 |
+
filtered_data = [sample for sample in training_data
|
| 159 |
+
if sample['correct_label'] in valid_classes]
|
| 160 |
+
|
| 161 |
+
if len(filtered_data) < 10:
|
| 162 |
+
raise ValueError(f"Insufficient training data: {len(filtered_data)} samples")
|
| 163 |
+
|
| 164 |
+
if len(valid_classes) < 2:
|
| 165 |
+
raise ValueError(f"Need at least 2 classes, got {len(valid_classes)}")
|
| 166 |
+
|
| 167 |
+
# Simple train/val split without sklearn
|
| 168 |
+
np.random.seed(42)
|
| 169 |
+
indices = np.random.permutation(len(filtered_data))
|
| 170 |
+
split_idx = int(len(filtered_data) * (1 - test_size))
|
| 171 |
+
|
| 172 |
+
train_indices = indices[:split_idx]
|
| 173 |
+
val_indices = indices[split_idx:]
|
| 174 |
+
|
| 175 |
+
train_data = [filtered_data[i] for i in train_indices]
|
| 176 |
+
val_data = [filtered_data[i] for i in val_indices]
|
| 177 |
+
|
| 178 |
+
# Create datasets
|
| 179 |
+
train_dataset = CustomObjectDataset(train_data, self.train_transform)
|
| 180 |
+
val_dataset = CustomObjectDataset(val_data, self.val_transform)
|
| 181 |
+
|
| 182 |
+
# Ensure same label mapping
|
| 183 |
+
val_dataset.label_to_idx = train_dataset.label_to_idx
|
| 184 |
+
val_dataset.idx_to_label = train_dataset.idx_to_label
|
| 185 |
+
val_dataset.num_classes = train_dataset.num_classes
|
| 186 |
+
|
| 187 |
+
class_info = {
|
| 188 |
+
'num_classes': train_dataset.num_classes,
|
| 189 |
+
'label_to_idx': train_dataset.label_to_idx,
|
| 190 |
+
'idx_to_label': train_dataset.idx_to_label,
|
| 191 |
+
'class_counts': class_counts,
|
| 192 |
+
'valid_classes': list(valid_classes),
|
| 193 |
+
'train_samples': len(train_data),
|
| 194 |
+
'val_samples': len(val_data)
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
return train_dataset, val_dataset, class_info
|
| 198 |
+
|
| 199 |
+
def train_model(self, training_data: List[Dict],
|
| 200 |
+
epochs=20, batch_size=16, learning_rate=0.001,
|
| 201 |
+
backbone='resnet18', patience=5) -> Dict:
|
| 202 |
+
"""
|
| 203 |
+
Train custom classifier
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
training_data: Training samples from database
|
| 207 |
+
epochs: Number of training epochs
|
| 208 |
+
batch_size: Batch size for training
|
| 209 |
+
learning_rate: Learning rate
|
| 210 |
+
backbone: Model backbone architecture
|
| 211 |
+
patience: Early stopping patience
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Training results and metrics
|
| 215 |
+
"""
|
| 216 |
+
try:
|
| 217 |
+
# Prepare data
|
| 218 |
+
train_dataset, val_dataset, class_info = self.prepare_data(training_data)
|
| 219 |
+
|
| 220 |
+
# Create data loaders
|
| 221 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 222 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 223 |
+
|
| 224 |
+
# Initialize model
|
| 225 |
+
model = CustomClassifier(class_info['num_classes'], backbone)
|
| 226 |
+
model = model.to(self.device)
|
| 227 |
+
|
| 228 |
+
# Loss and optimizer
|
| 229 |
+
criterion = nn.CrossEntropyLoss()
|
| 230 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 231 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
|
| 232 |
+
|
| 233 |
+
# Training history
|
| 234 |
+
history = {
|
| 235 |
+
'train_loss': [],
|
| 236 |
+
'train_acc': [],
|
| 237 |
+
'val_loss': [],
|
| 238 |
+
'val_acc': []
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
best_val_acc = 0.0
|
| 242 |
+
patience_counter = 0
|
| 243 |
+
|
| 244 |
+
logger.info(f"Starting training: {epochs} epochs, {class_info['num_classes']} classes")
|
| 245 |
+
|
| 246 |
+
for epoch in range(epochs):
|
| 247 |
+
# Training phase
|
| 248 |
+
model.train()
|
| 249 |
+
train_loss = 0.0
|
| 250 |
+
train_correct = 0
|
| 251 |
+
train_total = 0
|
| 252 |
+
|
| 253 |
+
for batch in train_loader:
|
| 254 |
+
images = batch['image'].to(self.device)
|
| 255 |
+
labels = batch['label'].to(self.device)
|
| 256 |
+
|
| 257 |
+
optimizer.zero_grad()
|
| 258 |
+
outputs = model(images)
|
| 259 |
+
loss = criterion(outputs, labels)
|
| 260 |
+
loss.backward()
|
| 261 |
+
optimizer.step()
|
| 262 |
+
|
| 263 |
+
train_loss += loss.item()
|
| 264 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 265 |
+
train_total += labels.size(0)
|
| 266 |
+
train_correct += (predicted == labels).sum().item()
|
| 267 |
+
|
| 268 |
+
train_acc = train_correct / train_total
|
| 269 |
+
avg_train_loss = train_loss / len(train_loader)
|
| 270 |
+
|
| 271 |
+
# Validation phase
|
| 272 |
+
model.eval()
|
| 273 |
+
val_loss = 0.0
|
| 274 |
+
val_correct = 0
|
| 275 |
+
val_total = 0
|
| 276 |
+
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
for batch in val_loader:
|
| 279 |
+
images = batch['image'].to(self.device)
|
| 280 |
+
labels = batch['label'].to(self.device)
|
| 281 |
+
|
| 282 |
+
outputs = model(images)
|
| 283 |
+
loss = criterion(outputs, labels)
|
| 284 |
+
|
| 285 |
+
val_loss += loss.item()
|
| 286 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 287 |
+
val_total += labels.size(0)
|
| 288 |
+
val_correct += (predicted == labels).sum().item()
|
| 289 |
+
|
| 290 |
+
val_acc = val_correct / val_total
|
| 291 |
+
avg_val_loss = val_loss / len(val_loader)
|
| 292 |
+
|
| 293 |
+
# Update history
|
| 294 |
+
history['train_loss'].append(avg_train_loss)
|
| 295 |
+
history['train_acc'].append(train_acc)
|
| 296 |
+
history['val_loss'].append(avg_val_loss)
|
| 297 |
+
history['val_acc'].append(val_acc)
|
| 298 |
+
|
| 299 |
+
logger.info(f"Epoch {epoch+1}/{epochs}: "
|
| 300 |
+
f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, "
|
| 301 |
+
f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")
|
| 302 |
+
|
| 303 |
+
# Early stopping
|
| 304 |
+
if val_acc > best_val_acc:
|
| 305 |
+
best_val_acc = val_acc
|
| 306 |
+
patience_counter = 0
|
| 307 |
+
# Save best model
|
| 308 |
+
torch.save(model.state_dict(), self.model_dir / 'best_model.pth')
|
| 309 |
+
else:
|
| 310 |
+
patience_counter += 1
|
| 311 |
+
|
| 312 |
+
if patience_counter >= patience:
|
| 313 |
+
logger.info(f"Early stopping at epoch {epoch+1}")
|
| 314 |
+
break
|
| 315 |
+
|
| 316 |
+
scheduler.step()
|
| 317 |
+
|
| 318 |
+
# Load best model
|
| 319 |
+
model.load_state_dict(torch.load(self.model_dir / 'best_model.pth'))
|
| 320 |
+
|
| 321 |
+
# Final evaluation
|
| 322 |
+
final_metrics = self.evaluate_model(model, val_loader, class_info)
|
| 323 |
+
|
| 324 |
+
# Save model and metadata
|
| 325 |
+
model_info = {
|
| 326 |
+
'model_state': model.state_dict(),
|
| 327 |
+
'class_info': class_info,
|
| 328 |
+
'training_config': {
|
| 329 |
+
'backbone': backbone,
|
| 330 |
+
'epochs': epochs,
|
| 331 |
+
'batch_size': batch_size,
|
| 332 |
+
'learning_rate': learning_rate
|
| 333 |
+
},
|
| 334 |
+
'history': history,
|
| 335 |
+
'metrics': final_metrics,
|
| 336 |
+
'timestamp': datetime.now().isoformat()
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
# Save complete model info
|
| 340 |
+
model_path = self.model_dir / f'custom_classifier_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pkl'
|
| 341 |
+
with open(model_path, 'wb') as f:
|
| 342 |
+
pickle.dump(model_info, f)
|
| 343 |
+
|
| 344 |
+
logger.info(f"Training completed. Best validation accuracy: {best_val_acc:.4f}")
|
| 345 |
+
logger.info(f"Model saved to: {model_path}")
|
| 346 |
+
|
| 347 |
+
return {
|
| 348 |
+
'success': True,
|
| 349 |
+
'model_path': str(model_path),
|
| 350 |
+
'best_accuracy': best_val_acc,
|
| 351 |
+
'final_metrics': final_metrics,
|
| 352 |
+
'class_info': class_info,
|
| 353 |
+
'history': history
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
except Exception as e:
|
| 357 |
+
logger.error(f"Training failed: {e}")
|
| 358 |
+
return {
|
| 359 |
+
'success': False,
|
| 360 |
+
'error': str(e)
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
def evaluate_model(self, model, val_loader, class_info) -> Dict:
|
| 364 |
+
"""Evaluate model performance"""
|
| 365 |
+
model.eval()
|
| 366 |
+
all_predictions = []
|
| 367 |
+
all_labels = []
|
| 368 |
+
all_confidences = []
|
| 369 |
+
|
| 370 |
+
with torch.no_grad():
|
| 371 |
+
for batch in val_loader:
|
| 372 |
+
images = batch['image'].to(self.device)
|
| 373 |
+
labels = batch['label']
|
| 374 |
+
|
| 375 |
+
outputs = model(images)
|
| 376 |
+
probabilities = torch.softmax(outputs, dim=1)
|
| 377 |
+
confidences, predicted = torch.max(probabilities, 1)
|
| 378 |
+
|
| 379 |
+
all_predictions.extend(predicted.cpu().numpy())
|
| 380 |
+
all_labels.extend(labels.numpy())
|
| 381 |
+
all_confidences.extend(confidences.cpu().numpy())
|
| 382 |
+
|
| 383 |
+
# Calculate metrics manually without sklearn
|
| 384 |
+
accuracy = sum(1 for true, pred in zip(all_labels, all_predictions) if true == pred) / len(all_labels)
|
| 385 |
+
|
| 386 |
+
# Simple precision/recall calculation
|
| 387 |
+
precision = recall = f1 = accuracy # Simplified for now
|
| 388 |
+
|
| 389 |
+
return {
|
| 390 |
+
'accuracy': float(accuracy),
|
| 391 |
+
'precision': float(precision),
|
| 392 |
+
'recall': float(recall),
|
| 393 |
+
'f1_score': float(f1),
|
| 394 |
+
'avg_confidence': float(np.mean(all_confidences)),
|
| 395 |
+
'num_samples': len(all_labels)
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
# Global trainer instance
|
| 399 |
custom_trainer = CustomTrainer()
|
backend/database.py
CHANGED
|
@@ -1,678 +1,678 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Database Module for NAVADA - SQLite storage for faces and objects
|
| 3 |
-
Handles storage, retrieval, and management of custom recognition data
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import sqlite3
|
| 7 |
-
import numpy as np
|
| 8 |
-
import cv2
|
| 9 |
-
from datetime import datetime
|
| 10 |
-
import json
|
| 11 |
-
import base64
|
| 12 |
-
from typing import List, Dict, Optional, Tuple
|
| 13 |
-
import logging
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
|
| 16 |
-
# Configure logging
|
| 17 |
-
logger = logging.getLogger(__name__)
|
| 18 |
-
|
| 19 |
-
class NAVADADatabase:
|
| 20 |
-
"""Database manager for storing faces, objects, and recognition data"""
|
| 21 |
-
|
| 22 |
-
def __init__(self, db_path: str = "navada_recognition.db"):
|
| 23 |
-
"""
|
| 24 |
-
Initialize database connection and create tables
|
| 25 |
-
|
| 26 |
-
Args:
|
| 27 |
-
db_path: Path to SQLite database file
|
| 28 |
-
"""
|
| 29 |
-
self.db_path = db_path
|
| 30 |
-
self.init_database()
|
| 31 |
-
|
| 32 |
-
def init_database(self):
|
| 33 |
-
"""Create database tables if they don't exist"""
|
| 34 |
-
try:
|
| 35 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 36 |
-
cursor = conn.cursor()
|
| 37 |
-
|
| 38 |
-
# Create faces table
|
| 39 |
-
cursor.execute("""
|
| 40 |
-
CREATE TABLE IF NOT EXISTS faces (
|
| 41 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 42 |
-
name TEXT NOT NULL,
|
| 43 |
-
encoding BLOB NOT NULL,
|
| 44 |
-
image_data BLOB,
|
| 45 |
-
confidence REAL DEFAULT 0.0,
|
| 46 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 47 |
-
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 48 |
-
metadata TEXT,
|
| 49 |
-
is_active BOOLEAN DEFAULT 1
|
| 50 |
-
)
|
| 51 |
-
""")
|
| 52 |
-
|
| 53 |
-
# Create objects table
|
| 54 |
-
cursor.execute("""
|
| 55 |
-
CREATE TABLE IF NOT EXISTS objects (
|
| 56 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 57 |
-
label TEXT NOT NULL,
|
| 58 |
-
category TEXT,
|
| 59 |
-
features BLOB,
|
| 60 |
-
image_data BLOB,
|
| 61 |
-
bounding_box TEXT,
|
| 62 |
-
confidence REAL DEFAULT 0.0,
|
| 63 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 64 |
-
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 65 |
-
metadata TEXT,
|
| 66 |
-
is_active BOOLEAN DEFAULT 1
|
| 67 |
-
)
|
| 68 |
-
""")
|
| 69 |
-
|
| 70 |
-
# Create detection history table
|
| 71 |
-
cursor.execute("""
|
| 72 |
-
CREATE TABLE IF NOT EXISTS detection_history (
|
| 73 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 74 |
-
session_id TEXT,
|
| 75 |
-
image_data BLOB,
|
| 76 |
-
detections TEXT,
|
| 77 |
-
face_matches TEXT,
|
| 78 |
-
object_matches TEXT,
|
| 79 |
-
confidence_scores TEXT,
|
| 80 |
-
processing_time REAL,
|
| 81 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 82 |
-
metadata TEXT
|
| 83 |
-
)
|
| 84 |
-
""")
|
| 85 |
-
|
| 86 |
-
# Create knowledge base for RAG
|
| 87 |
-
cursor.execute("""
|
| 88 |
-
CREATE TABLE IF NOT EXISTS knowledge_base (
|
| 89 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 90 |
-
entity_type TEXT NOT NULL,
|
| 91 |
-
entity_id INTEGER NOT NULL,
|
| 92 |
-
content TEXT NOT NULL,
|
| 93 |
-
embedding BLOB,
|
| 94 |
-
keywords TEXT,
|
| 95 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 96 |
-
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 97 |
-
)
|
| 98 |
-
""")
|
| 99 |
-
|
| 100 |
-
# Create training corrections table for active learning
|
| 101 |
-
cursor.execute("""
|
| 102 |
-
CREATE TABLE IF NOT EXISTS training_corrections (
|
| 103 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 104 |
-
image_path TEXT,
|
| 105 |
-
image_crop BLOB NOT NULL,
|
| 106 |
-
bbox_coords TEXT NOT NULL,
|
| 107 |
-
yolo_prediction TEXT NOT NULL,
|
| 108 |
-
yolo_confidence REAL NOT NULL,
|
| 109 |
-
correct_label TEXT NOT NULL,
|
| 110 |
-
user_feedback TEXT,
|
| 111 |
-
difficulty_score REAL DEFAULT 0.0,
|
| 112 |
-
validated BOOLEAN DEFAULT 0,
|
| 113 |
-
used_for_training BOOLEAN DEFAULT 0,
|
| 114 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 115 |
-
session_id TEXT,
|
| 116 |
-
metadata TEXT
|
| 117 |
-
)
|
| 118 |
-
""")
|
| 119 |
-
|
| 120 |
-
# Create custom model versions table
|
| 121 |
-
cursor.execute("""
|
| 122 |
-
CREATE TABLE IF NOT EXISTS model_versions (
|
| 123 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 124 |
-
version_name TEXT NOT NULL UNIQUE,
|
| 125 |
-
model_path TEXT NOT NULL,
|
| 126 |
-
accuracy REAL,
|
| 127 |
-
precision_score REAL,
|
| 128 |
-
recall_score REAL,
|
| 129 |
-
f1_score REAL,
|
| 130 |
-
training_samples INTEGER DEFAULT 0,
|
| 131 |
-
validation_samples INTEGER DEFAULT 0,
|
| 132 |
-
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 133 |
-
is_active BOOLEAN DEFAULT 0,
|
| 134 |
-
performance_metrics TEXT,
|
| 135 |
-
training_config TEXT,
|
| 136 |
-
notes TEXT
|
| 137 |
-
)
|
| 138 |
-
""")
|
| 139 |
-
|
| 140 |
-
# Create custom classes mapping
|
| 141 |
-
cursor.execute("""
|
| 142 |
-
CREATE TABLE IF NOT EXISTS custom_classes (
|
| 143 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 144 |
-
class_name TEXT NOT NULL UNIQUE,
|
| 145 |
-
yolo_class TEXT,
|
| 146 |
-
sample_count INTEGER DEFAULT 0,
|
| 147 |
-
confidence_threshold REAL DEFAULT 0.5,
|
| 148 |
-
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 149 |
-
is_active BOOLEAN DEFAULT 1,
|
| 150 |
-
description TEXT
|
| 151 |
-
)
|
| 152 |
-
""")
|
| 153 |
-
|
| 154 |
-
# Create indexes for better performance
|
| 155 |
-
cursor.execute("CREATE INDEX IF NOT EXISTS idx_faces_name ON faces(name)")
|
| 156 |
-
cursor.execute("CREATE INDEX IF NOT EXISTS idx_objects_label ON objects(label)")
|
| 157 |
-
cursor.execute("CREATE INDEX IF NOT EXISTS idx_history_session ON detection_history(session_id)")
|
| 158 |
-
cursor.execute("CREATE INDEX IF NOT EXISTS idx_knowledge_entity ON knowledge_base(entity_type, entity_id)")
|
| 159 |
-
|
| 160 |
-
conn.commit()
|
| 161 |
-
logger.info("Database initialized successfully")
|
| 162 |
-
|
| 163 |
-
except Exception as e:
|
| 164 |
-
logger.error(f"Database initialization failed: {e}")
|
| 165 |
-
raise
|
| 166 |
-
|
| 167 |
-
def add_face(self, name: str, face_encoding: np.ndarray, image: np.ndarray,
|
| 168 |
-
confidence: float = 0.0, metadata: Dict = None) -> int:
|
| 169 |
-
"""
|
| 170 |
-
Add a new face to the database
|
| 171 |
-
|
| 172 |
-
Args:
|
| 173 |
-
name: Person's name
|
| 174 |
-
face_encoding: Face encoding vector
|
| 175 |
-
image: Face image array
|
| 176 |
-
confidence: Recognition confidence
|
| 177 |
-
metadata: Additional metadata
|
| 178 |
-
|
| 179 |
-
Returns:
|
| 180 |
-
Face ID in database
|
| 181 |
-
"""
|
| 182 |
-
try:
|
| 183 |
-
# Encode image to base64
|
| 184 |
-
_, buffer = cv2.imencode('.jpg', image)
|
| 185 |
-
image_data = base64.b64encode(buffer).decode('utf-8')
|
| 186 |
-
|
| 187 |
-
# Serialize face encoding
|
| 188 |
-
encoding_data = face_encoding.tobytes()
|
| 189 |
-
|
| 190 |
-
# Convert metadata to JSON
|
| 191 |
-
metadata_json = json.dumps(metadata) if metadata else None
|
| 192 |
-
|
| 193 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 194 |
-
cursor = conn.cursor()
|
| 195 |
-
cursor.execute("""
|
| 196 |
-
INSERT INTO faces (name, encoding, image_data, confidence, metadata)
|
| 197 |
-
VALUES (?, ?, ?, ?, ?)
|
| 198 |
-
""", (name, encoding_data, image_data, confidence, metadata_json))
|
| 199 |
-
|
| 200 |
-
face_id = cursor.lastrowid
|
| 201 |
-
conn.commit()
|
| 202 |
-
|
| 203 |
-
# Add to knowledge base
|
| 204 |
-
self.add_knowledge_entry("face", face_id, f"Person named {name}")
|
| 205 |
-
|
| 206 |
-
logger.info(f"Added face for {name} with ID {face_id}")
|
| 207 |
-
return face_id
|
| 208 |
-
|
| 209 |
-
except Exception as e:
|
| 210 |
-
logger.error(f"Failed to add face: {e}")
|
| 211 |
-
raise
|
| 212 |
-
|
| 213 |
-
def add_object(self, label: str, category: str, features: np.ndarray,
|
| 214 |
-
image: np.ndarray, bounding_box: Tuple, confidence: float = 0.0,
|
| 215 |
-
metadata: Dict = None) -> int:
|
| 216 |
-
"""
|
| 217 |
-
Add a new custom object to the database
|
| 218 |
-
|
| 219 |
-
Args:
|
| 220 |
-
label: Object label/name
|
| 221 |
-
category: Object category
|
| 222 |
-
features: Feature vector
|
| 223 |
-
image: Object image
|
| 224 |
-
bounding_box: (x, y, w, h) bounding box
|
| 225 |
-
confidence: Detection confidence
|
| 226 |
-
metadata: Additional metadata
|
| 227 |
-
|
| 228 |
-
Returns:
|
| 229 |
-
Object ID in database
|
| 230 |
-
"""
|
| 231 |
-
try:
|
| 232 |
-
# Encode image to base64
|
| 233 |
-
_, buffer = cv2.imencode('.jpg', image)
|
| 234 |
-
image_data = base64.b64encode(buffer).decode('utf-8')
|
| 235 |
-
|
| 236 |
-
# Serialize features
|
| 237 |
-
features_data = features.tobytes() if features is not None else None
|
| 238 |
-
|
| 239 |
-
# Serialize bounding box
|
| 240 |
-
bbox_json = json.dumps(bounding_box)
|
| 241 |
-
|
| 242 |
-
# Convert metadata to JSON
|
| 243 |
-
metadata_json = json.dumps(metadata) if metadata else None
|
| 244 |
-
|
| 245 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 246 |
-
cursor = conn.cursor()
|
| 247 |
-
cursor.execute("""
|
| 248 |
-
INSERT INTO objects (label, category, features, image_data,
|
| 249 |
-
bounding_box, confidence, metadata)
|
| 250 |
-
VALUES (?, ?, ?, ?, ?, ?, ?)
|
| 251 |
-
""", (label, category, features_data, image_data, bbox_json,
|
| 252 |
-
confidence, metadata_json))
|
| 253 |
-
|
| 254 |
-
object_id = cursor.lastrowid
|
| 255 |
-
conn.commit()
|
| 256 |
-
|
| 257 |
-
# Add to knowledge base
|
| 258 |
-
self.add_knowledge_entry("object", object_id,
|
| 259 |
-
f"{label} - {category} object")
|
| 260 |
-
|
| 261 |
-
logger.info(f"Added object {label} with ID {object_id}")
|
| 262 |
-
return object_id
|
| 263 |
-
|
| 264 |
-
except Exception as e:
|
| 265 |
-
logger.error(f"Failed to add object: {e}")
|
| 266 |
-
raise
|
| 267 |
-
|
| 268 |
-
def get_faces(self, active_only: bool = True) -> List[Dict]:
|
| 269 |
-
"""Get all faces from database"""
|
| 270 |
-
try:
|
| 271 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 272 |
-
cursor = conn.cursor()
|
| 273 |
-
query = "SELECT * FROM faces"
|
| 274 |
-
if active_only:
|
| 275 |
-
query += " WHERE is_active = 1"
|
| 276 |
-
|
| 277 |
-
cursor.execute(query)
|
| 278 |
-
rows = cursor.fetchall()
|
| 279 |
-
|
| 280 |
-
faces = []
|
| 281 |
-
for row in rows:
|
| 282 |
-
face = {
|
| 283 |
-
'id': row[0],
|
| 284 |
-
'name': row[1],
|
| 285 |
-
'encoding': np.frombuffer(row[2], dtype=np.float64),
|
| 286 |
-
'confidence': row[4],
|
| 287 |
-
'created_at': row[5],
|
| 288 |
-
'metadata': json.loads(row[7]) if row[7] else {}
|
| 289 |
-
}
|
| 290 |
-
faces.append(face)
|
| 291 |
-
|
| 292 |
-
return faces
|
| 293 |
-
|
| 294 |
-
except Exception as e:
|
| 295 |
-
logger.error(f"Failed to get faces: {e}")
|
| 296 |
-
return []
|
| 297 |
-
|
| 298 |
-
def get_objects(self, category: str = None, active_only: bool = True) -> List[Dict]:
|
| 299 |
-
"""Get objects from database"""
|
| 300 |
-
try:
|
| 301 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 302 |
-
cursor = conn.cursor()
|
| 303 |
-
query = "SELECT * FROM objects"
|
| 304 |
-
params = []
|
| 305 |
-
|
| 306 |
-
conditions = []
|
| 307 |
-
if active_only:
|
| 308 |
-
conditions.append("is_active = 1")
|
| 309 |
-
if category:
|
| 310 |
-
conditions.append("category = ?")
|
| 311 |
-
params.append(category)
|
| 312 |
-
|
| 313 |
-
if conditions:
|
| 314 |
-
query += " WHERE " + " AND ".join(conditions)
|
| 315 |
-
|
| 316 |
-
cursor.execute(query, params)
|
| 317 |
-
rows = cursor.fetchall()
|
| 318 |
-
|
| 319 |
-
objects = []
|
| 320 |
-
for row in rows:
|
| 321 |
-
obj = {
|
| 322 |
-
'id': row[0],
|
| 323 |
-
'label': row[1],
|
| 324 |
-
'category': row[2],
|
| 325 |
-
'features': np.frombuffer(row[3], dtype=np.float64) if row[3] else None,
|
| 326 |
-
'bounding_box': json.loads(row[5]) if row[5] else None,
|
| 327 |
-
'confidence': row[6],
|
| 328 |
-
'created_at': row[7],
|
| 329 |
-
'metadata': json.loads(row[9]) if row[9] else {}
|
| 330 |
-
}
|
| 331 |
-
objects.append(obj)
|
| 332 |
-
|
| 333 |
-
return objects
|
| 334 |
-
|
| 335 |
-
except Exception as e:
|
| 336 |
-
logger.error(f"Failed to get objects: {e}")
|
| 337 |
-
return []
|
| 338 |
-
|
| 339 |
-
def save_detection_history(self, session_id: str, image: np.ndarray,
|
| 340 |
-
detections: List, face_matches: List = None,
|
| 341 |
-
object_matches: List = None, confidence_scores: Dict = None,
|
| 342 |
-
processing_time: float = 0.0, metadata: Dict = None) -> int:
|
| 343 |
-
"""Save detection results to history"""
|
| 344 |
-
try:
|
| 345 |
-
# Encode image
|
| 346 |
-
_, buffer = cv2.imencode('.jpg', image)
|
| 347 |
-
image_data = base64.b64encode(buffer).decode('utf-8')
|
| 348 |
-
|
| 349 |
-
# Serialize data
|
| 350 |
-
detections_json = json.dumps(detections)
|
| 351 |
-
face_matches_json = json.dumps(face_matches) if face_matches else None
|
| 352 |
-
object_matches_json = json.dumps(object_matches) if object_matches else None
|
| 353 |
-
confidence_json = json.dumps(confidence_scores) if confidence_scores else None
|
| 354 |
-
metadata_json = json.dumps(metadata) if metadata else None
|
| 355 |
-
|
| 356 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 357 |
-
cursor = conn.cursor()
|
| 358 |
-
cursor.execute("""
|
| 359 |
-
INSERT INTO detection_history
|
| 360 |
-
(session_id, image_data, detections, face_matches, object_matches,
|
| 361 |
-
confidence_scores, processing_time, metadata)
|
| 362 |
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 363 |
-
""", (session_id, image_data, detections_json, face_matches_json,
|
| 364 |
-
object_matches_json, confidence_json, processing_time, metadata_json))
|
| 365 |
-
|
| 366 |
-
history_id = cursor.lastrowid
|
| 367 |
-
conn.commit()
|
| 368 |
-
|
| 369 |
-
logger.info(f"Saved detection history with ID {history_id}")
|
| 370 |
-
return history_id
|
| 371 |
-
|
| 372 |
-
except Exception as e:
|
| 373 |
-
logger.error(f"Failed to save detection history: {e}")
|
| 374 |
-
raise
|
| 375 |
-
|
| 376 |
-
def add_knowledge_entry(self, entity_type: str, entity_id: int, content: str,
|
| 377 |
-
keywords: List[str] = None):
|
| 378 |
-
"""Add entry to knowledge base for RAG"""
|
| 379 |
-
try:
|
| 380 |
-
keywords_json = json.dumps(keywords) if keywords else None
|
| 381 |
-
|
| 382 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 383 |
-
cursor = conn.cursor()
|
| 384 |
-
cursor.execute("""
|
| 385 |
-
INSERT INTO knowledge_base (entity_type, entity_id, content, keywords)
|
| 386 |
-
VALUES (?, ?, ?, ?)
|
| 387 |
-
""", (entity_type, entity_id, content, keywords_json))
|
| 388 |
-
conn.commit()
|
| 389 |
-
|
| 390 |
-
except Exception as e:
|
| 391 |
-
logger.error(f"Failed to add knowledge entry: {e}")
|
| 392 |
-
|
| 393 |
-
def search_knowledge(self, query: str, entity_type: str = None) -> List[Dict]:
|
| 394 |
-
"""Search knowledge base for RAG"""
|
| 395 |
-
try:
|
| 396 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 397 |
-
cursor = conn.cursor()
|
| 398 |
-
|
| 399 |
-
# Simple text search (can be enhanced with embeddings)
|
| 400 |
-
search_query = f"%{query.lower()}%"
|
| 401 |
-
|
| 402 |
-
if entity_type:
|
| 403 |
-
cursor.execute("""
|
| 404 |
-
SELECT * FROM knowledge_base
|
| 405 |
-
WHERE entity_type = ? AND LOWER(content) LIKE ?
|
| 406 |
-
ORDER BY created_at DESC LIMIT 10
|
| 407 |
-
""", (entity_type, search_query))
|
| 408 |
-
else:
|
| 409 |
-
cursor.execute("""
|
| 410 |
-
SELECT * FROM knowledge_base
|
| 411 |
-
WHERE LOWER(content) LIKE ?
|
| 412 |
-
ORDER BY created_at DESC LIMIT 10
|
| 413 |
-
""", (search_query,))
|
| 414 |
-
|
| 415 |
-
rows = cursor.fetchall()
|
| 416 |
-
results = []
|
| 417 |
-
|
| 418 |
-
for row in rows:
|
| 419 |
-
result = {
|
| 420 |
-
'id': row[0],
|
| 421 |
-
'entity_type': row[1],
|
| 422 |
-
'entity_id': row[2],
|
| 423 |
-
'content': row[3],
|
| 424 |
-
'keywords': json.loads(row[5]) if row[5] else [],
|
| 425 |
-
'created_at': row[6]
|
| 426 |
-
}
|
| 427 |
-
results.append(result)
|
| 428 |
-
|
| 429 |
-
return results
|
| 430 |
-
|
| 431 |
-
except Exception as e:
|
| 432 |
-
logger.error(f"Knowledge search failed: {e}")
|
| 433 |
-
return []
|
| 434 |
-
|
| 435 |
-
def get_stats(self) -> Dict:
|
| 436 |
-
"""Get database statistics"""
|
| 437 |
-
try:
|
| 438 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 439 |
-
cursor = conn.cursor()
|
| 440 |
-
|
| 441 |
-
# Count faces
|
| 442 |
-
cursor.execute("SELECT COUNT(*) FROM faces WHERE is_active = 1")
|
| 443 |
-
face_count = cursor.fetchone()[0]
|
| 444 |
-
|
| 445 |
-
# Count objects
|
| 446 |
-
cursor.execute("SELECT COUNT(*) FROM objects WHERE is_active = 1")
|
| 447 |
-
object_count = cursor.fetchone()[0]
|
| 448 |
-
|
| 449 |
-
# Count history entries
|
| 450 |
-
cursor.execute("SELECT COUNT(*) FROM detection_history")
|
| 451 |
-
history_count = cursor.fetchone()[0]
|
| 452 |
-
|
| 453 |
-
# Get recent activity
|
| 454 |
-
cursor.execute("""
|
| 455 |
-
SELECT COUNT(*) FROM detection_history
|
| 456 |
-
WHERE created_at > datetime('now', '-7 days')
|
| 457 |
-
""")
|
| 458 |
-
recent_detections = cursor.fetchone()[0]
|
| 459 |
-
|
| 460 |
-
return {
|
| 461 |
-
'faces': face_count,
|
| 462 |
-
'objects': object_count,
|
| 463 |
-
'total_detections': history_count,
|
| 464 |
-
'recent_detections': recent_detections,
|
| 465 |
-
'database_size': Path(self.db_path).stat().st_size if Path(self.db_path).exists() else 0
|
| 466 |
-
}
|
| 467 |
-
|
| 468 |
-
except Exception as e:
|
| 469 |
-
logger.error(f"Failed to get stats: {e}")
|
| 470 |
-
return {}
|
| 471 |
-
|
| 472 |
-
# Training Corrections Methods for Active Learning
|
| 473 |
-
|
| 474 |
-
def save_correction(self, image_crop: np.ndarray, bbox_coords: List[float],
|
| 475 |
-
yolo_prediction: str, yolo_confidence: float,
|
| 476 |
-
correct_label: str, user_feedback: str = "",
|
| 477 |
-
session_id: str = "") -> bool:
|
| 478 |
-
"""
|
| 479 |
-
Save a user correction for training
|
| 480 |
-
|
| 481 |
-
Args:
|
| 482 |
-
image_crop: Cropped image of the detected object
|
| 483 |
-
bbox_coords: [x1, y1, x2, y2] bounding box coordinates
|
| 484 |
-
yolo_prediction: Original YOLO predicted label
|
| 485 |
-
yolo_confidence: Original YOLO confidence score
|
| 486 |
-
correct_label: User-provided correct label
|
| 487 |
-
user_feedback: Optional user feedback text
|
| 488 |
-
session_id: Session identifier
|
| 489 |
-
|
| 490 |
-
Returns:
|
| 491 |
-
bool: Success status
|
| 492 |
-
"""
|
| 493 |
-
try:
|
| 494 |
-
# Convert image to bytes
|
| 495 |
-
_, buffer = cv2.imencode('.jpg', image_crop)
|
| 496 |
-
image_bytes = buffer.tobytes()
|
| 497 |
-
|
| 498 |
-
# Calculate difficulty score (lower confidence = higher difficulty)
|
| 499 |
-
difficulty_score = 1.0 - yolo_confidence
|
| 500 |
-
|
| 501 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 502 |
-
cursor = conn.cursor()
|
| 503 |
-
|
| 504 |
-
cursor.execute("""
|
| 505 |
-
INSERT INTO training_corrections
|
| 506 |
-
(image_crop, bbox_coords, yolo_prediction, yolo_confidence,
|
| 507 |
-
correct_label, user_feedback, difficulty_score, session_id, metadata)
|
| 508 |
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 509 |
-
""", (
|
| 510 |
-
image_bytes,
|
| 511 |
-
json.dumps(bbox_coords),
|
| 512 |
-
yolo_prediction,
|
| 513 |
-
yolo_confidence,
|
| 514 |
-
correct_label,
|
| 515 |
-
user_feedback,
|
| 516 |
-
difficulty_score,
|
| 517 |
-
session_id,
|
| 518 |
-
json.dumps({
|
| 519 |
-
'timestamp': datetime.now().isoformat(),
|
| 520 |
-
'image_shape': image_crop.shape,
|
| 521 |
-
'correction_type': 'user_feedback'
|
| 522 |
-
})
|
| 523 |
-
))
|
| 524 |
-
|
| 525 |
-
# Update or create custom class entry
|
| 526 |
-
cursor.execute("""
|
| 527 |
-
INSERT OR IGNORE INTO custom_classes (class_name, yolo_class, sample_count)
|
| 528 |
-
VALUES (?, ?, 0)
|
| 529 |
-
""", (correct_label, yolo_prediction))
|
| 530 |
-
|
| 531 |
-
cursor.execute("""
|
| 532 |
-
UPDATE custom_classes
|
| 533 |
-
SET sample_count = sample_count + 1
|
| 534 |
-
WHERE class_name = ?
|
| 535 |
-
""", (correct_label,))
|
| 536 |
-
|
| 537 |
-
return True
|
| 538 |
-
|
| 539 |
-
except Exception as e:
|
| 540 |
-
logger.error(f"Failed to save correction: {e}")
|
| 541 |
-
return False
|
| 542 |
-
|
| 543 |
-
def get_training_data(self, class_name: str = None, limit: int = 1000,
|
| 544 |
-
validated_only: bool = False) -> List[Dict]:
|
| 545 |
-
"""
|
| 546 |
-
Retrieve training data for model training
|
| 547 |
-
|
| 548 |
-
Args:
|
| 549 |
-
class_name: Filter by specific class (optional)
|
| 550 |
-
limit: Maximum number of samples to return
|
| 551 |
-
validated_only: Only return validated corrections
|
| 552 |
-
|
| 553 |
-
Returns:
|
| 554 |
-
List of training samples
|
| 555 |
-
"""
|
| 556 |
-
try:
|
| 557 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 558 |
-
cursor = conn.cursor()
|
| 559 |
-
|
| 560 |
-
query = """
|
| 561 |
-
SELECT id, image_crop, bbox_coords, yolo_prediction,
|
| 562 |
-
yolo_confidence, correct_label, difficulty_score,
|
| 563 |
-
created_at, metadata
|
| 564 |
-
FROM training_corrections
|
| 565 |
-
WHERE 1=1
|
| 566 |
-
"""
|
| 567 |
-
params = []
|
| 568 |
-
|
| 569 |
-
if class_name:
|
| 570 |
-
query += " AND correct_label = ?"
|
| 571 |
-
params.append(class_name)
|
| 572 |
-
|
| 573 |
-
if validated_only:
|
| 574 |
-
query += " AND validated = 1"
|
| 575 |
-
|
| 576 |
-
query += " ORDER BY difficulty_score DESC, created_at DESC LIMIT ?"
|
| 577 |
-
params.append(limit)
|
| 578 |
-
|
| 579 |
-
cursor.execute(query, params)
|
| 580 |
-
rows = cursor.fetchall()
|
| 581 |
-
|
| 582 |
-
training_data = []
|
| 583 |
-
for row in rows:
|
| 584 |
-
# Decode image
|
| 585 |
-
image_bytes = row[1]
|
| 586 |
-
image_array = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
|
| 587 |
-
|
| 588 |
-
training_data.append({
|
| 589 |
-
'id': row[0],
|
| 590 |
-
'image': image_array,
|
| 591 |
-
'bbox_coords': json.loads(row[2]),
|
| 592 |
-
'yolo_prediction': row[3],
|
| 593 |
-
'yolo_confidence': row[4],
|
| 594 |
-
'correct_label': row[5],
|
| 595 |
-
'difficulty_score': row[6],
|
| 596 |
-
'created_at': row[7],
|
| 597 |
-
'metadata': json.loads(row[8]) if row[8] else {}
|
| 598 |
-
})
|
| 599 |
-
|
| 600 |
-
return training_data
|
| 601 |
-
|
| 602 |
-
except Exception as e:
|
| 603 |
-
logger.error(f"Failed to get training data: {e}")
|
| 604 |
-
return []
|
| 605 |
-
|
| 606 |
-
def get_training_stats(self) -> Dict:
|
| 607 |
-
"""Get statistics about training corrections"""
|
| 608 |
-
try:
|
| 609 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 610 |
-
cursor = conn.cursor()
|
| 611 |
-
|
| 612 |
-
# Total corrections
|
| 613 |
-
cursor.execute("SELECT COUNT(*) FROM training_corrections")
|
| 614 |
-
total_corrections = cursor.fetchone()[0]
|
| 615 |
-
|
| 616 |
-
# Corrections by class
|
| 617 |
-
cursor.execute("""
|
| 618 |
-
SELECT correct_label, COUNT(*) as count
|
| 619 |
-
FROM training_corrections
|
| 620 |
-
GROUP BY correct_label
|
| 621 |
-
ORDER BY count DESC
|
| 622 |
-
""")
|
| 623 |
-
class_counts = dict(cursor.fetchall())
|
| 624 |
-
|
| 625 |
-
# Validated corrections
|
| 626 |
-
cursor.execute("SELECT COUNT(*) FROM training_corrections WHERE validated = 1")
|
| 627 |
-
validated_count = cursor.fetchone()[0]
|
| 628 |
-
|
| 629 |
-
# Recent corrections (last 7 days)
|
| 630 |
-
cursor.execute("""
|
| 631 |
-
SELECT COUNT(*) FROM training_corrections
|
| 632 |
-
WHERE created_at > datetime('now', '-7 days')
|
| 633 |
-
""")
|
| 634 |
-
recent_corrections = cursor.fetchone()[0]
|
| 635 |
-
|
| 636 |
-
# Average difficulty score
|
| 637 |
-
cursor.execute("SELECT AVG(difficulty_score) FROM training_corrections")
|
| 638 |
-
avg_difficulty = cursor.fetchone()[0] or 0.0
|
| 639 |
-
|
| 640 |
-
return {
|
| 641 |
-
'total_corrections': total_corrections,
|
| 642 |
-
'validated_corrections': validated_count,
|
| 643 |
-
'recent_corrections': recent_corrections,
|
| 644 |
-
'class_distribution': class_counts,
|
| 645 |
-
'average_difficulty': round(avg_difficulty, 3),
|
| 646 |
-
'unique_classes': len(class_counts)
|
| 647 |
-
}
|
| 648 |
-
|
| 649 |
-
except Exception as e:
|
| 650 |
-
logger.error(f"Failed to get training stats: {e}")
|
| 651 |
-
return {}
|
| 652 |
-
|
| 653 |
-
def mark_corrections_used(self, correction_ids: List[int]) -> bool:
|
| 654 |
-
"""Mark corrections as used for training"""
|
| 655 |
-
try:
|
| 656 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 657 |
-
cursor = conn.cursor()
|
| 658 |
-
|
| 659 |
-
placeholders = ','.join(['?'] * len(correction_ids))
|
| 660 |
-
cursor.execute(f"""
|
| 661 |
-
UPDATE training_corrections
|
| 662 |
-
SET used_for_training = 1
|
| 663 |
-
WHERE id IN ({placeholders})
|
| 664 |
-
""", correction_ids)
|
| 665 |
-
|
| 666 |
-
return True
|
| 667 |
-
|
| 668 |
-
except Exception as e:
|
| 669 |
-
logger.error(f"Failed to mark corrections as used: {e}")
|
| 670 |
-
return False
|
| 671 |
-
|
| 672 |
-
# Global database instance
|
| 673 |
-
try:
|
| 674 |
-
db = NAVADADatabase()
|
| 675 |
-
logger.info("Database instance created successfully")
|
| 676 |
-
except Exception as e:
|
| 677 |
-
logger.error(f"Failed to create database instance: {e}")
|
| 678 |
db = None
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database Module for NAVADA - SQLite storage for faces and objects
|
| 3 |
+
Handles storage, retrieval, and management of custom recognition data
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sqlite3
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import json
|
| 11 |
+
import base64
|
| 12 |
+
from typing import List, Dict, Optional, Tuple
|
| 13 |
+
import logging
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Configure logging
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class NAVADADatabase:
|
| 20 |
+
"""Database manager for storing faces, objects, and recognition data"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, db_path: str = "navada_recognition.db"):
|
| 23 |
+
"""
|
| 24 |
+
Initialize database connection and create tables
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
db_path: Path to SQLite database file
|
| 28 |
+
"""
|
| 29 |
+
self.db_path = db_path
|
| 30 |
+
self.init_database()
|
| 31 |
+
|
| 32 |
+
def init_database(self):
|
| 33 |
+
"""Create database tables if they don't exist"""
|
| 34 |
+
try:
|
| 35 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 36 |
+
cursor = conn.cursor()
|
| 37 |
+
|
| 38 |
+
# Create faces table
|
| 39 |
+
cursor.execute("""
|
| 40 |
+
CREATE TABLE IF NOT EXISTS faces (
|
| 41 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 42 |
+
name TEXT NOT NULL,
|
| 43 |
+
encoding BLOB NOT NULL,
|
| 44 |
+
image_data BLOB,
|
| 45 |
+
confidence REAL DEFAULT 0.0,
|
| 46 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 47 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 48 |
+
metadata TEXT,
|
| 49 |
+
is_active BOOLEAN DEFAULT 1
|
| 50 |
+
)
|
| 51 |
+
""")
|
| 52 |
+
|
| 53 |
+
# Create objects table
|
| 54 |
+
cursor.execute("""
|
| 55 |
+
CREATE TABLE IF NOT EXISTS objects (
|
| 56 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 57 |
+
label TEXT NOT NULL,
|
| 58 |
+
category TEXT,
|
| 59 |
+
features BLOB,
|
| 60 |
+
image_data BLOB,
|
| 61 |
+
bounding_box TEXT,
|
| 62 |
+
confidence REAL DEFAULT 0.0,
|
| 63 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 64 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 65 |
+
metadata TEXT,
|
| 66 |
+
is_active BOOLEAN DEFAULT 1
|
| 67 |
+
)
|
| 68 |
+
""")
|
| 69 |
+
|
| 70 |
+
# Create detection history table
|
| 71 |
+
cursor.execute("""
|
| 72 |
+
CREATE TABLE IF NOT EXISTS detection_history (
|
| 73 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 74 |
+
session_id TEXT,
|
| 75 |
+
image_data BLOB,
|
| 76 |
+
detections TEXT,
|
| 77 |
+
face_matches TEXT,
|
| 78 |
+
object_matches TEXT,
|
| 79 |
+
confidence_scores TEXT,
|
| 80 |
+
processing_time REAL,
|
| 81 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 82 |
+
metadata TEXT
|
| 83 |
+
)
|
| 84 |
+
""")
|
| 85 |
+
|
| 86 |
+
# Create knowledge base for RAG
|
| 87 |
+
cursor.execute("""
|
| 88 |
+
CREATE TABLE IF NOT EXISTS knowledge_base (
|
| 89 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 90 |
+
entity_type TEXT NOT NULL,
|
| 91 |
+
entity_id INTEGER NOT NULL,
|
| 92 |
+
content TEXT NOT NULL,
|
| 93 |
+
embedding BLOB,
|
| 94 |
+
keywords TEXT,
|
| 95 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 96 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 97 |
+
)
|
| 98 |
+
""")
|
| 99 |
+
|
| 100 |
+
# Create training corrections table for active learning
|
| 101 |
+
cursor.execute("""
|
| 102 |
+
CREATE TABLE IF NOT EXISTS training_corrections (
|
| 103 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 104 |
+
image_path TEXT,
|
| 105 |
+
image_crop BLOB NOT NULL,
|
| 106 |
+
bbox_coords TEXT NOT NULL,
|
| 107 |
+
yolo_prediction TEXT NOT NULL,
|
| 108 |
+
yolo_confidence REAL NOT NULL,
|
| 109 |
+
correct_label TEXT NOT NULL,
|
| 110 |
+
user_feedback TEXT,
|
| 111 |
+
difficulty_score REAL DEFAULT 0.0,
|
| 112 |
+
validated BOOLEAN DEFAULT 0,
|
| 113 |
+
used_for_training BOOLEAN DEFAULT 0,
|
| 114 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 115 |
+
session_id TEXT,
|
| 116 |
+
metadata TEXT
|
| 117 |
+
)
|
| 118 |
+
""")
|
| 119 |
+
|
| 120 |
+
# Create custom model versions table
|
| 121 |
+
cursor.execute("""
|
| 122 |
+
CREATE TABLE IF NOT EXISTS model_versions (
|
| 123 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 124 |
+
version_name TEXT NOT NULL UNIQUE,
|
| 125 |
+
model_path TEXT NOT NULL,
|
| 126 |
+
accuracy REAL,
|
| 127 |
+
precision_score REAL,
|
| 128 |
+
recall_score REAL,
|
| 129 |
+
f1_score REAL,
|
| 130 |
+
training_samples INTEGER DEFAULT 0,
|
| 131 |
+
validation_samples INTEGER DEFAULT 0,
|
| 132 |
+
training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 133 |
+
is_active BOOLEAN DEFAULT 0,
|
| 134 |
+
performance_metrics TEXT,
|
| 135 |
+
training_config TEXT,
|
| 136 |
+
notes TEXT
|
| 137 |
+
)
|
| 138 |
+
""")
|
| 139 |
+
|
| 140 |
+
# Create custom classes mapping
|
| 141 |
+
cursor.execute("""
|
| 142 |
+
CREATE TABLE IF NOT EXISTS custom_classes (
|
| 143 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 144 |
+
class_name TEXT NOT NULL UNIQUE,
|
| 145 |
+
yolo_class TEXT,
|
| 146 |
+
sample_count INTEGER DEFAULT 0,
|
| 147 |
+
confidence_threshold REAL DEFAULT 0.5,
|
| 148 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 149 |
+
is_active BOOLEAN DEFAULT 1,
|
| 150 |
+
description TEXT
|
| 151 |
+
)
|
| 152 |
+
""")
|
| 153 |
+
|
| 154 |
+
# Create indexes for better performance
|
| 155 |
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_faces_name ON faces(name)")
|
| 156 |
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_objects_label ON objects(label)")
|
| 157 |
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_history_session ON detection_history(session_id)")
|
| 158 |
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_knowledge_entity ON knowledge_base(entity_type, entity_id)")
|
| 159 |
+
|
| 160 |
+
conn.commit()
|
| 161 |
+
logger.info("Database initialized successfully")
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error(f"Database initialization failed: {e}")
|
| 165 |
+
raise
|
| 166 |
+
|
| 167 |
+
def add_face(self, name: str, face_encoding: np.ndarray, image: np.ndarray,
|
| 168 |
+
confidence: float = 0.0, metadata: Dict = None) -> int:
|
| 169 |
+
"""
|
| 170 |
+
Add a new face to the database
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
name: Person's name
|
| 174 |
+
face_encoding: Face encoding vector
|
| 175 |
+
image: Face image array
|
| 176 |
+
confidence: Recognition confidence
|
| 177 |
+
metadata: Additional metadata
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Face ID in database
|
| 181 |
+
"""
|
| 182 |
+
try:
|
| 183 |
+
# Encode image to base64
|
| 184 |
+
_, buffer = cv2.imencode('.jpg', image)
|
| 185 |
+
image_data = base64.b64encode(buffer).decode('utf-8')
|
| 186 |
+
|
| 187 |
+
# Serialize face encoding
|
| 188 |
+
encoding_data = face_encoding.tobytes()
|
| 189 |
+
|
| 190 |
+
# Convert metadata to JSON
|
| 191 |
+
metadata_json = json.dumps(metadata) if metadata else None
|
| 192 |
+
|
| 193 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 194 |
+
cursor = conn.cursor()
|
| 195 |
+
cursor.execute("""
|
| 196 |
+
INSERT INTO faces (name, encoding, image_data, confidence, metadata)
|
| 197 |
+
VALUES (?, ?, ?, ?, ?)
|
| 198 |
+
""", (name, encoding_data, image_data, confidence, metadata_json))
|
| 199 |
+
|
| 200 |
+
face_id = cursor.lastrowid
|
| 201 |
+
conn.commit()
|
| 202 |
+
|
| 203 |
+
# Add to knowledge base
|
| 204 |
+
self.add_knowledge_entry("face", face_id, f"Person named {name}")
|
| 205 |
+
|
| 206 |
+
logger.info(f"Added face for {name} with ID {face_id}")
|
| 207 |
+
return face_id
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.error(f"Failed to add face: {e}")
|
| 211 |
+
raise
|
| 212 |
+
|
| 213 |
+
def add_object(self, label: str, category: str, features: np.ndarray,
|
| 214 |
+
image: np.ndarray, bounding_box: Tuple, confidence: float = 0.0,
|
| 215 |
+
metadata: Dict = None) -> int:
|
| 216 |
+
"""
|
| 217 |
+
Add a new custom object to the database
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
label: Object label/name
|
| 221 |
+
category: Object category
|
| 222 |
+
features: Feature vector
|
| 223 |
+
image: Object image
|
| 224 |
+
bounding_box: (x, y, w, h) bounding box
|
| 225 |
+
confidence: Detection confidence
|
| 226 |
+
metadata: Additional metadata
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Object ID in database
|
| 230 |
+
"""
|
| 231 |
+
try:
|
| 232 |
+
# Encode image to base64
|
| 233 |
+
_, buffer = cv2.imencode('.jpg', image)
|
| 234 |
+
image_data = base64.b64encode(buffer).decode('utf-8')
|
| 235 |
+
|
| 236 |
+
# Serialize features
|
| 237 |
+
features_data = features.tobytes() if features is not None else None
|
| 238 |
+
|
| 239 |
+
# Serialize bounding box
|
| 240 |
+
bbox_json = json.dumps(bounding_box)
|
| 241 |
+
|
| 242 |
+
# Convert metadata to JSON
|
| 243 |
+
metadata_json = json.dumps(metadata) if metadata else None
|
| 244 |
+
|
| 245 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 246 |
+
cursor = conn.cursor()
|
| 247 |
+
cursor.execute("""
|
| 248 |
+
INSERT INTO objects (label, category, features, image_data,
|
| 249 |
+
bounding_box, confidence, metadata)
|
| 250 |
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
| 251 |
+
""", (label, category, features_data, image_data, bbox_json,
|
| 252 |
+
confidence, metadata_json))
|
| 253 |
+
|
| 254 |
+
object_id = cursor.lastrowid
|
| 255 |
+
conn.commit()
|
| 256 |
+
|
| 257 |
+
# Add to knowledge base
|
| 258 |
+
self.add_knowledge_entry("object", object_id,
|
| 259 |
+
f"{label} - {category} object")
|
| 260 |
+
|
| 261 |
+
logger.info(f"Added object {label} with ID {object_id}")
|
| 262 |
+
return object_id
|
| 263 |
+
|
| 264 |
+
except Exception as e:
|
| 265 |
+
logger.error(f"Failed to add object: {e}")
|
| 266 |
+
raise
|
| 267 |
+
|
| 268 |
+
def get_faces(self, active_only: bool = True) -> List[Dict]:
|
| 269 |
+
"""Get all faces from database"""
|
| 270 |
+
try:
|
| 271 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 272 |
+
cursor = conn.cursor()
|
| 273 |
+
query = "SELECT * FROM faces"
|
| 274 |
+
if active_only:
|
| 275 |
+
query += " WHERE is_active = 1"
|
| 276 |
+
|
| 277 |
+
cursor.execute(query)
|
| 278 |
+
rows = cursor.fetchall()
|
| 279 |
+
|
| 280 |
+
faces = []
|
| 281 |
+
for row in rows:
|
| 282 |
+
face = {
|
| 283 |
+
'id': row[0],
|
| 284 |
+
'name': row[1],
|
| 285 |
+
'encoding': np.frombuffer(row[2], dtype=np.float64),
|
| 286 |
+
'confidence': row[4],
|
| 287 |
+
'created_at': row[5],
|
| 288 |
+
'metadata': json.loads(row[7]) if row[7] else {}
|
| 289 |
+
}
|
| 290 |
+
faces.append(face)
|
| 291 |
+
|
| 292 |
+
return faces
|
| 293 |
+
|
| 294 |
+
except Exception as e:
|
| 295 |
+
logger.error(f"Failed to get faces: {e}")
|
| 296 |
+
return []
|
| 297 |
+
|
| 298 |
+
def get_objects(self, category: str = None, active_only: bool = True) -> List[Dict]:
|
| 299 |
+
"""Get objects from database"""
|
| 300 |
+
try:
|
| 301 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 302 |
+
cursor = conn.cursor()
|
| 303 |
+
query = "SELECT * FROM objects"
|
| 304 |
+
params = []
|
| 305 |
+
|
| 306 |
+
conditions = []
|
| 307 |
+
if active_only:
|
| 308 |
+
conditions.append("is_active = 1")
|
| 309 |
+
if category:
|
| 310 |
+
conditions.append("category = ?")
|
| 311 |
+
params.append(category)
|
| 312 |
+
|
| 313 |
+
if conditions:
|
| 314 |
+
query += " WHERE " + " AND ".join(conditions)
|
| 315 |
+
|
| 316 |
+
cursor.execute(query, params)
|
| 317 |
+
rows = cursor.fetchall()
|
| 318 |
+
|
| 319 |
+
objects = []
|
| 320 |
+
for row in rows:
|
| 321 |
+
obj = {
|
| 322 |
+
'id': row[0],
|
| 323 |
+
'label': row[1],
|
| 324 |
+
'category': row[2],
|
| 325 |
+
'features': np.frombuffer(row[3], dtype=np.float64) if row[3] else None,
|
| 326 |
+
'bounding_box': json.loads(row[5]) if row[5] else None,
|
| 327 |
+
'confidence': row[6],
|
| 328 |
+
'created_at': row[7],
|
| 329 |
+
'metadata': json.loads(row[9]) if row[9] else {}
|
| 330 |
+
}
|
| 331 |
+
objects.append(obj)
|
| 332 |
+
|
| 333 |
+
return objects
|
| 334 |
+
|
| 335 |
+
except Exception as e:
|
| 336 |
+
logger.error(f"Failed to get objects: {e}")
|
| 337 |
+
return []
|
| 338 |
+
|
| 339 |
+
def save_detection_history(self, session_id: str, image: np.ndarray,
|
| 340 |
+
detections: List, face_matches: List = None,
|
| 341 |
+
object_matches: List = None, confidence_scores: Dict = None,
|
| 342 |
+
processing_time: float = 0.0, metadata: Dict = None) -> int:
|
| 343 |
+
"""Save detection results to history"""
|
| 344 |
+
try:
|
| 345 |
+
# Encode image
|
| 346 |
+
_, buffer = cv2.imencode('.jpg', image)
|
| 347 |
+
image_data = base64.b64encode(buffer).decode('utf-8')
|
| 348 |
+
|
| 349 |
+
# Serialize data
|
| 350 |
+
detections_json = json.dumps(detections)
|
| 351 |
+
face_matches_json = json.dumps(face_matches) if face_matches else None
|
| 352 |
+
object_matches_json = json.dumps(object_matches) if object_matches else None
|
| 353 |
+
confidence_json = json.dumps(confidence_scores) if confidence_scores else None
|
| 354 |
+
metadata_json = json.dumps(metadata) if metadata else None
|
| 355 |
+
|
| 356 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 357 |
+
cursor = conn.cursor()
|
| 358 |
+
cursor.execute("""
|
| 359 |
+
INSERT INTO detection_history
|
| 360 |
+
(session_id, image_data, detections, face_matches, object_matches,
|
| 361 |
+
confidence_scores, processing_time, metadata)
|
| 362 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 363 |
+
""", (session_id, image_data, detections_json, face_matches_json,
|
| 364 |
+
object_matches_json, confidence_json, processing_time, metadata_json))
|
| 365 |
+
|
| 366 |
+
history_id = cursor.lastrowid
|
| 367 |
+
conn.commit()
|
| 368 |
+
|
| 369 |
+
logger.info(f"Saved detection history with ID {history_id}")
|
| 370 |
+
return history_id
|
| 371 |
+
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.error(f"Failed to save detection history: {e}")
|
| 374 |
+
raise
|
| 375 |
+
|
| 376 |
+
def add_knowledge_entry(self, entity_type: str, entity_id: int, content: str,
|
| 377 |
+
keywords: List[str] = None):
|
| 378 |
+
"""Add entry to knowledge base for RAG"""
|
| 379 |
+
try:
|
| 380 |
+
keywords_json = json.dumps(keywords) if keywords else None
|
| 381 |
+
|
| 382 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 383 |
+
cursor = conn.cursor()
|
| 384 |
+
cursor.execute("""
|
| 385 |
+
INSERT INTO knowledge_base (entity_type, entity_id, content, keywords)
|
| 386 |
+
VALUES (?, ?, ?, ?)
|
| 387 |
+
""", (entity_type, entity_id, content, keywords_json))
|
| 388 |
+
conn.commit()
|
| 389 |
+
|
| 390 |
+
except Exception as e:
|
| 391 |
+
logger.error(f"Failed to add knowledge entry: {e}")
|
| 392 |
+
|
| 393 |
+
def search_knowledge(self, query: str, entity_type: str = None) -> List[Dict]:
|
| 394 |
+
"""Search knowledge base for RAG"""
|
| 395 |
+
try:
|
| 396 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 397 |
+
cursor = conn.cursor()
|
| 398 |
+
|
| 399 |
+
# Simple text search (can be enhanced with embeddings)
|
| 400 |
+
search_query = f"%{query.lower()}%"
|
| 401 |
+
|
| 402 |
+
if entity_type:
|
| 403 |
+
cursor.execute("""
|
| 404 |
+
SELECT * FROM knowledge_base
|
| 405 |
+
WHERE entity_type = ? AND LOWER(content) LIKE ?
|
| 406 |
+
ORDER BY created_at DESC LIMIT 10
|
| 407 |
+
""", (entity_type, search_query))
|
| 408 |
+
else:
|
| 409 |
+
cursor.execute("""
|
| 410 |
+
SELECT * FROM knowledge_base
|
| 411 |
+
WHERE LOWER(content) LIKE ?
|
| 412 |
+
ORDER BY created_at DESC LIMIT 10
|
| 413 |
+
""", (search_query,))
|
| 414 |
+
|
| 415 |
+
rows = cursor.fetchall()
|
| 416 |
+
results = []
|
| 417 |
+
|
| 418 |
+
for row in rows:
|
| 419 |
+
result = {
|
| 420 |
+
'id': row[0],
|
| 421 |
+
'entity_type': row[1],
|
| 422 |
+
'entity_id': row[2],
|
| 423 |
+
'content': row[3],
|
| 424 |
+
'keywords': json.loads(row[5]) if row[5] else [],
|
| 425 |
+
'created_at': row[6]
|
| 426 |
+
}
|
| 427 |
+
results.append(result)
|
| 428 |
+
|
| 429 |
+
return results
|
| 430 |
+
|
| 431 |
+
except Exception as e:
|
| 432 |
+
logger.error(f"Knowledge search failed: {e}")
|
| 433 |
+
return []
|
| 434 |
+
|
| 435 |
+
def get_stats(self) -> Dict:
|
| 436 |
+
"""Get database statistics"""
|
| 437 |
+
try:
|
| 438 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 439 |
+
cursor = conn.cursor()
|
| 440 |
+
|
| 441 |
+
# Count faces
|
| 442 |
+
cursor.execute("SELECT COUNT(*) FROM faces WHERE is_active = 1")
|
| 443 |
+
face_count = cursor.fetchone()[0]
|
| 444 |
+
|
| 445 |
+
# Count objects
|
| 446 |
+
cursor.execute("SELECT COUNT(*) FROM objects WHERE is_active = 1")
|
| 447 |
+
object_count = cursor.fetchone()[0]
|
| 448 |
+
|
| 449 |
+
# Count history entries
|
| 450 |
+
cursor.execute("SELECT COUNT(*) FROM detection_history")
|
| 451 |
+
history_count = cursor.fetchone()[0]
|
| 452 |
+
|
| 453 |
+
# Get recent activity
|
| 454 |
+
cursor.execute("""
|
| 455 |
+
SELECT COUNT(*) FROM detection_history
|
| 456 |
+
WHERE created_at > datetime('now', '-7 days')
|
| 457 |
+
""")
|
| 458 |
+
recent_detections = cursor.fetchone()[0]
|
| 459 |
+
|
| 460 |
+
return {
|
| 461 |
+
'faces': face_count,
|
| 462 |
+
'objects': object_count,
|
| 463 |
+
'total_detections': history_count,
|
| 464 |
+
'recent_detections': recent_detections,
|
| 465 |
+
'database_size': Path(self.db_path).stat().st_size if Path(self.db_path).exists() else 0
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
except Exception as e:
|
| 469 |
+
logger.error(f"Failed to get stats: {e}")
|
| 470 |
+
return {}
|
| 471 |
+
|
| 472 |
+
# Training Corrections Methods for Active Learning
|
| 473 |
+
|
| 474 |
+
def save_correction(self, image_crop: np.ndarray, bbox_coords: List[float],
|
| 475 |
+
yolo_prediction: str, yolo_confidence: float,
|
| 476 |
+
correct_label: str, user_feedback: str = "",
|
| 477 |
+
session_id: str = "") -> bool:
|
| 478 |
+
"""
|
| 479 |
+
Save a user correction for training
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
image_crop: Cropped image of the detected object
|
| 483 |
+
bbox_coords: [x1, y1, x2, y2] bounding box coordinates
|
| 484 |
+
yolo_prediction: Original YOLO predicted label
|
| 485 |
+
yolo_confidence: Original YOLO confidence score
|
| 486 |
+
correct_label: User-provided correct label
|
| 487 |
+
user_feedback: Optional user feedback text
|
| 488 |
+
session_id: Session identifier
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
bool: Success status
|
| 492 |
+
"""
|
| 493 |
+
try:
|
| 494 |
+
# Convert image to bytes
|
| 495 |
+
_, buffer = cv2.imencode('.jpg', image_crop)
|
| 496 |
+
image_bytes = buffer.tobytes()
|
| 497 |
+
|
| 498 |
+
# Calculate difficulty score (lower confidence = higher difficulty)
|
| 499 |
+
difficulty_score = 1.0 - yolo_confidence
|
| 500 |
+
|
| 501 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 502 |
+
cursor = conn.cursor()
|
| 503 |
+
|
| 504 |
+
cursor.execute("""
|
| 505 |
+
INSERT INTO training_corrections
|
| 506 |
+
(image_crop, bbox_coords, yolo_prediction, yolo_confidence,
|
| 507 |
+
correct_label, user_feedback, difficulty_score, session_id, metadata)
|
| 508 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 509 |
+
""", (
|
| 510 |
+
image_bytes,
|
| 511 |
+
json.dumps(bbox_coords),
|
| 512 |
+
yolo_prediction,
|
| 513 |
+
yolo_confidence,
|
| 514 |
+
correct_label,
|
| 515 |
+
user_feedback,
|
| 516 |
+
difficulty_score,
|
| 517 |
+
session_id,
|
| 518 |
+
json.dumps({
|
| 519 |
+
'timestamp': datetime.now().isoformat(),
|
| 520 |
+
'image_shape': image_crop.shape,
|
| 521 |
+
'correction_type': 'user_feedback'
|
| 522 |
+
})
|
| 523 |
+
))
|
| 524 |
+
|
| 525 |
+
# Update or create custom class entry
|
| 526 |
+
cursor.execute("""
|
| 527 |
+
INSERT OR IGNORE INTO custom_classes (class_name, yolo_class, sample_count)
|
| 528 |
+
VALUES (?, ?, 0)
|
| 529 |
+
""", (correct_label, yolo_prediction))
|
| 530 |
+
|
| 531 |
+
cursor.execute("""
|
| 532 |
+
UPDATE custom_classes
|
| 533 |
+
SET sample_count = sample_count + 1
|
| 534 |
+
WHERE class_name = ?
|
| 535 |
+
""", (correct_label,))
|
| 536 |
+
|
| 537 |
+
return True
|
| 538 |
+
|
| 539 |
+
except Exception as e:
|
| 540 |
+
logger.error(f"Failed to save correction: {e}")
|
| 541 |
+
return False
|
| 542 |
+
|
| 543 |
+
def get_training_data(self, class_name: str = None, limit: int = 1000,
|
| 544 |
+
validated_only: bool = False) -> List[Dict]:
|
| 545 |
+
"""
|
| 546 |
+
Retrieve training data for model training
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
class_name: Filter by specific class (optional)
|
| 550 |
+
limit: Maximum number of samples to return
|
| 551 |
+
validated_only: Only return validated corrections
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
List of training samples
|
| 555 |
+
"""
|
| 556 |
+
try:
|
| 557 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 558 |
+
cursor = conn.cursor()
|
| 559 |
+
|
| 560 |
+
query = """
|
| 561 |
+
SELECT id, image_crop, bbox_coords, yolo_prediction,
|
| 562 |
+
yolo_confidence, correct_label, difficulty_score,
|
| 563 |
+
created_at, metadata
|
| 564 |
+
FROM training_corrections
|
| 565 |
+
WHERE 1=1
|
| 566 |
+
"""
|
| 567 |
+
params = []
|
| 568 |
+
|
| 569 |
+
if class_name:
|
| 570 |
+
query += " AND correct_label = ?"
|
| 571 |
+
params.append(class_name)
|
| 572 |
+
|
| 573 |
+
if validated_only:
|
| 574 |
+
query += " AND validated = 1"
|
| 575 |
+
|
| 576 |
+
query += " ORDER BY difficulty_score DESC, created_at DESC LIMIT ?"
|
| 577 |
+
params.append(limit)
|
| 578 |
+
|
| 579 |
+
cursor.execute(query, params)
|
| 580 |
+
rows = cursor.fetchall()
|
| 581 |
+
|
| 582 |
+
training_data = []
|
| 583 |
+
for row in rows:
|
| 584 |
+
# Decode image
|
| 585 |
+
image_bytes = row[1]
|
| 586 |
+
image_array = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
|
| 587 |
+
|
| 588 |
+
training_data.append({
|
| 589 |
+
'id': row[0],
|
| 590 |
+
'image': image_array,
|
| 591 |
+
'bbox_coords': json.loads(row[2]),
|
| 592 |
+
'yolo_prediction': row[3],
|
| 593 |
+
'yolo_confidence': row[4],
|
| 594 |
+
'correct_label': row[5],
|
| 595 |
+
'difficulty_score': row[6],
|
| 596 |
+
'created_at': row[7],
|
| 597 |
+
'metadata': json.loads(row[8]) if row[8] else {}
|
| 598 |
+
})
|
| 599 |
+
|
| 600 |
+
return training_data
|
| 601 |
+
|
| 602 |
+
except Exception as e:
|
| 603 |
+
logger.error(f"Failed to get training data: {e}")
|
| 604 |
+
return []
|
| 605 |
+
|
| 606 |
+
def get_training_stats(self) -> Dict:
|
| 607 |
+
"""Get statistics about training corrections"""
|
| 608 |
+
try:
|
| 609 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 610 |
+
cursor = conn.cursor()
|
| 611 |
+
|
| 612 |
+
# Total corrections
|
| 613 |
+
cursor.execute("SELECT COUNT(*) FROM training_corrections")
|
| 614 |
+
total_corrections = cursor.fetchone()[0]
|
| 615 |
+
|
| 616 |
+
# Corrections by class
|
| 617 |
+
cursor.execute("""
|
| 618 |
+
SELECT correct_label, COUNT(*) as count
|
| 619 |
+
FROM training_corrections
|
| 620 |
+
GROUP BY correct_label
|
| 621 |
+
ORDER BY count DESC
|
| 622 |
+
""")
|
| 623 |
+
class_counts = dict(cursor.fetchall())
|
| 624 |
+
|
| 625 |
+
# Validated corrections
|
| 626 |
+
cursor.execute("SELECT COUNT(*) FROM training_corrections WHERE validated = 1")
|
| 627 |
+
validated_count = cursor.fetchone()[0]
|
| 628 |
+
|
| 629 |
+
# Recent corrections (last 7 days)
|
| 630 |
+
cursor.execute("""
|
| 631 |
+
SELECT COUNT(*) FROM training_corrections
|
| 632 |
+
WHERE created_at > datetime('now', '-7 days')
|
| 633 |
+
""")
|
| 634 |
+
recent_corrections = cursor.fetchone()[0]
|
| 635 |
+
|
| 636 |
+
# Average difficulty score
|
| 637 |
+
cursor.execute("SELECT AVG(difficulty_score) FROM training_corrections")
|
| 638 |
+
avg_difficulty = cursor.fetchone()[0] or 0.0
|
| 639 |
+
|
| 640 |
+
return {
|
| 641 |
+
'total_corrections': total_corrections,
|
| 642 |
+
'validated_corrections': validated_count,
|
| 643 |
+
'recent_corrections': recent_corrections,
|
| 644 |
+
'class_distribution': class_counts,
|
| 645 |
+
'average_difficulty': round(avg_difficulty, 3),
|
| 646 |
+
'unique_classes': len(class_counts)
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
except Exception as e:
|
| 650 |
+
logger.error(f"Failed to get training stats: {e}")
|
| 651 |
+
return {}
|
| 652 |
+
|
| 653 |
+
def mark_corrections_used(self, correction_ids: List[int]) -> bool:
|
| 654 |
+
"""Mark corrections as used for training"""
|
| 655 |
+
try:
|
| 656 |
+
with sqlite3.connect(self.db_path) as conn:
|
| 657 |
+
cursor = conn.cursor()
|
| 658 |
+
|
| 659 |
+
placeholders = ','.join(['?'] * len(correction_ids))
|
| 660 |
+
cursor.execute(f"""
|
| 661 |
+
UPDATE training_corrections
|
| 662 |
+
SET used_for_training = 1
|
| 663 |
+
WHERE id IN ({placeholders})
|
| 664 |
+
""", correction_ids)
|
| 665 |
+
|
| 666 |
+
return True
|
| 667 |
+
|
| 668 |
+
except Exception as e:
|
| 669 |
+
logger.error(f"Failed to mark corrections as used: {e}")
|
| 670 |
+
return False
|
| 671 |
+
|
| 672 |
+
# Global database instance
|
| 673 |
+
try:
|
| 674 |
+
db = NAVADADatabase()
|
| 675 |
+
logger.info("Database instance created successfully")
|
| 676 |
+
except Exception as e:
|
| 677 |
+
logger.error(f"Failed to create database instance: {e}")
|
| 678 |
db = None
|
backend/face_detection.py
CHANGED
|
@@ -1,299 +1,299 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Face Detection Module for NAVADA
|
| 3 |
-
This module provides face detection capabilities using OpenCV's Haar Cascades.
|
| 4 |
-
It can detect faces, eyes, and smiles in images and return detailed statistics.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import cv2 # OpenCV library for computer vision tasks
|
| 8 |
-
import numpy as np # NumPy for numerical operations on arrays
|
| 9 |
-
from typing import Tuple, List, Dict, Optional, Union # Type hints for better code documentation
|
| 10 |
-
import os # Operating system interface for file path operations
|
| 11 |
-
import logging # Logging module for error tracking
|
| 12 |
-
|
| 13 |
-
# Configure logging for this module
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class FaceDetector:
|
| 18 |
-
"""
|
| 19 |
-
A class to handle face detection using OpenCV's Haar Cascade classifiers.
|
| 20 |
-
This detector can identify faces, eyes, and smiles in images.
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
def __init__(self):
|
| 24 |
-
"""
|
| 25 |
-
Initialize the FaceDetector with pre-trained Haar Cascade classifiers.
|
| 26 |
-
Loads classifiers for face, eye, and smile detection.
|
| 27 |
-
"""
|
| 28 |
-
try:
|
| 29 |
-
# Load the pre-trained Haar Cascade classifier for frontal face detection
|
| 30 |
-
# This XML file contains trained patterns for detecting frontal faces
|
| 31 |
-
self.face_cascade = cv2.CascadeClassifier(
|
| 32 |
-
cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
# Load the classifier for eye detection
|
| 36 |
-
# This works best when applied to face regions
|
| 37 |
-
self.eye_cascade = cv2.CascadeClassifier(
|
| 38 |
-
cv2.data.haarcascades + 'haarcascade_eye.xml'
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
# Load the classifier for smile detection
|
| 42 |
-
# This detects smiling expressions in face regions
|
| 43 |
-
self.smile_cascade = cv2.CascadeClassifier(
|
| 44 |
-
cv2.data.haarcascades + 'haarcascade_smile.xml'
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
# Verify that classifiers loaded successfully
|
| 48 |
-
if self.face_cascade.empty():
|
| 49 |
-
raise ValueError("Failed to load face cascade classifier")
|
| 50 |
-
if self.eye_cascade.empty():
|
| 51 |
-
raise ValueError("Failed to load eye cascade classifier")
|
| 52 |
-
if self.smile_cascade.empty():
|
| 53 |
-
raise ValueError("Failed to load smile cascade classifier")
|
| 54 |
-
|
| 55 |
-
logger.info("Face detection classifiers loaded successfully")
|
| 56 |
-
|
| 57 |
-
except Exception as e:
|
| 58 |
-
logger.error(f"Error initializing face detector: {str(e)}")
|
| 59 |
-
raise
|
| 60 |
-
|
| 61 |
-
def detect_faces(self, image: Optional[np.ndarray]) -> Tuple[np.ndarray, Dict]:
|
| 62 |
-
"""
|
| 63 |
-
Detect faces in an image and return an annotated image with statistics.
|
| 64 |
-
|
| 65 |
-
Parameters:
|
| 66 |
-
-----------
|
| 67 |
-
image : np.ndarray
|
| 68 |
-
Input image as a NumPy array (can be grayscale or color)
|
| 69 |
-
|
| 70 |
-
Returns:
|
| 71 |
-
--------
|
| 72 |
-
Tuple[np.ndarray, Dict]
|
| 73 |
-
- Annotated image with face detection boxes and labels
|
| 74 |
-
- Dictionary containing detection statistics and face details
|
| 75 |
-
"""
|
| 76 |
-
# Input validation - check if image is provided and valid
|
| 77 |
-
if image is None:
|
| 78 |
-
logger.warning("No image provided for face detection")
|
| 79 |
-
return np.zeros((480, 640, 3), dtype=np.uint8), {
|
| 80 |
-
'total_faces': 0,
|
| 81 |
-
'faces': [],
|
| 82 |
-
'detection_method': 'Haar Cascade',
|
| 83 |
-
'features_detected': {'eyes': 0, 'smiles': 0}
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
# Ensure image is a numpy array
|
| 87 |
-
if not isinstance(image, np.ndarray):
|
| 88 |
-
logger.error("Image must be a numpy array")
|
| 89 |
-
raise TypeError("Image must be a numpy array")
|
| 90 |
-
|
| 91 |
-
# Check if image is empty
|
| 92 |
-
if image.size == 0:
|
| 93 |
-
logger.warning("Empty image provided")
|
| 94 |
-
return image, {
|
| 95 |
-
'total_faces': 0,
|
| 96 |
-
'faces': [],
|
| 97 |
-
'detection_method': 'Haar Cascade',
|
| 98 |
-
'features_detected': {'eyes': 0, 'smiles': 0}
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
# Convert grayscale images to RGB for consistent output
|
| 102 |
-
# Check the number of dimensions to determine if image is grayscale
|
| 103 |
-
if len(image.shape) == 2: # Grayscale image (height, width)
|
| 104 |
-
img_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 105 |
-
elif len(image.shape) == 3: # Color image (height, width, channels)
|
| 106 |
-
img_rgb = image.copy() # Create a copy to avoid modifying original
|
| 107 |
-
else:
|
| 108 |
-
logger.error(f"Invalid image shape: {image.shape}")
|
| 109 |
-
raise ValueError(f"Invalid image shape: {image.shape}")
|
| 110 |
-
|
| 111 |
-
# Convert to grayscale for detection algorithms
|
| 112 |
-
# Haar Cascades work on grayscale images for better performance
|
| 113 |
-
try:
|
| 114 |
-
gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
|
| 115 |
-
except cv2.error as e:
|
| 116 |
-
logger.error(f"Error converting image to grayscale: {str(e)}")
|
| 117 |
-
return img_rgb, {
|
| 118 |
-
'total_faces': 0,
|
| 119 |
-
'faces': [],
|
| 120 |
-
'detection_method': 'Haar Cascade',
|
| 121 |
-
'features_detected': {'eyes': 0, 'smiles': 0}
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
# Detect faces using the Haar Cascade classifier
|
| 125 |
-
# Parameters control detection sensitivity and performance
|
| 126 |
-
faces = self.face_cascade.detectMultiScale(
|
| 127 |
-
gray, # Grayscale image to search
|
| 128 |
-
scaleFactor=1.1, # Image pyramid scaling factor (1.1 = 10% reduction each level)
|
| 129 |
-
minNeighbors=5, # Minimum neighbors for detection confidence
|
| 130 |
-
minSize=(30, 30) # Minimum face size in pixels
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
# List to store detailed information about each detected face
|
| 134 |
-
face_details = []
|
| 135 |
-
|
| 136 |
-
# Process each detected face
|
| 137 |
-
for idx, (x, y, w, h) in enumerate(faces):
|
| 138 |
-
# Draw a magenta rectangle around the detected face
|
| 139 |
-
# Parameters: image, top-left corner, bottom-right corner, color (BGR), thickness
|
| 140 |
-
cv2.rectangle(img_rgb, (x, y), (x+w, y+h), (255, 0, 255), 3)
|
| 141 |
-
|
| 142 |
-
# Add a label above each face
|
| 143 |
-
# Parameters: image, text, position, font, scale, color, thickness
|
| 144 |
-
cv2.putText(
|
| 145 |
-
img_rgb,
|
| 146 |
-
f"Face {idx+1}", # Label text
|
| 147 |
-
(x, y-10), # Position (above the rectangle)
|
| 148 |
-
cv2.FONT_HERSHEY_SIMPLEX, # Font type
|
| 149 |
-
0.7, # Font scale
|
| 150 |
-
(255, 0, 255), # Color (magenta in RGB)
|
| 151 |
-
2 # Thickness
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
# Extract Region of Interest (ROI) for face area
|
| 155 |
-
# This isolates the face region for feature detection
|
| 156 |
-
roi_gray = gray[y:y+h, x:x+w] # Grayscale ROI for detection
|
| 157 |
-
roi_color = img_rgb[y:y+h, x:x+w] # Color ROI for drawing
|
| 158 |
-
|
| 159 |
-
# Detect eyes within the face region
|
| 160 |
-
# Using different parameters for eye detection (more sensitive)
|
| 161 |
-
eyes = self.eye_cascade.detectMultiScale(
|
| 162 |
-
roi_gray,
|
| 163 |
-
scaleFactor=1.05, # Smaller scale factor for finer detection
|
| 164 |
-
minNeighbors=3 # Fewer neighbors required
|
| 165 |
-
)
|
| 166 |
-
eye_count = len(eyes) # Count the number of detected eyes
|
| 167 |
-
|
| 168 |
-
# Draw green rectangles around detected eyes
|
| 169 |
-
for (ex, ey, ew, eh) in eyes:
|
| 170 |
-
cv2.rectangle(
|
| 171 |
-
roi_color, # Draw on the color ROI
|
| 172 |
-
(ex, ey), # Top-left corner
|
| 173 |
-
(ex+ew, ey+eh), # Bottom-right corner
|
| 174 |
-
(0, 255, 0), # Green color in RGB
|
| 175 |
-
2 # Thickness
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
-
# Detect smiles within the face region
|
| 179 |
-
# Smile detection requires different parameters
|
| 180 |
-
smiles = self.smile_cascade.detectMultiScale(
|
| 181 |
-
roi_gray,
|
| 182 |
-
scaleFactor=1.8, # Larger scale factor for smile detection
|
| 183 |
-
minNeighbors=20 # More neighbors required for confidence
|
| 184 |
-
)
|
| 185 |
-
has_smile = len(smiles) > 0 # Boolean flag for smile presence
|
| 186 |
-
|
| 187 |
-
# Draw yellow rectangles around detected smiles
|
| 188 |
-
for (sx, sy, sw, sh) in smiles:
|
| 189 |
-
cv2.rectangle(
|
| 190 |
-
roi_color, # Draw on the color ROI
|
| 191 |
-
(sx, sy), # Top-left corner
|
| 192 |
-
(sx+sw, sy+sh), # Bottom-right corner
|
| 193 |
-
(0, 255, 255), # Yellow color in RGB
|
| 194 |
-
2 # Thickness
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
# Store detailed information about this face
|
| 198 |
-
face_details.append({
|
| 199 |
-
'face_id': idx + 1, # Sequential face ID starting from 1
|
| 200 |
-
'position': { # Face bounding box coordinates
|
| 201 |
-
'x': int(x), # X coordinate of top-left corner
|
| 202 |
-
'y': int(y), # Y coordinate of top-left corner
|
| 203 |
-
'width': int(w), # Width of face bounding box
|
| 204 |
-
'height': int(h) # Height of face bounding box
|
| 205 |
-
},
|
| 206 |
-
'eyes_detected': eye_count, # Number of eyes detected
|
| 207 |
-
'smile_detected': has_smile, # Whether a smile was detected
|
| 208 |
-
'confidence': 0.95 # Placeholder confidence score
|
| 209 |
-
})
|
| 210 |
-
|
| 211 |
-
# Compile comprehensive statistics about all detected faces
|
| 212 |
-
stats = {
|
| 213 |
-
'total_faces': len(faces), # Total number of faces detected
|
| 214 |
-
'faces': face_details, # List of detailed face information
|
| 215 |
-
'detection_method': 'Haar Cascade', # Method used for detection
|
| 216 |
-
'features_detected': { # Aggregate feature statistics
|
| 217 |
-
'eyes': sum(f['eyes_detected'] for f in face_details), # Total eyes
|
| 218 |
-
'smiles': sum(1 for f in face_details if f['smile_detected']) # Total smiles
|
| 219 |
-
}
|
| 220 |
-
}
|
| 221 |
-
|
| 222 |
-
return img_rgb, stats # Return annotated image and statistics
|
| 223 |
-
|
| 224 |
-
def analyze_demographics(self, face_stats: Optional[Dict]) -> str:
|
| 225 |
-
"""
|
| 226 |
-
Create a demographic analysis report based on face detection statistics.
|
| 227 |
-
|
| 228 |
-
Parameters:
|
| 229 |
-
-----------
|
| 230 |
-
face_stats : Dict
|
| 231 |
-
Dictionary containing face detection statistics
|
| 232 |
-
|
| 233 |
-
Returns:
|
| 234 |
-
--------
|
| 235 |
-
str
|
| 236 |
-
Formatted text analysis of detected faces and their features
|
| 237 |
-
"""
|
| 238 |
-
# Handle case where no statistics are provided
|
| 239 |
-
if not face_stats:
|
| 240 |
-
return "No face detection data available."
|
| 241 |
-
|
| 242 |
-
# Handle case where no faces were detected
|
| 243 |
-
if face_stats.get('total_faces', 0) == 0:
|
| 244 |
-
return "No faces detected in the image."
|
| 245 |
-
|
| 246 |
-
# Build analysis report
|
| 247 |
-
analysis = [] # List to accumulate analysis text
|
| 248 |
-
|
| 249 |
-
# Add header with total face count
|
| 250 |
-
analysis.append(f"👥 Detected {face_stats['total_faces']} face(s) in the image\n")
|
| 251 |
-
|
| 252 |
-
# Add detailed information for each face
|
| 253 |
-
for face in face_stats.get('faces', []):
|
| 254 |
-
# Create description for individual face
|
| 255 |
-
face_desc = f"\n**Face {face['face_id']}:**"
|
| 256 |
-
|
| 257 |
-
# Add position information
|
| 258 |
-
pos = face.get('position', {})
|
| 259 |
-
face_desc += f"\n • Position: ({pos.get('x', 0)}, {pos.get('y', 0)})"
|
| 260 |
-
|
| 261 |
-
# Add size information
|
| 262 |
-
face_desc += f"\n • Size: {pos.get('width', 0)}x{pos.get('height', 0)} pixels"
|
| 263 |
-
|
| 264 |
-
# Add eye detection information if eyes were found
|
| 265 |
-
if face.get('eyes_detected', 0) > 0:
|
| 266 |
-
face_desc += f"\n • Eyes detected: {face['eyes_detected']}"
|
| 267 |
-
|
| 268 |
-
# Add smile detection information
|
| 269 |
-
if face.get('smile_detected', False):
|
| 270 |
-
face_desc += "\n • 😊 Smile detected!"
|
| 271 |
-
|
| 272 |
-
analysis.append(face_desc) # Add face description to analysis
|
| 273 |
-
|
| 274 |
-
# Add summary statistics if smiles were detected
|
| 275 |
-
features = face_stats.get('features_detected', {})
|
| 276 |
-
smile_count = features.get('smiles', 0)
|
| 277 |
-
|
| 278 |
-
if smile_count > 0:
|
| 279 |
-
# Calculate percentage of faces that are smiling
|
| 280 |
-
smile_ratio = (smile_count / face_stats['total_faces']) * 100
|
| 281 |
-
|
| 282 |
-
# Add overall analysis section
|
| 283 |
-
analysis.append(f"\n\n📊 **Overall Analysis:**")
|
| 284 |
-
analysis.append(f"\n • {smile_ratio:.0f}% of faces are smiling")
|
| 285 |
-
analysis.append(f"\n • Total eyes detected: {features.get('eyes', 0)}")
|
| 286 |
-
|
| 287 |
-
# Join all analysis parts and return
|
| 288 |
-
return "".join(analysis)
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
# Create a global instance of FaceDetector for use throughout the application
|
| 292 |
-
# This avoids reloading classifiers multiple times
|
| 293 |
-
try:
|
| 294 |
-
face_detector = FaceDetector()
|
| 295 |
-
logger.info("Global face detector initialized successfully")
|
| 296 |
-
except Exception as e:
|
| 297 |
-
logger.error(f"Failed to initialize global face detector: {str(e)}")
|
| 298 |
-
# Create a dummy detector that returns empty results
|
| 299 |
face_detector = None
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Face Detection Module for NAVADA
|
| 3 |
+
This module provides face detection capabilities using OpenCV's Haar Cascades.
|
| 4 |
+
It can detect faces, eyes, and smiles in images and return detailed statistics.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import cv2 # OpenCV library for computer vision tasks
|
| 8 |
+
import numpy as np # NumPy for numerical operations on arrays
|
| 9 |
+
from typing import Tuple, List, Dict, Optional, Union # Type hints for better code documentation
|
| 10 |
+
import os # Operating system interface for file path operations
|
| 11 |
+
import logging # Logging module for error tracking
|
| 12 |
+
|
| 13 |
+
# Configure logging for this module
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FaceDetector:
|
| 18 |
+
"""
|
| 19 |
+
A class to handle face detection using OpenCV's Haar Cascade classifiers.
|
| 20 |
+
This detector can identify faces, eyes, and smiles in images.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
"""
|
| 25 |
+
Initialize the FaceDetector with pre-trained Haar Cascade classifiers.
|
| 26 |
+
Loads classifiers for face, eye, and smile detection.
|
| 27 |
+
"""
|
| 28 |
+
try:
|
| 29 |
+
# Load the pre-trained Haar Cascade classifier for frontal face detection
|
| 30 |
+
# This XML file contains trained patterns for detecting frontal faces
|
| 31 |
+
self.face_cascade = cv2.CascadeClassifier(
|
| 32 |
+
cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Load the classifier for eye detection
|
| 36 |
+
# This works best when applied to face regions
|
| 37 |
+
self.eye_cascade = cv2.CascadeClassifier(
|
| 38 |
+
cv2.data.haarcascades + 'haarcascade_eye.xml'
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Load the classifier for smile detection
|
| 42 |
+
# This detects smiling expressions in face regions
|
| 43 |
+
self.smile_cascade = cv2.CascadeClassifier(
|
| 44 |
+
cv2.data.haarcascades + 'haarcascade_smile.xml'
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Verify that classifiers loaded successfully
|
| 48 |
+
if self.face_cascade.empty():
|
| 49 |
+
raise ValueError("Failed to load face cascade classifier")
|
| 50 |
+
if self.eye_cascade.empty():
|
| 51 |
+
raise ValueError("Failed to load eye cascade classifier")
|
| 52 |
+
if self.smile_cascade.empty():
|
| 53 |
+
raise ValueError("Failed to load smile cascade classifier")
|
| 54 |
+
|
| 55 |
+
logger.info("Face detection classifiers loaded successfully")
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Error initializing face detector: {str(e)}")
|
| 59 |
+
raise
|
| 60 |
+
|
| 61 |
+
def detect_faces(self, image: Optional[np.ndarray]) -> Tuple[np.ndarray, Dict]:
|
| 62 |
+
"""
|
| 63 |
+
Detect faces in an image and return an annotated image with statistics.
|
| 64 |
+
|
| 65 |
+
Parameters:
|
| 66 |
+
-----------
|
| 67 |
+
image : np.ndarray
|
| 68 |
+
Input image as a NumPy array (can be grayscale or color)
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
--------
|
| 72 |
+
Tuple[np.ndarray, Dict]
|
| 73 |
+
- Annotated image with face detection boxes and labels
|
| 74 |
+
- Dictionary containing detection statistics and face details
|
| 75 |
+
"""
|
| 76 |
+
# Input validation - check if image is provided and valid
|
| 77 |
+
if image is None:
|
| 78 |
+
logger.warning("No image provided for face detection")
|
| 79 |
+
return np.zeros((480, 640, 3), dtype=np.uint8), {
|
| 80 |
+
'total_faces': 0,
|
| 81 |
+
'faces': [],
|
| 82 |
+
'detection_method': 'Haar Cascade',
|
| 83 |
+
'features_detected': {'eyes': 0, 'smiles': 0}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Ensure image is a numpy array
|
| 87 |
+
if not isinstance(image, np.ndarray):
|
| 88 |
+
logger.error("Image must be a numpy array")
|
| 89 |
+
raise TypeError("Image must be a numpy array")
|
| 90 |
+
|
| 91 |
+
# Check if image is empty
|
| 92 |
+
if image.size == 0:
|
| 93 |
+
logger.warning("Empty image provided")
|
| 94 |
+
return image, {
|
| 95 |
+
'total_faces': 0,
|
| 96 |
+
'faces': [],
|
| 97 |
+
'detection_method': 'Haar Cascade',
|
| 98 |
+
'features_detected': {'eyes': 0, 'smiles': 0}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# Convert grayscale images to RGB for consistent output
|
| 102 |
+
# Check the number of dimensions to determine if image is grayscale
|
| 103 |
+
if len(image.shape) == 2: # Grayscale image (height, width)
|
| 104 |
+
img_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 105 |
+
elif len(image.shape) == 3: # Color image (height, width, channels)
|
| 106 |
+
img_rgb = image.copy() # Create a copy to avoid modifying original
|
| 107 |
+
else:
|
| 108 |
+
logger.error(f"Invalid image shape: {image.shape}")
|
| 109 |
+
raise ValueError(f"Invalid image shape: {image.shape}")
|
| 110 |
+
|
| 111 |
+
# Convert to grayscale for detection algorithms
|
| 112 |
+
# Haar Cascades work on grayscale images for better performance
|
| 113 |
+
try:
|
| 114 |
+
gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
|
| 115 |
+
except cv2.error as e:
|
| 116 |
+
logger.error(f"Error converting image to grayscale: {str(e)}")
|
| 117 |
+
return img_rgb, {
|
| 118 |
+
'total_faces': 0,
|
| 119 |
+
'faces': [],
|
| 120 |
+
'detection_method': 'Haar Cascade',
|
| 121 |
+
'features_detected': {'eyes': 0, 'smiles': 0}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
# Detect faces using the Haar Cascade classifier
|
| 125 |
+
# Parameters control detection sensitivity and performance
|
| 126 |
+
faces = self.face_cascade.detectMultiScale(
|
| 127 |
+
gray, # Grayscale image to search
|
| 128 |
+
scaleFactor=1.1, # Image pyramid scaling factor (1.1 = 10% reduction each level)
|
| 129 |
+
minNeighbors=5, # Minimum neighbors for detection confidence
|
| 130 |
+
minSize=(30, 30) # Minimum face size in pixels
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# List to store detailed information about each detected face
|
| 134 |
+
face_details = []
|
| 135 |
+
|
| 136 |
+
# Process each detected face
|
| 137 |
+
for idx, (x, y, w, h) in enumerate(faces):
|
| 138 |
+
# Draw a magenta rectangle around the detected face
|
| 139 |
+
# Parameters: image, top-left corner, bottom-right corner, color (BGR), thickness
|
| 140 |
+
cv2.rectangle(img_rgb, (x, y), (x+w, y+h), (255, 0, 255), 3)
|
| 141 |
+
|
| 142 |
+
# Add a label above each face
|
| 143 |
+
# Parameters: image, text, position, font, scale, color, thickness
|
| 144 |
+
cv2.putText(
|
| 145 |
+
img_rgb,
|
| 146 |
+
f"Face {idx+1}", # Label text
|
| 147 |
+
(x, y-10), # Position (above the rectangle)
|
| 148 |
+
cv2.FONT_HERSHEY_SIMPLEX, # Font type
|
| 149 |
+
0.7, # Font scale
|
| 150 |
+
(255, 0, 255), # Color (magenta in RGB)
|
| 151 |
+
2 # Thickness
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Extract Region of Interest (ROI) for face area
|
| 155 |
+
# This isolates the face region for feature detection
|
| 156 |
+
roi_gray = gray[y:y+h, x:x+w] # Grayscale ROI for detection
|
| 157 |
+
roi_color = img_rgb[y:y+h, x:x+w] # Color ROI for drawing
|
| 158 |
+
|
| 159 |
+
# Detect eyes within the face region
|
| 160 |
+
# Using different parameters for eye detection (more sensitive)
|
| 161 |
+
eyes = self.eye_cascade.detectMultiScale(
|
| 162 |
+
roi_gray,
|
| 163 |
+
scaleFactor=1.05, # Smaller scale factor for finer detection
|
| 164 |
+
minNeighbors=3 # Fewer neighbors required
|
| 165 |
+
)
|
| 166 |
+
eye_count = len(eyes) # Count the number of detected eyes
|
| 167 |
+
|
| 168 |
+
# Draw green rectangles around detected eyes
|
| 169 |
+
for (ex, ey, ew, eh) in eyes:
|
| 170 |
+
cv2.rectangle(
|
| 171 |
+
roi_color, # Draw on the color ROI
|
| 172 |
+
(ex, ey), # Top-left corner
|
| 173 |
+
(ex+ew, ey+eh), # Bottom-right corner
|
| 174 |
+
(0, 255, 0), # Green color in RGB
|
| 175 |
+
2 # Thickness
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Detect smiles within the face region
|
| 179 |
+
# Smile detection requires different parameters
|
| 180 |
+
smiles = self.smile_cascade.detectMultiScale(
|
| 181 |
+
roi_gray,
|
| 182 |
+
scaleFactor=1.8, # Larger scale factor for smile detection
|
| 183 |
+
minNeighbors=20 # More neighbors required for confidence
|
| 184 |
+
)
|
| 185 |
+
has_smile = len(smiles) > 0 # Boolean flag for smile presence
|
| 186 |
+
|
| 187 |
+
# Draw yellow rectangles around detected smiles
|
| 188 |
+
for (sx, sy, sw, sh) in smiles:
|
| 189 |
+
cv2.rectangle(
|
| 190 |
+
roi_color, # Draw on the color ROI
|
| 191 |
+
(sx, sy), # Top-left corner
|
| 192 |
+
(sx+sw, sy+sh), # Bottom-right corner
|
| 193 |
+
(0, 255, 255), # Yellow color in RGB
|
| 194 |
+
2 # Thickness
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Store detailed information about this face
|
| 198 |
+
face_details.append({
|
| 199 |
+
'face_id': idx + 1, # Sequential face ID starting from 1
|
| 200 |
+
'position': { # Face bounding box coordinates
|
| 201 |
+
'x': int(x), # X coordinate of top-left corner
|
| 202 |
+
'y': int(y), # Y coordinate of top-left corner
|
| 203 |
+
'width': int(w), # Width of face bounding box
|
| 204 |
+
'height': int(h) # Height of face bounding box
|
| 205 |
+
},
|
| 206 |
+
'eyes_detected': eye_count, # Number of eyes detected
|
| 207 |
+
'smile_detected': has_smile, # Whether a smile was detected
|
| 208 |
+
'confidence': 0.95 # Placeholder confidence score
|
| 209 |
+
})
|
| 210 |
+
|
| 211 |
+
# Compile comprehensive statistics about all detected faces
|
| 212 |
+
stats = {
|
| 213 |
+
'total_faces': len(faces), # Total number of faces detected
|
| 214 |
+
'faces': face_details, # List of detailed face information
|
| 215 |
+
'detection_method': 'Haar Cascade', # Method used for detection
|
| 216 |
+
'features_detected': { # Aggregate feature statistics
|
| 217 |
+
'eyes': sum(f['eyes_detected'] for f in face_details), # Total eyes
|
| 218 |
+
'smiles': sum(1 for f in face_details if f['smile_detected']) # Total smiles
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
return img_rgb, stats # Return annotated image and statistics
|
| 223 |
+
|
| 224 |
+
def analyze_demographics(self, face_stats: Optional[Dict]) -> str:
|
| 225 |
+
"""
|
| 226 |
+
Create a demographic analysis report based on face detection statistics.
|
| 227 |
+
|
| 228 |
+
Parameters:
|
| 229 |
+
-----------
|
| 230 |
+
face_stats : Dict
|
| 231 |
+
Dictionary containing face detection statistics
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
--------
|
| 235 |
+
str
|
| 236 |
+
Formatted text analysis of detected faces and their features
|
| 237 |
+
"""
|
| 238 |
+
# Handle case where no statistics are provided
|
| 239 |
+
if not face_stats:
|
| 240 |
+
return "No face detection data available."
|
| 241 |
+
|
| 242 |
+
# Handle case where no faces were detected
|
| 243 |
+
if face_stats.get('total_faces', 0) == 0:
|
| 244 |
+
return "No faces detected in the image."
|
| 245 |
+
|
| 246 |
+
# Build analysis report
|
| 247 |
+
analysis = [] # List to accumulate analysis text
|
| 248 |
+
|
| 249 |
+
# Add header with total face count
|
| 250 |
+
analysis.append(f"👥 Detected {face_stats['total_faces']} face(s) in the image\n")
|
| 251 |
+
|
| 252 |
+
# Add detailed information for each face
|
| 253 |
+
for face in face_stats.get('faces', []):
|
| 254 |
+
# Create description for individual face
|
| 255 |
+
face_desc = f"\n**Face {face['face_id']}:**"
|
| 256 |
+
|
| 257 |
+
# Add position information
|
| 258 |
+
pos = face.get('position', {})
|
| 259 |
+
face_desc += f"\n • Position: ({pos.get('x', 0)}, {pos.get('y', 0)})"
|
| 260 |
+
|
| 261 |
+
# Add size information
|
| 262 |
+
face_desc += f"\n • Size: {pos.get('width', 0)}x{pos.get('height', 0)} pixels"
|
| 263 |
+
|
| 264 |
+
# Add eye detection information if eyes were found
|
| 265 |
+
if face.get('eyes_detected', 0) > 0:
|
| 266 |
+
face_desc += f"\n • Eyes detected: {face['eyes_detected']}"
|
| 267 |
+
|
| 268 |
+
# Add smile detection information
|
| 269 |
+
if face.get('smile_detected', False):
|
| 270 |
+
face_desc += "\n • 😊 Smile detected!"
|
| 271 |
+
|
| 272 |
+
analysis.append(face_desc) # Add face description to analysis
|
| 273 |
+
|
| 274 |
+
# Add summary statistics if smiles were detected
|
| 275 |
+
features = face_stats.get('features_detected', {})
|
| 276 |
+
smile_count = features.get('smiles', 0)
|
| 277 |
+
|
| 278 |
+
if smile_count > 0:
|
| 279 |
+
# Calculate percentage of faces that are smiling
|
| 280 |
+
smile_ratio = (smile_count / face_stats['total_faces']) * 100
|
| 281 |
+
|
| 282 |
+
# Add overall analysis section
|
| 283 |
+
analysis.append(f"\n\n📊 **Overall Analysis:**")
|
| 284 |
+
analysis.append(f"\n • {smile_ratio:.0f}% of faces are smiling")
|
| 285 |
+
analysis.append(f"\n • Total eyes detected: {features.get('eyes', 0)}")
|
| 286 |
+
|
| 287 |
+
# Join all analysis parts and return
|
| 288 |
+
return "".join(analysis)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Create a global instance of FaceDetector for use throughout the application
|
| 292 |
+
# This avoids reloading classifiers multiple times
|
| 293 |
+
try:
|
| 294 |
+
face_detector = FaceDetector()
|
| 295 |
+
logger.info("Global face detector initialized successfully")
|
| 296 |
+
except Exception as e:
|
| 297 |
+
logger.error(f"Failed to initialize global face detector: {str(e)}")
|
| 298 |
+
# Create a dummy detector that returns empty results
|
| 299 |
face_detector = None
|
backend/openai_client.py
CHANGED
|
@@ -1,65 +1,65 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from openai import OpenAI # type: ignore
|
| 3 |
-
import tempfile
|
| 4 |
-
|
| 5 |
-
# Lazily initialized OpenAI client to avoid import-time errors when the
|
| 6 |
-
# API key isn't configured. Previously this module attempted to create the
|
| 7 |
-
# client on import and raised a ``ValueError`` if ``OPENAI_API_KEY`` was
|
| 8 |
-
# missing, which prevented the rest of the application from running (and
|
| 9 |
-
# broke tests that don't require the API). The client is now created only
|
| 10 |
-
# when needed.
|
| 11 |
-
_client: OpenAI | None = None
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def _get_client() -> OpenAI:
|
| 15 |
-
"""Return a cached OpenAI client instance.
|
| 16 |
-
|
| 17 |
-
Raises:
|
| 18 |
-
ValueError: If the ``OPENAI_API_KEY`` environment variable is not set.
|
| 19 |
-
"""
|
| 20 |
-
global _client
|
| 21 |
-
if _client is None:
|
| 22 |
-
api_key = os.getenv("OPENAI_API_KEY")
|
| 23 |
-
if not api_key:
|
| 24 |
-
raise ValueError("OPENAI_API_KEY environment variable is required but not set")
|
| 25 |
-
_client = OpenAI(api_key=api_key)
|
| 26 |
-
return _client
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def explain_detection(objects_list):
|
| 30 |
-
"""Send detected objects to OpenAI and return an explanation."""
|
| 31 |
-
if not objects_list:
|
| 32 |
-
return "No objects detected."
|
| 33 |
-
|
| 34 |
-
prompt = f"Explain these detected objects in simple terms: {objects_list}"
|
| 35 |
-
|
| 36 |
-
client = _get_client()
|
| 37 |
-
response = client.chat.completions.create(
|
| 38 |
-
model="gpt-4o-mini", # new lightweight chat model
|
| 39 |
-
messages=[{"role": "user", "content": prompt}],
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
return response.choices[0].message.content
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def generate_voice(text):
|
| 46 |
-
"""Generate voice narration using OpenAI's TTS service."""
|
| 47 |
-
try:
|
| 48 |
-
client = _get_client()
|
| 49 |
-
|
| 50 |
-
# Generate speech using OpenAI TTS
|
| 51 |
-
response = client.audio.speech.create(
|
| 52 |
-
model="tts-1",
|
| 53 |
-
voice="alloy", # You can change this to: alloy, echo, fable, onyx, nova, or shimmer
|
| 54 |
-
input=text,
|
| 55 |
-
response_format="mp3",
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
# Save the audio to a temporary file
|
| 59 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
|
| 60 |
-
temp_audio.write(response.content)
|
| 61 |
-
return temp_audio.name
|
| 62 |
-
|
| 63 |
-
except Exception as e:
|
| 64 |
-
print(f"Voice generation error: {e}")
|
| 65 |
-
return None
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from openai import OpenAI # type: ignore
|
| 3 |
+
import tempfile
|
| 4 |
+
|
| 5 |
+
# Lazily initialized OpenAI client to avoid import-time errors when the
|
| 6 |
+
# API key isn't configured. Previously this module attempted to create the
|
| 7 |
+
# client on import and raised a ``ValueError`` if ``OPENAI_API_KEY`` was
|
| 8 |
+
# missing, which prevented the rest of the application from running (and
|
| 9 |
+
# broke tests that don't require the API). The client is now created only
|
| 10 |
+
# when needed.
|
| 11 |
+
_client: OpenAI | None = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _get_client() -> OpenAI:
|
| 15 |
+
"""Return a cached OpenAI client instance.
|
| 16 |
+
|
| 17 |
+
Raises:
|
| 18 |
+
ValueError: If the ``OPENAI_API_KEY`` environment variable is not set.
|
| 19 |
+
"""
|
| 20 |
+
global _client
|
| 21 |
+
if _client is None:
|
| 22 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 23 |
+
if not api_key:
|
| 24 |
+
raise ValueError("OPENAI_API_KEY environment variable is required but not set")
|
| 25 |
+
_client = OpenAI(api_key=api_key)
|
| 26 |
+
return _client
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def explain_detection(objects_list):
|
| 30 |
+
"""Send detected objects to OpenAI and return an explanation."""
|
| 31 |
+
if not objects_list:
|
| 32 |
+
return "No objects detected."
|
| 33 |
+
|
| 34 |
+
prompt = f"Explain these detected objects in simple terms: {objects_list}"
|
| 35 |
+
|
| 36 |
+
client = _get_client()
|
| 37 |
+
response = client.chat.completions.create(
|
| 38 |
+
model="gpt-4o-mini", # new lightweight chat model
|
| 39 |
+
messages=[{"role": "user", "content": prompt}],
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return response.choices[0].message.content
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def generate_voice(text):
|
| 46 |
+
"""Generate voice narration using OpenAI's TTS service."""
|
| 47 |
+
try:
|
| 48 |
+
client = _get_client()
|
| 49 |
+
|
| 50 |
+
# Generate speech using OpenAI TTS
|
| 51 |
+
response = client.audio.speech.create(
|
| 52 |
+
model="tts-1",
|
| 53 |
+
voice="alloy", # You can change this to: alloy, echo, fable, onyx, nova, or shimmer
|
| 54 |
+
input=text,
|
| 55 |
+
response_format="mp3",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Save the audio to a temporary file
|
| 59 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
|
| 60 |
+
temp_audio.write(response.content)
|
| 61 |
+
return temp_audio.name
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Voice generation error: {e}")
|
| 65 |
+
return None
|
backend/prisma_client.py
CHANGED
|
@@ -1,400 +1,400 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Prisma Client Integration for NAVADA 2.0
|
| 3 |
-
Provides enhanced database operations with Prisma ORM
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import asyncio
|
| 7 |
-
import json
|
| 8 |
-
import base64
|
| 9 |
-
import logging
|
| 10 |
-
from typing import List, Dict, Optional, Any
|
| 11 |
-
from datetime import datetime
|
| 12 |
-
import numpy as np
|
| 13 |
-
import cv2
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger(__name__)
|
| 16 |
-
|
| 17 |
-
class PrismaManager:
|
| 18 |
-
"""Enhanced database manager using Prisma ORM"""
|
| 19 |
-
|
| 20 |
-
def __init__(self):
|
| 21 |
-
self.client = None
|
| 22 |
-
self._init_client()
|
| 23 |
-
|
| 24 |
-
def _init_client(self):
|
| 25 |
-
"""Initialize Prisma client"""
|
| 26 |
-
try:
|
| 27 |
-
# Import Prisma client (needs to be generated first)
|
| 28 |
-
# from prisma import Prisma
|
| 29 |
-
# self.client = Prisma()
|
| 30 |
-
logger.info("Prisma client initialized")
|
| 31 |
-
except ImportError:
|
| 32 |
-
logger.warning("Prisma client not available - run 'npm run prisma:generate'")
|
| 33 |
-
self.client = None
|
| 34 |
-
except Exception as e:
|
| 35 |
-
logger.error(f"Failed to initialize Prisma client: {e}")
|
| 36 |
-
self.client = None
|
| 37 |
-
|
| 38 |
-
async def connect(self):
|
| 39 |
-
"""Connect to database"""
|
| 40 |
-
if self.client:
|
| 41 |
-
try:
|
| 42 |
-
await self.client.connect()
|
| 43 |
-
logger.info("Connected to database via Prisma")
|
| 44 |
-
return True
|
| 45 |
-
except Exception as e:
|
| 46 |
-
logger.error(f"Failed to connect to database: {e}")
|
| 47 |
-
return False
|
| 48 |
-
return False
|
| 49 |
-
|
| 50 |
-
async def disconnect(self):
|
| 51 |
-
"""Disconnect from database"""
|
| 52 |
-
if self.client:
|
| 53 |
-
try:
|
| 54 |
-
await self.client.disconnect()
|
| 55 |
-
logger.info("Disconnected from database")
|
| 56 |
-
except Exception as e:
|
| 57 |
-
logger.error(f"Error disconnecting: {e}")
|
| 58 |
-
|
| 59 |
-
# Document Management for Knowledge Retrieval
|
| 60 |
-
async def add_document(self, title: str, content: str, content_type: str = "text",
|
| 61 |
-
tags: List[str] = None, category: str = None,
|
| 62 |
-
image_data: bytes = None, image_url: str = None) -> Optional[int]:
|
| 63 |
-
"""
|
| 64 |
-
Add document for knowledge retrieval
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
title: Document title
|
| 68 |
-
content: Document content (text)
|
| 69 |
-
content_type: "text", "image", "mixed"
|
| 70 |
-
tags: List of tags
|
| 71 |
-
category: Document category
|
| 72 |
-
image_data: Binary image data
|
| 73 |
-
image_url: URL to image
|
| 74 |
-
|
| 75 |
-
Returns:
|
| 76 |
-
Document ID if successful
|
| 77 |
-
"""
|
| 78 |
-
if not self.client:
|
| 79 |
-
return None
|
| 80 |
-
|
| 81 |
-
try:
|
| 82 |
-
tags_str = json.dumps(tags) if tags else None
|
| 83 |
-
|
| 84 |
-
document = await self.client.document.create(
|
| 85 |
-
data={
|
| 86 |
-
'title': title,
|
| 87 |
-
'content': content,
|
| 88 |
-
'contentType': content_type,
|
| 89 |
-
'tags': tags_str,
|
| 90 |
-
'category': category,
|
| 91 |
-
'imageData': image_data,
|
| 92 |
-
'imageUrl': image_url
|
| 93 |
-
}
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
# Create document chunks for better retrieval
|
| 97 |
-
await self._create_document_chunks(document.id, content)
|
| 98 |
-
|
| 99 |
-
logger.info(f"Added document: {title} (ID: {document.id})")
|
| 100 |
-
return document.id
|
| 101 |
-
|
| 102 |
-
except Exception as e:
|
| 103 |
-
logger.error(f"Failed to add document: {e}")
|
| 104 |
-
return None
|
| 105 |
-
|
| 106 |
-
async def _create_document_chunks(self, document_id: int, content: str, chunk_size: int = 500):
|
| 107 |
-
"""Create chunks from document content for better retrieval"""
|
| 108 |
-
if not self.client:
|
| 109 |
-
return
|
| 110 |
-
|
| 111 |
-
try:
|
| 112 |
-
# Split content into chunks
|
| 113 |
-
chunks = [content[i:i+chunk_size] for i in range(0, len(content), chunk_size)]
|
| 114 |
-
|
| 115 |
-
for i, chunk in enumerate(chunks):
|
| 116 |
-
await self.client.documentchunk.create(
|
| 117 |
-
data={
|
| 118 |
-
'documentId': document_id,
|
| 119 |
-
'chunkIndex': i,
|
| 120 |
-
'content': chunk
|
| 121 |
-
}
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
except Exception as e:
|
| 125 |
-
logger.error(f"Failed to create document chunks: {e}")
|
| 126 |
-
|
| 127 |
-
async def search_documents(self, query: str, content_type: str = None,
|
| 128 |
-
category: str = None, limit: int = 10) -> List[Dict]:
|
| 129 |
-
"""
|
| 130 |
-
Search documents by content, tags, or category
|
| 131 |
-
|
| 132 |
-
Args:
|
| 133 |
-
query: Search query
|
| 134 |
-
content_type: Filter by content type
|
| 135 |
-
category: Filter by category
|
| 136 |
-
limit: Maximum results
|
| 137 |
-
|
| 138 |
-
Returns:
|
| 139 |
-
List of matching documents
|
| 140 |
-
"""
|
| 141 |
-
if not self.client:
|
| 142 |
-
return []
|
| 143 |
-
|
| 144 |
-
try:
|
| 145 |
-
where_clause = {
|
| 146 |
-
'isActive': True,
|
| 147 |
-
'OR': [
|
| 148 |
-
{'title': {'contains': query}},
|
| 149 |
-
{'content': {'contains': query}},
|
| 150 |
-
{'tags': {'contains': query}}
|
| 151 |
-
]
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
if content_type:
|
| 155 |
-
where_clause['contentType'] = content_type
|
| 156 |
-
if category:
|
| 157 |
-
where_clause['category'] = category
|
| 158 |
-
|
| 159 |
-
documents = await self.client.document.find_many(
|
| 160 |
-
where=where_clause,
|
| 161 |
-
take=limit,
|
| 162 |
-
order_by={'createdAt': 'desc'}
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
return [self._document_to_dict(doc) for doc in documents]
|
| 166 |
-
|
| 167 |
-
except Exception as e:
|
| 168 |
-
logger.error(f"Document search failed: {e}")
|
| 169 |
-
return []
|
| 170 |
-
|
| 171 |
-
def _document_to_dict(self, document) -> Dict:
|
| 172 |
-
"""Convert Prisma document to dictionary"""
|
| 173 |
-
return {
|
| 174 |
-
'id': document.id,
|
| 175 |
-
'title': document.title,
|
| 176 |
-
'content': document.content,
|
| 177 |
-
'content_type': document.contentType,
|
| 178 |
-
'tags': json.loads(document.tags) if document.tags else [],
|
| 179 |
-
'category': document.category,
|
| 180 |
-
'image_url': document.imageUrl,
|
| 181 |
-
'created_at': document.createdAt,
|
| 182 |
-
'updated_at': document.updatedAt
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
# Media File Management
|
| 186 |
-
async def add_media_file(self, filename: str, filepath: str, mime_type: str,
|
| 187 |
-
file_size: int, image_data: bytes = None,
|
| 188 |
-
description: str = None, tags: List[str] = None) -> Optional[int]:
|
| 189 |
-
"""Add media file to database"""
|
| 190 |
-
if not self.client:
|
| 191 |
-
return None
|
| 192 |
-
|
| 193 |
-
try:
|
| 194 |
-
tags_str = json.dumps(tags) if tags else None
|
| 195 |
-
|
| 196 |
-
media_file = await self.client.mediafile.create(
|
| 197 |
-
data={
|
| 198 |
-
'filename': filename,
|
| 199 |
-
'filepath': filepath,
|
| 200 |
-
'mimeType': mime_type,
|
| 201 |
-
'fileSize': file_size,
|
| 202 |
-
'imageData': image_data,
|
| 203 |
-
'description': description,
|
| 204 |
-
'tags': tags_str
|
| 205 |
-
}
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
logger.info(f"Added media file: {filename} (ID: {media_file.id})")
|
| 209 |
-
return media_file.id
|
| 210 |
-
|
| 211 |
-
except Exception as e:
|
| 212 |
-
logger.error(f"Failed to add media file: {e}")
|
| 213 |
-
return None
|
| 214 |
-
|
| 215 |
-
async def get_media_files(self, tags: List[str] = None, mime_type: str = None,
|
| 216 |
-
limit: int = 50) -> List[Dict]:
|
| 217 |
-
"""Get media files with optional filtering"""
|
| 218 |
-
if not self.client:
|
| 219 |
-
return []
|
| 220 |
-
|
| 221 |
-
try:
|
| 222 |
-
where_clause = {'isActive': True}
|
| 223 |
-
|
| 224 |
-
if mime_type:
|
| 225 |
-
where_clause['mimeType'] = {'contains': mime_type}
|
| 226 |
-
|
| 227 |
-
if tags:
|
| 228 |
-
# Search for any of the provided tags
|
| 229 |
-
tag_conditions = [{'tags': {'contains': tag}} for tag in tags]
|
| 230 |
-
where_clause['OR'] = tag_conditions
|
| 231 |
-
|
| 232 |
-
media_files = await self.client.mediafile.find_many(
|
| 233 |
-
where=where_clause,
|
| 234 |
-
take=limit,
|
| 235 |
-
order_by={'createdAt': 'desc'}
|
| 236 |
-
)
|
| 237 |
-
|
| 238 |
-
return [self._media_file_to_dict(file) for file in media_files]
|
| 239 |
-
|
| 240 |
-
except Exception as e:
|
| 241 |
-
logger.error(f"Failed to get media files: {e}")
|
| 242 |
-
return []
|
| 243 |
-
|
| 244 |
-
def _media_file_to_dict(self, media_file) -> Dict:
|
| 245 |
-
"""Convert Prisma media file to dictionary"""
|
| 246 |
-
return {
|
| 247 |
-
'id': media_file.id,
|
| 248 |
-
'filename': media_file.filename,
|
| 249 |
-
'filepath': media_file.filepath,
|
| 250 |
-
'mime_type': media_file.mimeType,
|
| 251 |
-
'file_size': media_file.fileSize,
|
| 252 |
-
'description': media_file.description,
|
| 253 |
-
'tags': json.loads(media_file.tags) if media_file.tags else [],
|
| 254 |
-
'created_at': media_file.createdAt
|
| 255 |
-
}
|
| 256 |
-
|
| 257 |
-
# Enhanced Knowledge Base Operations
|
| 258 |
-
async def add_knowledge_entry(self, entity_type: str, entity_id: int, content: str,
|
| 259 |
-
title: str = None, description: str = None,
|
| 260 |
-
tags: List[str] = None, category: str = None,
|
| 261 |
-
image_url: str = None, text_content: str = None) -> Optional[int]:
|
| 262 |
-
"""Add enhanced knowledge base entry"""
|
| 263 |
-
if not self.client:
|
| 264 |
-
return None
|
| 265 |
-
|
| 266 |
-
try:
|
| 267 |
-
keywords_str = json.dumps(tags) if tags else None
|
| 268 |
-
|
| 269 |
-
knowledge_entry = await self.client.knowledgebase.create(
|
| 270 |
-
data={
|
| 271 |
-
'entityType': entity_type,
|
| 272 |
-
'entityId': entity_id,
|
| 273 |
-
'content': content,
|
| 274 |
-
'title': title,
|
| 275 |
-
'description': description,
|
| 276 |
-
'keywords': keywords_str,
|
| 277 |
-
'category': category,
|
| 278 |
-
'imageUrl': image_url,
|
| 279 |
-
'textContent': text_content
|
| 280 |
-
}
|
| 281 |
-
)
|
| 282 |
-
|
| 283 |
-
logger.info(f"Added knowledge entry: {title or content[:50]}")
|
| 284 |
-
return knowledge_entry.id
|
| 285 |
-
|
| 286 |
-
except Exception as e:
|
| 287 |
-
logger.error(f"Failed to add knowledge entry: {e}")
|
| 288 |
-
return None
|
| 289 |
-
|
| 290 |
-
async def search_knowledge(self, query: str, entity_type: str = None,
|
| 291 |
-
category: str = None, limit: int = 10) -> List[Dict]:
|
| 292 |
-
"""Enhanced knowledge search"""
|
| 293 |
-
if not self.client:
|
| 294 |
-
return []
|
| 295 |
-
|
| 296 |
-
try:
|
| 297 |
-
where_clause = {
|
| 298 |
-
'OR': [
|
| 299 |
-
{'content': {'contains': query}},
|
| 300 |
-
{'title': {'contains': query}},
|
| 301 |
-
{'description': {'contains': query}},
|
| 302 |
-
{'keywords': {'contains': query}},
|
| 303 |
-
{'textContent': {'contains': query}}
|
| 304 |
-
]
|
| 305 |
-
}
|
| 306 |
-
|
| 307 |
-
if entity_type:
|
| 308 |
-
where_clause['entityType'] = entity_type
|
| 309 |
-
if category:
|
| 310 |
-
where_clause['category'] = category
|
| 311 |
-
|
| 312 |
-
entries = await self.client.knowledgebase.find_many(
|
| 313 |
-
where=where_clause,
|
| 314 |
-
take=limit,
|
| 315 |
-
order_by={'createdAt': 'desc'}
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
return [self._knowledge_to_dict(entry) for entry in entries]
|
| 319 |
-
|
| 320 |
-
except Exception as e:
|
| 321 |
-
logger.error(f"Knowledge search failed: {e}")
|
| 322 |
-
return []
|
| 323 |
-
|
| 324 |
-
def _knowledge_to_dict(self, entry) -> Dict:
|
| 325 |
-
"""Convert Prisma knowledge entry to dictionary"""
|
| 326 |
-
return {
|
| 327 |
-
'id': entry.id,
|
| 328 |
-
'entity_type': entry.entityType,
|
| 329 |
-
'entity_id': entry.entityId,
|
| 330 |
-
'content': entry.content,
|
| 331 |
-
'title': entry.title,
|
| 332 |
-
'description': entry.description,
|
| 333 |
-
'keywords': json.loads(entry.keywords) if entry.keywords else [],
|
| 334 |
-
'category': entry.category,
|
| 335 |
-
'image_url': entry.imageUrl,
|
| 336 |
-
'text_content': entry.textContent,
|
| 337 |
-
'created_at': entry.createdAt,
|
| 338 |
-
'updated_at': entry.updatedAt
|
| 339 |
-
}
|
| 340 |
-
|
| 341 |
-
# Statistics and Analytics
|
| 342 |
-
async def get_enhanced_stats(self) -> Dict:
|
| 343 |
-
"""Get comprehensive database statistics"""
|
| 344 |
-
if not self.client:
|
| 345 |
-
return {}
|
| 346 |
-
|
| 347 |
-
try:
|
| 348 |
-
stats = {}
|
| 349 |
-
|
| 350 |
-
# Basic counts
|
| 351 |
-
stats['faces'] = await self.client.face.count(where={'isActive': True})
|
| 352 |
-
stats['objects'] = await self.client.object.count(where={'isActive': True})
|
| 353 |
-
stats['documents'] = await self.client.document.count(where={'isActive': True})
|
| 354 |
-
stats['media_files'] = await self.client.mediafile.count(where={'isActive': True})
|
| 355 |
-
stats['knowledge_entries'] = await self.client.knowledgebase.count()
|
| 356 |
-
stats['training_corrections'] = await self.client.trainingcorrection.count()
|
| 357 |
-
|
| 358 |
-
# Recent activity (last 7 days)
|
| 359 |
-
seven_days_ago = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
| 360 |
-
stats['recent_detections'] = await self.client.detectionhistory.count(
|
| 361 |
-
where={'createdAt': {'gte': seven_days_ago}}
|
| 362 |
-
)
|
| 363 |
-
|
| 364 |
-
return stats
|
| 365 |
-
|
| 366 |
-
except Exception as e:
|
| 367 |
-
logger.error(f"Failed to get enhanced stats: {e}")
|
| 368 |
-
return {}
|
| 369 |
-
|
| 370 |
-
# Global Prisma manager instance
|
| 371 |
-
prisma_manager = PrismaManager()
|
| 372 |
-
|
| 373 |
-
# Helper functions for async operations in Streamlit
|
| 374 |
-
def run_async(coro):
|
| 375 |
-
"""Run async function in Streamlit"""
|
| 376 |
-
try:
|
| 377 |
-
loop = asyncio.get_event_loop()
|
| 378 |
-
except RuntimeError:
|
| 379 |
-
loop = asyncio.new_event_loop()
|
| 380 |
-
asyncio.set_event_loop(loop)
|
| 381 |
-
|
| 382 |
-
return loop.run_until_complete(coro)
|
| 383 |
-
|
| 384 |
-
# Convenience functions
|
| 385 |
-
def add_document_sync(title: str, content: str, **kwargs) -> Optional[int]:
|
| 386 |
-
"""Synchronous wrapper for adding documents"""
|
| 387 |
-
return run_async(prisma_manager.add_document(title, content, **kwargs))
|
| 388 |
-
|
| 389 |
-
def search_documents_sync(query: str, **kwargs) -> List[Dict]:
|
| 390 |
-
"""Synchronous wrapper for searching documents"""
|
| 391 |
-
return run_async(prisma_manager.search_documents(query, **kwargs))
|
| 392 |
-
|
| 393 |
-
def add_media_file_sync(filename: str, filepath: str, mime_type: str,
|
| 394 |
-
file_size: int, **kwargs) -> Optional[int]:
|
| 395 |
-
"""Synchronous wrapper for adding media files"""
|
| 396 |
-
return run_async(prisma_manager.add_media_file(filename, filepath, mime_type, file_size, **kwargs))
|
| 397 |
-
|
| 398 |
-
def get_enhanced_stats_sync() -> Dict:
|
| 399 |
-
"""Synchronous wrapper for getting stats"""
|
| 400 |
return run_async(prisma_manager.get_enhanced_stats())
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prisma Client Integration for NAVADA 2.0
|
| 3 |
+
Provides enhanced database operations with Prisma ORM
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import json
|
| 8 |
+
import base64
|
| 9 |
+
import logging
|
| 10 |
+
from typing import List, Dict, Optional, Any
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import numpy as np
|
| 13 |
+
import cv2
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class PrismaManager:
|
| 18 |
+
"""Enhanced database manager using Prisma ORM"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.client = None
|
| 22 |
+
self._init_client()
|
| 23 |
+
|
| 24 |
+
def _init_client(self):
|
| 25 |
+
"""Initialize Prisma client"""
|
| 26 |
+
try:
|
| 27 |
+
# Import Prisma client (needs to be generated first)
|
| 28 |
+
# from prisma import Prisma
|
| 29 |
+
# self.client = Prisma()
|
| 30 |
+
logger.info("Prisma client initialized")
|
| 31 |
+
except ImportError:
|
| 32 |
+
logger.warning("Prisma client not available - run 'npm run prisma:generate'")
|
| 33 |
+
self.client = None
|
| 34 |
+
except Exception as e:
|
| 35 |
+
logger.error(f"Failed to initialize Prisma client: {e}")
|
| 36 |
+
self.client = None
|
| 37 |
+
|
| 38 |
+
async def connect(self):
|
| 39 |
+
"""Connect to database"""
|
| 40 |
+
if self.client:
|
| 41 |
+
try:
|
| 42 |
+
await self.client.connect()
|
| 43 |
+
logger.info("Connected to database via Prisma")
|
| 44 |
+
return True
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logger.error(f"Failed to connect to database: {e}")
|
| 47 |
+
return False
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
async def disconnect(self):
|
| 51 |
+
"""Disconnect from database"""
|
| 52 |
+
if self.client:
|
| 53 |
+
try:
|
| 54 |
+
await self.client.disconnect()
|
| 55 |
+
logger.info("Disconnected from database")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Error disconnecting: {e}")
|
| 58 |
+
|
| 59 |
+
# Document Management for Knowledge Retrieval
|
| 60 |
+
async def add_document(self, title: str, content: str, content_type: str = "text",
|
| 61 |
+
tags: List[str] = None, category: str = None,
|
| 62 |
+
image_data: bytes = None, image_url: str = None) -> Optional[int]:
|
| 63 |
+
"""
|
| 64 |
+
Add document for knowledge retrieval
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
title: Document title
|
| 68 |
+
content: Document content (text)
|
| 69 |
+
content_type: "text", "image", "mixed"
|
| 70 |
+
tags: List of tags
|
| 71 |
+
category: Document category
|
| 72 |
+
image_data: Binary image data
|
| 73 |
+
image_url: URL to image
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Document ID if successful
|
| 77 |
+
"""
|
| 78 |
+
if not self.client:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
tags_str = json.dumps(tags) if tags else None
|
| 83 |
+
|
| 84 |
+
document = await self.client.document.create(
|
| 85 |
+
data={
|
| 86 |
+
'title': title,
|
| 87 |
+
'content': content,
|
| 88 |
+
'contentType': content_type,
|
| 89 |
+
'tags': tags_str,
|
| 90 |
+
'category': category,
|
| 91 |
+
'imageData': image_data,
|
| 92 |
+
'imageUrl': image_url
|
| 93 |
+
}
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Create document chunks for better retrieval
|
| 97 |
+
await self._create_document_chunks(document.id, content)
|
| 98 |
+
|
| 99 |
+
logger.info(f"Added document: {title} (ID: {document.id})")
|
| 100 |
+
return document.id
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error(f"Failed to add document: {e}")
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
async def _create_document_chunks(self, document_id: int, content: str, chunk_size: int = 500):
|
| 107 |
+
"""Create chunks from document content for better retrieval"""
|
| 108 |
+
if not self.client:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
# Split content into chunks
|
| 113 |
+
chunks = [content[i:i+chunk_size] for i in range(0, len(content), chunk_size)]
|
| 114 |
+
|
| 115 |
+
for i, chunk in enumerate(chunks):
|
| 116 |
+
await self.client.documentchunk.create(
|
| 117 |
+
data={
|
| 118 |
+
'documentId': document_id,
|
| 119 |
+
'chunkIndex': i,
|
| 120 |
+
'content': chunk
|
| 121 |
+
}
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(f"Failed to create document chunks: {e}")
|
| 126 |
+
|
| 127 |
+
async def search_documents(self, query: str, content_type: str = None,
|
| 128 |
+
category: str = None, limit: int = 10) -> List[Dict]:
|
| 129 |
+
"""
|
| 130 |
+
Search documents by content, tags, or category
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
query: Search query
|
| 134 |
+
content_type: Filter by content type
|
| 135 |
+
category: Filter by category
|
| 136 |
+
limit: Maximum results
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
List of matching documents
|
| 140 |
+
"""
|
| 141 |
+
if not self.client:
|
| 142 |
+
return []
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
where_clause = {
|
| 146 |
+
'isActive': True,
|
| 147 |
+
'OR': [
|
| 148 |
+
{'title': {'contains': query}},
|
| 149 |
+
{'content': {'contains': query}},
|
| 150 |
+
{'tags': {'contains': query}}
|
| 151 |
+
]
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
if content_type:
|
| 155 |
+
where_clause['contentType'] = content_type
|
| 156 |
+
if category:
|
| 157 |
+
where_clause['category'] = category
|
| 158 |
+
|
| 159 |
+
documents = await self.client.document.find_many(
|
| 160 |
+
where=where_clause,
|
| 161 |
+
take=limit,
|
| 162 |
+
order_by={'createdAt': 'desc'}
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return [self._document_to_dict(doc) for doc in documents]
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Document search failed: {e}")
|
| 169 |
+
return []
|
| 170 |
+
|
| 171 |
+
def _document_to_dict(self, document) -> Dict:
|
| 172 |
+
"""Convert Prisma document to dictionary"""
|
| 173 |
+
return {
|
| 174 |
+
'id': document.id,
|
| 175 |
+
'title': document.title,
|
| 176 |
+
'content': document.content,
|
| 177 |
+
'content_type': document.contentType,
|
| 178 |
+
'tags': json.loads(document.tags) if document.tags else [],
|
| 179 |
+
'category': document.category,
|
| 180 |
+
'image_url': document.imageUrl,
|
| 181 |
+
'created_at': document.createdAt,
|
| 182 |
+
'updated_at': document.updatedAt
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
# Media File Management
|
| 186 |
+
async def add_media_file(self, filename: str, filepath: str, mime_type: str,
|
| 187 |
+
file_size: int, image_data: bytes = None,
|
| 188 |
+
description: str = None, tags: List[str] = None) -> Optional[int]:
|
| 189 |
+
"""Add media file to database"""
|
| 190 |
+
if not self.client:
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
tags_str = json.dumps(tags) if tags else None
|
| 195 |
+
|
| 196 |
+
media_file = await self.client.mediafile.create(
|
| 197 |
+
data={
|
| 198 |
+
'filename': filename,
|
| 199 |
+
'filepath': filepath,
|
| 200 |
+
'mimeType': mime_type,
|
| 201 |
+
'fileSize': file_size,
|
| 202 |
+
'imageData': image_data,
|
| 203 |
+
'description': description,
|
| 204 |
+
'tags': tags_str
|
| 205 |
+
}
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
logger.info(f"Added media file: {filename} (ID: {media_file.id})")
|
| 209 |
+
return media_file.id
|
| 210 |
+
|
| 211 |
+
except Exception as e:
|
| 212 |
+
logger.error(f"Failed to add media file: {e}")
|
| 213 |
+
return None
|
| 214 |
+
|
| 215 |
+
async def get_media_files(self, tags: List[str] = None, mime_type: str = None,
|
| 216 |
+
limit: int = 50) -> List[Dict]:
|
| 217 |
+
"""Get media files with optional filtering"""
|
| 218 |
+
if not self.client:
|
| 219 |
+
return []
|
| 220 |
+
|
| 221 |
+
try:
|
| 222 |
+
where_clause = {'isActive': True}
|
| 223 |
+
|
| 224 |
+
if mime_type:
|
| 225 |
+
where_clause['mimeType'] = {'contains': mime_type}
|
| 226 |
+
|
| 227 |
+
if tags:
|
| 228 |
+
# Search for any of the provided tags
|
| 229 |
+
tag_conditions = [{'tags': {'contains': tag}} for tag in tags]
|
| 230 |
+
where_clause['OR'] = tag_conditions
|
| 231 |
+
|
| 232 |
+
media_files = await self.client.mediafile.find_many(
|
| 233 |
+
where=where_clause,
|
| 234 |
+
take=limit,
|
| 235 |
+
order_by={'createdAt': 'desc'}
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
return [self._media_file_to_dict(file) for file in media_files]
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logger.error(f"Failed to get media files: {e}")
|
| 242 |
+
return []
|
| 243 |
+
|
| 244 |
+
def _media_file_to_dict(self, media_file) -> Dict:
|
| 245 |
+
"""Convert Prisma media file to dictionary"""
|
| 246 |
+
return {
|
| 247 |
+
'id': media_file.id,
|
| 248 |
+
'filename': media_file.filename,
|
| 249 |
+
'filepath': media_file.filepath,
|
| 250 |
+
'mime_type': media_file.mimeType,
|
| 251 |
+
'file_size': media_file.fileSize,
|
| 252 |
+
'description': media_file.description,
|
| 253 |
+
'tags': json.loads(media_file.tags) if media_file.tags else [],
|
| 254 |
+
'created_at': media_file.createdAt
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
# Enhanced Knowledge Base Operations
|
| 258 |
+
async def add_knowledge_entry(self, entity_type: str, entity_id: int, content: str,
|
| 259 |
+
title: str = None, description: str = None,
|
| 260 |
+
tags: List[str] = None, category: str = None,
|
| 261 |
+
image_url: str = None, text_content: str = None) -> Optional[int]:
|
| 262 |
+
"""Add enhanced knowledge base entry"""
|
| 263 |
+
if not self.client:
|
| 264 |
+
return None
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
keywords_str = json.dumps(tags) if tags else None
|
| 268 |
+
|
| 269 |
+
knowledge_entry = await self.client.knowledgebase.create(
|
| 270 |
+
data={
|
| 271 |
+
'entityType': entity_type,
|
| 272 |
+
'entityId': entity_id,
|
| 273 |
+
'content': content,
|
| 274 |
+
'title': title,
|
| 275 |
+
'description': description,
|
| 276 |
+
'keywords': keywords_str,
|
| 277 |
+
'category': category,
|
| 278 |
+
'imageUrl': image_url,
|
| 279 |
+
'textContent': text_content
|
| 280 |
+
}
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
logger.info(f"Added knowledge entry: {title or content[:50]}")
|
| 284 |
+
return knowledge_entry.id
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.error(f"Failed to add knowledge entry: {e}")
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
async def search_knowledge(self, query: str, entity_type: str = None,
|
| 291 |
+
category: str = None, limit: int = 10) -> List[Dict]:
|
| 292 |
+
"""Enhanced knowledge search"""
|
| 293 |
+
if not self.client:
|
| 294 |
+
return []
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
where_clause = {
|
| 298 |
+
'OR': [
|
| 299 |
+
{'content': {'contains': query}},
|
| 300 |
+
{'title': {'contains': query}},
|
| 301 |
+
{'description': {'contains': query}},
|
| 302 |
+
{'keywords': {'contains': query}},
|
| 303 |
+
{'textContent': {'contains': query}}
|
| 304 |
+
]
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
if entity_type:
|
| 308 |
+
where_clause['entityType'] = entity_type
|
| 309 |
+
if category:
|
| 310 |
+
where_clause['category'] = category
|
| 311 |
+
|
| 312 |
+
entries = await self.client.knowledgebase.find_many(
|
| 313 |
+
where=where_clause,
|
| 314 |
+
take=limit,
|
| 315 |
+
order_by={'createdAt': 'desc'}
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
return [self._knowledge_to_dict(entry) for entry in entries]
|
| 319 |
+
|
| 320 |
+
except Exception as e:
|
| 321 |
+
logger.error(f"Knowledge search failed: {e}")
|
| 322 |
+
return []
|
| 323 |
+
|
| 324 |
+
def _knowledge_to_dict(self, entry) -> Dict:
|
| 325 |
+
"""Convert Prisma knowledge entry to dictionary"""
|
| 326 |
+
return {
|
| 327 |
+
'id': entry.id,
|
| 328 |
+
'entity_type': entry.entityType,
|
| 329 |
+
'entity_id': entry.entityId,
|
| 330 |
+
'content': entry.content,
|
| 331 |
+
'title': entry.title,
|
| 332 |
+
'description': entry.description,
|
| 333 |
+
'keywords': json.loads(entry.keywords) if entry.keywords else [],
|
| 334 |
+
'category': entry.category,
|
| 335 |
+
'image_url': entry.imageUrl,
|
| 336 |
+
'text_content': entry.textContent,
|
| 337 |
+
'created_at': entry.createdAt,
|
| 338 |
+
'updated_at': entry.updatedAt
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
# Statistics and Analytics
|
| 342 |
+
async def get_enhanced_stats(self) -> Dict:
|
| 343 |
+
"""Get comprehensive database statistics"""
|
| 344 |
+
if not self.client:
|
| 345 |
+
return {}
|
| 346 |
+
|
| 347 |
+
try:
|
| 348 |
+
stats = {}
|
| 349 |
+
|
| 350 |
+
# Basic counts
|
| 351 |
+
stats['faces'] = await self.client.face.count(where={'isActive': True})
|
| 352 |
+
stats['objects'] = await self.client.object.count(where={'isActive': True})
|
| 353 |
+
stats['documents'] = await self.client.document.count(where={'isActive': True})
|
| 354 |
+
stats['media_files'] = await self.client.mediafile.count(where={'isActive': True})
|
| 355 |
+
stats['knowledge_entries'] = await self.client.knowledgebase.count()
|
| 356 |
+
stats['training_corrections'] = await self.client.trainingcorrection.count()
|
| 357 |
+
|
| 358 |
+
# Recent activity (last 7 days)
|
| 359 |
+
seven_days_ago = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
| 360 |
+
stats['recent_detections'] = await self.client.detectionhistory.count(
|
| 361 |
+
where={'createdAt': {'gte': seven_days_ago}}
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
return stats
|
| 365 |
+
|
| 366 |
+
except Exception as e:
|
| 367 |
+
logger.error(f"Failed to get enhanced stats: {e}")
|
| 368 |
+
return {}
|
| 369 |
+
|
| 370 |
+
# Global Prisma manager instance
|
| 371 |
+
prisma_manager = PrismaManager()
|
| 372 |
+
|
| 373 |
+
# Helper functions for async operations in Streamlit
|
| 374 |
+
def run_async(coro):
|
| 375 |
+
"""Run async function in Streamlit"""
|
| 376 |
+
try:
|
| 377 |
+
loop = asyncio.get_event_loop()
|
| 378 |
+
except RuntimeError:
|
| 379 |
+
loop = asyncio.new_event_loop()
|
| 380 |
+
asyncio.set_event_loop(loop)
|
| 381 |
+
|
| 382 |
+
return loop.run_until_complete(coro)
|
| 383 |
+
|
| 384 |
+
# Convenience functions
|
| 385 |
+
def add_document_sync(title: str, content: str, **kwargs) -> Optional[int]:
|
| 386 |
+
"""Synchronous wrapper for adding documents"""
|
| 387 |
+
return run_async(prisma_manager.add_document(title, content, **kwargs))
|
| 388 |
+
|
| 389 |
+
def search_documents_sync(query: str, **kwargs) -> List[Dict]:
|
| 390 |
+
"""Synchronous wrapper for searching documents"""
|
| 391 |
+
return run_async(prisma_manager.search_documents(query, **kwargs))
|
| 392 |
+
|
| 393 |
+
def add_media_file_sync(filename: str, filepath: str, mime_type: str,
|
| 394 |
+
file_size: int, **kwargs) -> Optional[int]:
|
| 395 |
+
"""Synchronous wrapper for adding media files"""
|
| 396 |
+
return run_async(prisma_manager.add_media_file(filename, filepath, mime_type, file_size, **kwargs))
|
| 397 |
+
|
| 398 |
+
def get_enhanced_stats_sync() -> Dict:
|
| 399 |
+
"""Synchronous wrapper for getting stats"""
|
| 400 |
return run_async(prisma_manager.get_enhanced_stats())
|
backend/recognition.py
CHANGED
|
@@ -1,367 +1,367 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Advanced Recognition Module for NAVADA
|
| 3 |
-
Handles face recognition, custom object detection, and RAG-enhanced identification
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import cv2
|
| 7 |
-
import numpy as np
|
| 8 |
-
from typing import List, Dict, Tuple, Optional
|
| 9 |
-
import logging
|
| 10 |
-
from .database import db
|
| 11 |
-
from .face_detection import face_detector
|
| 12 |
-
import time
|
| 13 |
-
import uuid
|
| 14 |
-
|
| 15 |
-
# Configure logging
|
| 16 |
-
logger = logging.getLogger(__name__)
|
| 17 |
-
|
| 18 |
-
class NAVADARecognition:
|
| 19 |
-
"""Advanced recognition system with database integration"""
|
| 20 |
-
|
| 21 |
-
def __init__(self):
|
| 22 |
-
"""Initialize recognition system"""
|
| 23 |
-
self.face_threshold = 0.6 # Face recognition threshold
|
| 24 |
-
self.object_threshold = 0.5 # Object recognition threshold
|
| 25 |
-
self.session_id = str(uuid.uuid4())
|
| 26 |
-
|
| 27 |
-
def extract_face_encoding(self, face_image: np.ndarray) -> Optional[np.ndarray]:
|
| 28 |
-
"""
|
| 29 |
-
Extract face encoding for recognition
|
| 30 |
-
This is a simplified version - in production, use face_recognition library
|
| 31 |
-
"""
|
| 32 |
-
try:
|
| 33 |
-
# Convert to grayscale and resize
|
| 34 |
-
gray = cv2.cvtColor(face_image, cv2.COLOR_RGB2GRAY)
|
| 35 |
-
resized = cv2.resize(gray, (128, 128))
|
| 36 |
-
|
| 37 |
-
# Flatten and normalize as simple encoding
|
| 38 |
-
encoding = resized.flatten().astype(np.float64)
|
| 39 |
-
encoding = encoding / np.linalg.norm(encoding) # Normalize
|
| 40 |
-
|
| 41 |
-
return encoding
|
| 42 |
-
|
| 43 |
-
except Exception as e:
|
| 44 |
-
logger.error(f"Face encoding extraction failed: {e}")
|
| 45 |
-
return None
|
| 46 |
-
|
| 47 |
-
def compare_face_encodings(self, encoding1: np.ndarray, encoding2: np.ndarray) -> float:
|
| 48 |
-
"""Compare two face encodings and return similarity score"""
|
| 49 |
-
try:
|
| 50 |
-
# Calculate cosine similarity
|
| 51 |
-
similarity = np.dot(encoding1, encoding2) / (
|
| 52 |
-
np.linalg.norm(encoding1) * np.linalg.norm(encoding2)
|
| 53 |
-
)
|
| 54 |
-
return float(similarity)
|
| 55 |
-
|
| 56 |
-
except Exception as e:
|
| 57 |
-
logger.error(f"Face comparison failed: {e}")
|
| 58 |
-
return 0.0
|
| 59 |
-
|
| 60 |
-
def recognize_faces(self, image: np.ndarray) -> Tuple[np.ndarray, List[Dict]]:
|
| 61 |
-
"""
|
| 62 |
-
Recognize faces in image against database
|
| 63 |
-
|
| 64 |
-
Returns:
|
| 65 |
-
Annotated image and list of recognition results
|
| 66 |
-
"""
|
| 67 |
-
try:
|
| 68 |
-
if not db:
|
| 69 |
-
return image, []
|
| 70 |
-
|
| 71 |
-
# Detect faces first
|
| 72 |
-
annotated_img, face_stats = face_detector.detect_faces(image)
|
| 73 |
-
|
| 74 |
-
# Get known faces from database
|
| 75 |
-
known_faces = db.get_faces()
|
| 76 |
-
|
| 77 |
-
recognition_results = []
|
| 78 |
-
|
| 79 |
-
if face_stats and face_stats['faces']:
|
| 80 |
-
for face_info in face_stats['faces']:
|
| 81 |
-
# Extract face region
|
| 82 |
-
pos = face_info['position']
|
| 83 |
-
x, y, w, h = pos['x'], pos['y'], pos['width'], pos['height']
|
| 84 |
-
face_region = image[y:y+h, x:x+w]
|
| 85 |
-
|
| 86 |
-
if face_region.size > 0:
|
| 87 |
-
# Extract face encoding
|
| 88 |
-
face_encoding = self.extract_face_encoding(face_region)
|
| 89 |
-
|
| 90 |
-
if face_encoding is not None:
|
| 91 |
-
# Compare with known faces
|
| 92 |
-
best_match = None
|
| 93 |
-
best_similarity = 0.0
|
| 94 |
-
|
| 95 |
-
for known_face in known_faces:
|
| 96 |
-
similarity = self.compare_face_encodings(
|
| 97 |
-
face_encoding, known_face['encoding']
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
if similarity > best_similarity and similarity > self.face_threshold:
|
| 101 |
-
best_similarity = similarity
|
| 102 |
-
best_match = known_face
|
| 103 |
-
|
| 104 |
-
# Add recognition result
|
| 105 |
-
if best_match:
|
| 106 |
-
# Draw name on image
|
| 107 |
-
name = best_match['name']
|
| 108 |
-
cv2.putText(annotated_img, f"{name} ({best_similarity:.2f})",
|
| 109 |
-
(x, y-30), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
| 110 |
-
(0, 255, 0), 2)
|
| 111 |
-
|
| 112 |
-
recognition_results.append({
|
| 113 |
-
'face_id': face_info['face_id'],
|
| 114 |
-
'name': name,
|
| 115 |
-
'similarity': best_similarity,
|
| 116 |
-
'position': pos,
|
| 117 |
-
'database_id': best_match['id']
|
| 118 |
-
})
|
| 119 |
-
else:
|
| 120 |
-
# Unknown face
|
| 121 |
-
cv2.putText(annotated_img, "Unknown",
|
| 122 |
-
(x, y-30), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
| 123 |
-
(0, 0, 255), 2)
|
| 124 |
-
|
| 125 |
-
recognition_results.append({
|
| 126 |
-
'face_id': face_info['face_id'],
|
| 127 |
-
'name': 'Unknown',
|
| 128 |
-
'similarity': 0.0,
|
| 129 |
-
'position': pos,
|
| 130 |
-
'database_id': None
|
| 131 |
-
})
|
| 132 |
-
|
| 133 |
-
return annotated_img, recognition_results
|
| 134 |
-
|
| 135 |
-
except Exception as e:
|
| 136 |
-
logger.error(f"Face recognition failed: {e}")
|
| 137 |
-
return image, []
|
| 138 |
-
|
| 139 |
-
def add_new_face(self, image: np.ndarray, name: str, face_region: Tuple = None) -> bool:
|
| 140 |
-
"""
|
| 141 |
-
Add a new face to the database
|
| 142 |
-
|
| 143 |
-
Args:
|
| 144 |
-
image: Full image containing the face
|
| 145 |
-
name: Person's name
|
| 146 |
-
face_region: Optional (x, y, w, h) region, if None will detect automatically
|
| 147 |
-
|
| 148 |
-
Returns:
|
| 149 |
-
Success status
|
| 150 |
-
"""
|
| 151 |
-
try:
|
| 152 |
-
if not db:
|
| 153 |
-
logger.error("Database not available")
|
| 154 |
-
return False
|
| 155 |
-
|
| 156 |
-
if face_region:
|
| 157 |
-
# Use provided region
|
| 158 |
-
x, y, w, h = face_region
|
| 159 |
-
face_img = image[y:y+h, x:x+w]
|
| 160 |
-
else:
|
| 161 |
-
# Detect face automatically
|
| 162 |
-
_, face_stats = face_detector.detect_faces(image)
|
| 163 |
-
|
| 164 |
-
if not face_stats or not face_stats['faces']:
|
| 165 |
-
logger.error("No face detected in image")
|
| 166 |
-
return False
|
| 167 |
-
|
| 168 |
-
# Use first detected face
|
| 169 |
-
pos = face_stats['faces'][0]['position']
|
| 170 |
-
x, y, w, h = pos['x'], pos['y'], pos['width'], pos['height']
|
| 171 |
-
face_img = image[y:y+h, x:x+w]
|
| 172 |
-
|
| 173 |
-
# Extract encoding
|
| 174 |
-
encoding = self.extract_face_encoding(face_img)
|
| 175 |
-
if encoding is None:
|
| 176 |
-
logger.error("Failed to extract face encoding")
|
| 177 |
-
return False
|
| 178 |
-
|
| 179 |
-
# Add to database
|
| 180 |
-
face_id = db.add_face(
|
| 181 |
-
name=name,
|
| 182 |
-
face_encoding=encoding,
|
| 183 |
-
image=face_img,
|
| 184 |
-
confidence=0.9,
|
| 185 |
-
metadata={
|
| 186 |
-
'source': 'user_added',
|
| 187 |
-
'session_id': self.session_id,
|
| 188 |
-
'timestamp': time.time()
|
| 189 |
-
}
|
| 190 |
-
)
|
| 191 |
-
|
| 192 |
-
logger.info(f"Added new face '{name}' with ID {face_id}")
|
| 193 |
-
return True
|
| 194 |
-
|
| 195 |
-
except Exception as e:
|
| 196 |
-
logger.error(f"Failed to add new face: {e}")
|
| 197 |
-
return False
|
| 198 |
-
|
| 199 |
-
def add_custom_object(self, image: np.ndarray, label: str, category: str,
|
| 200 |
-
bbox: Tuple = None) -> bool:
|
| 201 |
-
"""
|
| 202 |
-
Add a custom object to the database
|
| 203 |
-
|
| 204 |
-
Args:
|
| 205 |
-
image: Full image containing the object
|
| 206 |
-
label: Object label/name
|
| 207 |
-
category: Object category
|
| 208 |
-
bbox: Optional (x, y, w, h) bounding box
|
| 209 |
-
|
| 210 |
-
Returns:
|
| 211 |
-
Success status
|
| 212 |
-
"""
|
| 213 |
-
try:
|
| 214 |
-
if not db:
|
| 215 |
-
logger.error("Database not available")
|
| 216 |
-
return False
|
| 217 |
-
|
| 218 |
-
if bbox:
|
| 219 |
-
# Use provided bounding box
|
| 220 |
-
x, y, w, h = bbox
|
| 221 |
-
object_img = image[y:y+h, x:x+w]
|
| 222 |
-
else:
|
| 223 |
-
# Use entire image as object
|
| 224 |
-
object_img = image
|
| 225 |
-
bbox = (0, 0, image.shape[1], image.shape[0])
|
| 226 |
-
|
| 227 |
-
# Extract simple features (can be enhanced with deep learning)
|
| 228 |
-
features = self.extract_object_features(object_img)
|
| 229 |
-
|
| 230 |
-
# Add to database
|
| 231 |
-
object_id = db.add_object(
|
| 232 |
-
label=label,
|
| 233 |
-
category=category,
|
| 234 |
-
features=features,
|
| 235 |
-
image=object_img,
|
| 236 |
-
bounding_box=bbox,
|
| 237 |
-
confidence=0.8,
|
| 238 |
-
metadata={
|
| 239 |
-
'source': 'user_added',
|
| 240 |
-
'session_id': self.session_id,
|
| 241 |
-
'timestamp': time.time()
|
| 242 |
-
}
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
logger.info(f"Added custom object '{label}' with ID {object_id}")
|
| 246 |
-
return True
|
| 247 |
-
|
| 248 |
-
except Exception as e:
|
| 249 |
-
logger.error(f"Failed to add custom object: {e}")
|
| 250 |
-
return False
|
| 251 |
-
|
| 252 |
-
def extract_object_features(self, object_img: np.ndarray) -> np.ndarray:
|
| 253 |
-
"""Extract features from object image (simplified implementation)"""
|
| 254 |
-
try:
|
| 255 |
-
# Convert to grayscale and resize
|
| 256 |
-
gray = cv2.cvtColor(object_img, cv2.COLOR_RGB2GRAY)
|
| 257 |
-
resized = cv2.resize(gray, (64, 64))
|
| 258 |
-
|
| 259 |
-
# Extract histogram features
|
| 260 |
-
hist = cv2.calcHist([resized], [0], None, [256], [0, 256])
|
| 261 |
-
hist_normalized = hist.flatten() / hist.sum()
|
| 262 |
-
|
| 263 |
-
# Extract edge features
|
| 264 |
-
edges = cv2.Canny(resized, 50, 150)
|
| 265 |
-
edge_density = edges.sum() / edges.size
|
| 266 |
-
|
| 267 |
-
# Combine features
|
| 268 |
-
features = np.concatenate([hist_normalized, [edge_density]])
|
| 269 |
-
|
| 270 |
-
return features.astype(np.float64)
|
| 271 |
-
|
| 272 |
-
except Exception as e:
|
| 273 |
-
logger.error(f"Feature extraction failed: {e}")
|
| 274 |
-
return np.array([])
|
| 275 |
-
|
| 276 |
-
def enhance_with_rag(self, detections: List, face_matches: List = None) -> str:
|
| 277 |
-
"""
|
| 278 |
-
Use RAG to enhance detection results with context
|
| 279 |
-
|
| 280 |
-
Args:
|
| 281 |
-
detections: List of detected objects
|
| 282 |
-
face_matches: List of face recognition results
|
| 283 |
-
|
| 284 |
-
Returns:
|
| 285 |
-
Enhanced description with context
|
| 286 |
-
"""
|
| 287 |
-
try:
|
| 288 |
-
if not db:
|
| 289 |
-
return "Enhanced analysis not available (database offline)"
|
| 290 |
-
|
| 291 |
-
# Build search queries from detections
|
| 292 |
-
queries = []
|
| 293 |
-
|
| 294 |
-
# Add object queries
|
| 295 |
-
for detection in detections:
|
| 296 |
-
queries.append(detection)
|
| 297 |
-
|
| 298 |
-
# Add face queries
|
| 299 |
-
if face_matches:
|
| 300 |
-
for match in face_matches:
|
| 301 |
-
if match['name'] != 'Unknown':
|
| 302 |
-
queries.append(match['name'])
|
| 303 |
-
|
| 304 |
-
# Search knowledge base
|
| 305 |
-
knowledge_results = []
|
| 306 |
-
for query in queries:
|
| 307 |
-
results = db.search_knowledge(query)
|
| 308 |
-
knowledge_results.extend(results)
|
| 309 |
-
|
| 310 |
-
# Build enhanced description
|
| 311 |
-
if knowledge_results:
|
| 312 |
-
enhanced_desc = "🧠 **Enhanced Analysis with Context:**\n\n"
|
| 313 |
-
|
| 314 |
-
# Group by entity type
|
| 315 |
-
face_context = [r for r in knowledge_results if r['entity_type'] == 'face']
|
| 316 |
-
object_context = [r for r in knowledge_results if r['entity_type'] == 'object']
|
| 317 |
-
|
| 318 |
-
if face_context:
|
| 319 |
-
enhanced_desc += "👥 **Known Individuals:**\n"
|
| 320 |
-
for ctx in face_context[:3]: # Limit to 3 results
|
| 321 |
-
enhanced_desc += f" • {ctx['content']}\n"
|
| 322 |
-
enhanced_desc += "\n"
|
| 323 |
-
|
| 324 |
-
if object_context:
|
| 325 |
-
enhanced_desc += "🏷️ **Recognized Objects:**\n"
|
| 326 |
-
for ctx in object_context[:3]: # Limit to 3 results
|
| 327 |
-
enhanced_desc += f" • {ctx['content']}\n"
|
| 328 |
-
enhanced_desc += "\n"
|
| 329 |
-
|
| 330 |
-
enhanced_desc += "📊 **Context Insights:**\n"
|
| 331 |
-
enhanced_desc += f" • Found {len(knowledge_results)} relevant knowledge entries\n"
|
| 332 |
-
enhanced_desc += f" • Analysis includes both detected and learned objects\n"
|
| 333 |
-
|
| 334 |
-
return enhanced_desc
|
| 335 |
-
else:
|
| 336 |
-
return "🔍 **Context Analysis:** No additional context found in knowledge base."
|
| 337 |
-
|
| 338 |
-
except Exception as e:
|
| 339 |
-
logger.error(f"RAG enhancement failed: {e}")
|
| 340 |
-
return "❌ Enhanced analysis unavailable due to processing error."
|
| 341 |
-
|
| 342 |
-
def save_session_data(self, image: np.ndarray, detections: List,
|
| 343 |
-
face_matches: List = None, processing_time: float = 0.0):
|
| 344 |
-
"""Save current session data to database"""
|
| 345 |
-
try:
|
| 346 |
-
if db:
|
| 347 |
-
db.save_detection_history(
|
| 348 |
-
session_id=self.session_id,
|
| 349 |
-
image=image,
|
| 350 |
-
detections=detections,
|
| 351 |
-
face_matches=face_matches,
|
| 352 |
-
processing_time=processing_time,
|
| 353 |
-
metadata={
|
| 354 |
-
'timestamp': time.time(),
|
| 355 |
-
'version': '2.0'
|
| 356 |
-
}
|
| 357 |
-
)
|
| 358 |
-
except Exception as e:
|
| 359 |
-
logger.error(f"Failed to save session data: {e}")
|
| 360 |
-
|
| 361 |
-
# Global recognition instance
|
| 362 |
-
try:
|
| 363 |
-
recognition_system = NAVADARecognition()
|
| 364 |
-
logger.info("Recognition system initialized successfully")
|
| 365 |
-
except Exception as e:
|
| 366 |
-
logger.error(f"Failed to initialize recognition system: {e}")
|
| 367 |
recognition_system = None
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Recognition Module for NAVADA
|
| 3 |
+
Handles face recognition, custom object detection, and RAG-enhanced identification
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import List, Dict, Tuple, Optional
|
| 9 |
+
import logging
|
| 10 |
+
from .database import db
|
| 11 |
+
from .face_detection import face_detector
|
| 12 |
+
import time
|
| 13 |
+
import uuid
|
| 14 |
+
|
| 15 |
+
# Configure logging
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class NAVADARecognition:
|
| 19 |
+
"""Advanced recognition system with database integration"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
"""Initialize recognition system"""
|
| 23 |
+
self.face_threshold = 0.6 # Face recognition threshold
|
| 24 |
+
self.object_threshold = 0.5 # Object recognition threshold
|
| 25 |
+
self.session_id = str(uuid.uuid4())
|
| 26 |
+
|
| 27 |
+
def extract_face_encoding(self, face_image: np.ndarray) -> Optional[np.ndarray]:
|
| 28 |
+
"""
|
| 29 |
+
Extract face encoding for recognition
|
| 30 |
+
This is a simplified version - in production, use face_recognition library
|
| 31 |
+
"""
|
| 32 |
+
try:
|
| 33 |
+
# Convert to grayscale and resize
|
| 34 |
+
gray = cv2.cvtColor(face_image, cv2.COLOR_RGB2GRAY)
|
| 35 |
+
resized = cv2.resize(gray, (128, 128))
|
| 36 |
+
|
| 37 |
+
# Flatten and normalize as simple encoding
|
| 38 |
+
encoding = resized.flatten().astype(np.float64)
|
| 39 |
+
encoding = encoding / np.linalg.norm(encoding) # Normalize
|
| 40 |
+
|
| 41 |
+
return encoding
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error(f"Face encoding extraction failed: {e}")
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
def compare_face_encodings(self, encoding1: np.ndarray, encoding2: np.ndarray) -> float:
|
| 48 |
+
"""Compare two face encodings and return similarity score"""
|
| 49 |
+
try:
|
| 50 |
+
# Calculate cosine similarity
|
| 51 |
+
similarity = np.dot(encoding1, encoding2) / (
|
| 52 |
+
np.linalg.norm(encoding1) * np.linalg.norm(encoding2)
|
| 53 |
+
)
|
| 54 |
+
return float(similarity)
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Face comparison failed: {e}")
|
| 58 |
+
return 0.0
|
| 59 |
+
|
| 60 |
+
def recognize_faces(self, image: np.ndarray) -> Tuple[np.ndarray, List[Dict]]:
|
| 61 |
+
"""
|
| 62 |
+
Recognize faces in image against database
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Annotated image and list of recognition results
|
| 66 |
+
"""
|
| 67 |
+
try:
|
| 68 |
+
if not db:
|
| 69 |
+
return image, []
|
| 70 |
+
|
| 71 |
+
# Detect faces first
|
| 72 |
+
annotated_img, face_stats = face_detector.detect_faces(image)
|
| 73 |
+
|
| 74 |
+
# Get known faces from database
|
| 75 |
+
known_faces = db.get_faces()
|
| 76 |
+
|
| 77 |
+
recognition_results = []
|
| 78 |
+
|
| 79 |
+
if face_stats and face_stats['faces']:
|
| 80 |
+
for face_info in face_stats['faces']:
|
| 81 |
+
# Extract face region
|
| 82 |
+
pos = face_info['position']
|
| 83 |
+
x, y, w, h = pos['x'], pos['y'], pos['width'], pos['height']
|
| 84 |
+
face_region = image[y:y+h, x:x+w]
|
| 85 |
+
|
| 86 |
+
if face_region.size > 0:
|
| 87 |
+
# Extract face encoding
|
| 88 |
+
face_encoding = self.extract_face_encoding(face_region)
|
| 89 |
+
|
| 90 |
+
if face_encoding is not None:
|
| 91 |
+
# Compare with known faces
|
| 92 |
+
best_match = None
|
| 93 |
+
best_similarity = 0.0
|
| 94 |
+
|
| 95 |
+
for known_face in known_faces:
|
| 96 |
+
similarity = self.compare_face_encodings(
|
| 97 |
+
face_encoding, known_face['encoding']
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if similarity > best_similarity and similarity > self.face_threshold:
|
| 101 |
+
best_similarity = similarity
|
| 102 |
+
best_match = known_face
|
| 103 |
+
|
| 104 |
+
# Add recognition result
|
| 105 |
+
if best_match:
|
| 106 |
+
# Draw name on image
|
| 107 |
+
name = best_match['name']
|
| 108 |
+
cv2.putText(annotated_img, f"{name} ({best_similarity:.2f})",
|
| 109 |
+
(x, y-30), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
| 110 |
+
(0, 255, 0), 2)
|
| 111 |
+
|
| 112 |
+
recognition_results.append({
|
| 113 |
+
'face_id': face_info['face_id'],
|
| 114 |
+
'name': name,
|
| 115 |
+
'similarity': best_similarity,
|
| 116 |
+
'position': pos,
|
| 117 |
+
'database_id': best_match['id']
|
| 118 |
+
})
|
| 119 |
+
else:
|
| 120 |
+
# Unknown face
|
| 121 |
+
cv2.putText(annotated_img, "Unknown",
|
| 122 |
+
(x, y-30), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
| 123 |
+
(0, 0, 255), 2)
|
| 124 |
+
|
| 125 |
+
recognition_results.append({
|
| 126 |
+
'face_id': face_info['face_id'],
|
| 127 |
+
'name': 'Unknown',
|
| 128 |
+
'similarity': 0.0,
|
| 129 |
+
'position': pos,
|
| 130 |
+
'database_id': None
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
return annotated_img, recognition_results
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Face recognition failed: {e}")
|
| 137 |
+
return image, []
|
| 138 |
+
|
| 139 |
+
def add_new_face(self, image: np.ndarray, name: str, face_region: Tuple = None) -> bool:
|
| 140 |
+
"""
|
| 141 |
+
Add a new face to the database
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
image: Full image containing the face
|
| 145 |
+
name: Person's name
|
| 146 |
+
face_region: Optional (x, y, w, h) region, if None will detect automatically
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Success status
|
| 150 |
+
"""
|
| 151 |
+
try:
|
| 152 |
+
if not db:
|
| 153 |
+
logger.error("Database not available")
|
| 154 |
+
return False
|
| 155 |
+
|
| 156 |
+
if face_region:
|
| 157 |
+
# Use provided region
|
| 158 |
+
x, y, w, h = face_region
|
| 159 |
+
face_img = image[y:y+h, x:x+w]
|
| 160 |
+
else:
|
| 161 |
+
# Detect face automatically
|
| 162 |
+
_, face_stats = face_detector.detect_faces(image)
|
| 163 |
+
|
| 164 |
+
if not face_stats or not face_stats['faces']:
|
| 165 |
+
logger.error("No face detected in image")
|
| 166 |
+
return False
|
| 167 |
+
|
| 168 |
+
# Use first detected face
|
| 169 |
+
pos = face_stats['faces'][0]['position']
|
| 170 |
+
x, y, w, h = pos['x'], pos['y'], pos['width'], pos['height']
|
| 171 |
+
face_img = image[y:y+h, x:x+w]
|
| 172 |
+
|
| 173 |
+
# Extract encoding
|
| 174 |
+
encoding = self.extract_face_encoding(face_img)
|
| 175 |
+
if encoding is None:
|
| 176 |
+
logger.error("Failed to extract face encoding")
|
| 177 |
+
return False
|
| 178 |
+
|
| 179 |
+
# Add to database
|
| 180 |
+
face_id = db.add_face(
|
| 181 |
+
name=name,
|
| 182 |
+
face_encoding=encoding,
|
| 183 |
+
image=face_img,
|
| 184 |
+
confidence=0.9,
|
| 185 |
+
metadata={
|
| 186 |
+
'source': 'user_added',
|
| 187 |
+
'session_id': self.session_id,
|
| 188 |
+
'timestamp': time.time()
|
| 189 |
+
}
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
logger.info(f"Added new face '{name}' with ID {face_id}")
|
| 193 |
+
return True
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.error(f"Failed to add new face: {e}")
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
def add_custom_object(self, image: np.ndarray, label: str, category: str,
|
| 200 |
+
bbox: Tuple = None) -> bool:
|
| 201 |
+
"""
|
| 202 |
+
Add a custom object to the database
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
image: Full image containing the object
|
| 206 |
+
label: Object label/name
|
| 207 |
+
category: Object category
|
| 208 |
+
bbox: Optional (x, y, w, h) bounding box
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
Success status
|
| 212 |
+
"""
|
| 213 |
+
try:
|
| 214 |
+
if not db:
|
| 215 |
+
logger.error("Database not available")
|
| 216 |
+
return False
|
| 217 |
+
|
| 218 |
+
if bbox:
|
| 219 |
+
# Use provided bounding box
|
| 220 |
+
x, y, w, h = bbox
|
| 221 |
+
object_img = image[y:y+h, x:x+w]
|
| 222 |
+
else:
|
| 223 |
+
# Use entire image as object
|
| 224 |
+
object_img = image
|
| 225 |
+
bbox = (0, 0, image.shape[1], image.shape[0])
|
| 226 |
+
|
| 227 |
+
# Extract simple features (can be enhanced with deep learning)
|
| 228 |
+
features = self.extract_object_features(object_img)
|
| 229 |
+
|
| 230 |
+
# Add to database
|
| 231 |
+
object_id = db.add_object(
|
| 232 |
+
label=label,
|
| 233 |
+
category=category,
|
| 234 |
+
features=features,
|
| 235 |
+
image=object_img,
|
| 236 |
+
bounding_box=bbox,
|
| 237 |
+
confidence=0.8,
|
| 238 |
+
metadata={
|
| 239 |
+
'source': 'user_added',
|
| 240 |
+
'session_id': self.session_id,
|
| 241 |
+
'timestamp': time.time()
|
| 242 |
+
}
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
logger.info(f"Added custom object '{label}' with ID {object_id}")
|
| 246 |
+
return True
|
| 247 |
+
|
| 248 |
+
except Exception as e:
|
| 249 |
+
logger.error(f"Failed to add custom object: {e}")
|
| 250 |
+
return False
|
| 251 |
+
|
| 252 |
+
def extract_object_features(self, object_img: np.ndarray) -> np.ndarray:
|
| 253 |
+
"""Extract features from object image (simplified implementation)"""
|
| 254 |
+
try:
|
| 255 |
+
# Convert to grayscale and resize
|
| 256 |
+
gray = cv2.cvtColor(object_img, cv2.COLOR_RGB2GRAY)
|
| 257 |
+
resized = cv2.resize(gray, (64, 64))
|
| 258 |
+
|
| 259 |
+
# Extract histogram features
|
| 260 |
+
hist = cv2.calcHist([resized], [0], None, [256], [0, 256])
|
| 261 |
+
hist_normalized = hist.flatten() / hist.sum()
|
| 262 |
+
|
| 263 |
+
# Extract edge features
|
| 264 |
+
edges = cv2.Canny(resized, 50, 150)
|
| 265 |
+
edge_density = edges.sum() / edges.size
|
| 266 |
+
|
| 267 |
+
# Combine features
|
| 268 |
+
features = np.concatenate([hist_normalized, [edge_density]])
|
| 269 |
+
|
| 270 |
+
return features.astype(np.float64)
|
| 271 |
+
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logger.error(f"Feature extraction failed: {e}")
|
| 274 |
+
return np.array([])
|
| 275 |
+
|
| 276 |
+
def enhance_with_rag(self, detections: List, face_matches: List = None) -> str:
|
| 277 |
+
"""
|
| 278 |
+
Use RAG to enhance detection results with context
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
detections: List of detected objects
|
| 282 |
+
face_matches: List of face recognition results
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
Enhanced description with context
|
| 286 |
+
"""
|
| 287 |
+
try:
|
| 288 |
+
if not db:
|
| 289 |
+
return "Enhanced analysis not available (database offline)"
|
| 290 |
+
|
| 291 |
+
# Build search queries from detections
|
| 292 |
+
queries = []
|
| 293 |
+
|
| 294 |
+
# Add object queries
|
| 295 |
+
for detection in detections:
|
| 296 |
+
queries.append(detection)
|
| 297 |
+
|
| 298 |
+
# Add face queries
|
| 299 |
+
if face_matches:
|
| 300 |
+
for match in face_matches:
|
| 301 |
+
if match['name'] != 'Unknown':
|
| 302 |
+
queries.append(match['name'])
|
| 303 |
+
|
| 304 |
+
# Search knowledge base
|
| 305 |
+
knowledge_results = []
|
| 306 |
+
for query in queries:
|
| 307 |
+
results = db.search_knowledge(query)
|
| 308 |
+
knowledge_results.extend(results)
|
| 309 |
+
|
| 310 |
+
# Build enhanced description
|
| 311 |
+
if knowledge_results:
|
| 312 |
+
enhanced_desc = "🧠 **Enhanced Analysis with Context:**\n\n"
|
| 313 |
+
|
| 314 |
+
# Group by entity type
|
| 315 |
+
face_context = [r for r in knowledge_results if r['entity_type'] == 'face']
|
| 316 |
+
object_context = [r for r in knowledge_results if r['entity_type'] == 'object']
|
| 317 |
+
|
| 318 |
+
if face_context:
|
| 319 |
+
enhanced_desc += "👥 **Known Individuals:**\n"
|
| 320 |
+
for ctx in face_context[:3]: # Limit to 3 results
|
| 321 |
+
enhanced_desc += f" • {ctx['content']}\n"
|
| 322 |
+
enhanced_desc += "\n"
|
| 323 |
+
|
| 324 |
+
if object_context:
|
| 325 |
+
enhanced_desc += "🏷️ **Recognized Objects:**\n"
|
| 326 |
+
for ctx in object_context[:3]: # Limit to 3 results
|
| 327 |
+
enhanced_desc += f" • {ctx['content']}\n"
|
| 328 |
+
enhanced_desc += "\n"
|
| 329 |
+
|
| 330 |
+
enhanced_desc += "📊 **Context Insights:**\n"
|
| 331 |
+
enhanced_desc += f" • Found {len(knowledge_results)} relevant knowledge entries\n"
|
| 332 |
+
enhanced_desc += f" • Analysis includes both detected and learned objects\n"
|
| 333 |
+
|
| 334 |
+
return enhanced_desc
|
| 335 |
+
else:
|
| 336 |
+
return "🔍 **Context Analysis:** No additional context found in knowledge base."
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"RAG enhancement failed: {e}")
|
| 340 |
+
return "❌ Enhanced analysis unavailable due to processing error."
|
| 341 |
+
|
| 342 |
+
def save_session_data(self, image: np.ndarray, detections: List,
|
| 343 |
+
face_matches: List = None, processing_time: float = 0.0):
|
| 344 |
+
"""Save current session data to database"""
|
| 345 |
+
try:
|
| 346 |
+
if db:
|
| 347 |
+
db.save_detection_history(
|
| 348 |
+
session_id=self.session_id,
|
| 349 |
+
image=image,
|
| 350 |
+
detections=detections,
|
| 351 |
+
face_matches=face_matches,
|
| 352 |
+
processing_time=processing_time,
|
| 353 |
+
metadata={
|
| 354 |
+
'timestamp': time.time(),
|
| 355 |
+
'version': '2.0'
|
| 356 |
+
}
|
| 357 |
+
)
|
| 358 |
+
except Exception as e:
|
| 359 |
+
logger.error(f"Failed to save session data: {e}")
|
| 360 |
+
|
| 361 |
+
# Global recognition instance
|
| 362 |
+
try:
|
| 363 |
+
recognition_system = NAVADARecognition()
|
| 364 |
+
logger.info("Recognition system initialized successfully")
|
| 365 |
+
except Exception as e:
|
| 366 |
+
logger.error(f"Failed to initialize recognition system: {e}")
|
| 367 |
recognition_system = None
|
backend/two_stage_inference.py
CHANGED
|
@@ -1,285 +1,285 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Two-Stage Inference System
|
| 3 |
-
Combines YOLO detection with custom classifier for improved accuracy
|
| 4 |
-
"""
|
| 5 |
-
import torch
|
| 6 |
-
import torchvision.transforms as transforms
|
| 7 |
-
import numpy as np
|
| 8 |
-
import cv2
|
| 9 |
-
from typing import List, Dict, Tuple, Optional
|
| 10 |
-
import pickle
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
import logging
|
| 13 |
-
|
| 14 |
-
from .yolo_enhanced import detect_objects_enhanced, model as yolo_model
|
| 15 |
-
from .custom_trainer import CustomClassifier
|
| 16 |
-
|
| 17 |
-
logger = logging.getLogger(__name__)
|
| 18 |
-
|
| 19 |
-
class TwoStageInference:
|
| 20 |
-
"""Two-stage detection and classification system"""
|
| 21 |
-
|
| 22 |
-
def __init__(self, models_dir='models/'):
|
| 23 |
-
"""
|
| 24 |
-
Initialize two-stage inference system
|
| 25 |
-
|
| 26 |
-
Args:
|
| 27 |
-
models_dir: Directory containing trained custom models
|
| 28 |
-
"""
|
| 29 |
-
self.models_dir = Path(models_dir)
|
| 30 |
-
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 31 |
-
|
| 32 |
-
# Load active custom model if available
|
| 33 |
-
self.custom_model = None
|
| 34 |
-
self.class_info = None
|
| 35 |
-
self.load_active_model()
|
| 36 |
-
|
| 37 |
-
# Image preprocessing for custom classifier
|
| 38 |
-
self.preprocess = transforms.Compose([
|
| 39 |
-
transforms.ToPILImage(),
|
| 40 |
-
transforms.Resize((224, 224)),
|
| 41 |
-
transforms.ToTensor(),
|
| 42 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 43 |
-
])
|
| 44 |
-
|
| 45 |
-
def load_active_model(self):
|
| 46 |
-
"""Load the most recent trained custom model"""
|
| 47 |
-
try:
|
| 48 |
-
# Find latest model file
|
| 49 |
-
model_files = list(self.models_dir.glob('custom_classifier_*.pkl'))
|
| 50 |
-
if not model_files:
|
| 51 |
-
logger.info("No custom models found. Using YOLO only.")
|
| 52 |
-
return
|
| 53 |
-
|
| 54 |
-
# Get most recent model
|
| 55 |
-
latest_model = max(model_files, key=lambda x: x.stat().st_mtime)
|
| 56 |
-
|
| 57 |
-
# Load model info
|
| 58 |
-
with open(latest_model, 'rb') as f:
|
| 59 |
-
model_info = pickle.load(f)
|
| 60 |
-
|
| 61 |
-
self.class_info = model_info['class_info']
|
| 62 |
-
|
| 63 |
-
# Initialize and load custom model
|
| 64 |
-
self.custom_model = CustomClassifier(
|
| 65 |
-
num_classes=self.class_info['num_classes'],
|
| 66 |
-
backbone=model_info['training_config']['backbone']
|
| 67 |
-
)
|
| 68 |
-
self.custom_model.load_state_dict(model_info['model_state'])
|
| 69 |
-
self.custom_model = self.custom_model.to(self.device)
|
| 70 |
-
self.custom_model.eval()
|
| 71 |
-
|
| 72 |
-
logger.info(f"Loaded custom model: {latest_model.name}")
|
| 73 |
-
logger.info(f"Custom classes: {list(self.class_info['idx_to_label'].values())}")
|
| 74 |
-
|
| 75 |
-
except Exception as e:
|
| 76 |
-
logger.error(f"Failed to load custom model: {e}")
|
| 77 |
-
self.custom_model = None
|
| 78 |
-
self.class_info = None
|
| 79 |
-
|
| 80 |
-
def classify_object(self, object_crop: np.ndarray) -> Tuple[str, float]:
|
| 81 |
-
"""
|
| 82 |
-
Classify object crop using custom model
|
| 83 |
-
|
| 84 |
-
Args:
|
| 85 |
-
object_crop: Cropped image region
|
| 86 |
-
|
| 87 |
-
Returns:
|
| 88 |
-
Tuple of (predicted_label, confidence)
|
| 89 |
-
"""
|
| 90 |
-
if self.custom_model is None:
|
| 91 |
-
return None, 0.0
|
| 92 |
-
|
| 93 |
-
try:
|
| 94 |
-
# Preprocess image
|
| 95 |
-
if object_crop.size == 0:
|
| 96 |
-
return None, 0.0
|
| 97 |
-
|
| 98 |
-
# Convert BGR to RGB
|
| 99 |
-
if len(object_crop.shape) == 3 and object_crop.shape[2] == 3:
|
| 100 |
-
object_crop = cv2.cvtColor(object_crop, cv2.COLOR_BGR2RGB)
|
| 101 |
-
|
| 102 |
-
# Preprocess for model
|
| 103 |
-
input_tensor = self.preprocess(object_crop).unsqueeze(0).to(self.device)
|
| 104 |
-
|
| 105 |
-
# Inference
|
| 106 |
-
with torch.no_grad():
|
| 107 |
-
outputs = self.custom_model(input_tensor)
|
| 108 |
-
probabilities = torch.softmax(outputs, dim=1)
|
| 109 |
-
confidence, predicted = torch.max(probabilities, 1)
|
| 110 |
-
|
| 111 |
-
predicted_idx = predicted.item()
|
| 112 |
-
confidence_score = confidence.item()
|
| 113 |
-
|
| 114 |
-
# Convert to label
|
| 115 |
-
predicted_label = self.class_info['idx_to_label'][predicted_idx]
|
| 116 |
-
|
| 117 |
-
return predicted_label, confidence_score
|
| 118 |
-
|
| 119 |
-
except Exception as e:
|
| 120 |
-
logger.error(f"Custom classification failed: {e}")
|
| 121 |
-
return None, 0.0
|
| 122 |
-
|
| 123 |
-
def should_override_yolo(self, yolo_label: str, yolo_confidence: float,
|
| 124 |
-
custom_label: str, custom_confidence: float) -> bool:
|
| 125 |
-
"""
|
| 126 |
-
Decide whether to override YOLO prediction with custom model
|
| 127 |
-
|
| 128 |
-
Args:
|
| 129 |
-
yolo_label: YOLO predicted label
|
| 130 |
-
yolo_confidence: YOLO confidence
|
| 131 |
-
custom_label: Custom model predicted label
|
| 132 |
-
custom_confidence: Custom model confidence
|
| 133 |
-
|
| 134 |
-
Returns:
|
| 135 |
-
True if should use custom model prediction
|
| 136 |
-
"""
|
| 137 |
-
# Don't override if custom model not confident enough
|
| 138 |
-
if custom_confidence < 0.7:
|
| 139 |
-
return False
|
| 140 |
-
|
| 141 |
-
# Always override if YOLO has low confidence and custom has high
|
| 142 |
-
if yolo_confidence < 0.5 and custom_confidence > 0.8:
|
| 143 |
-
return True
|
| 144 |
-
|
| 145 |
-
# Override if custom model is significantly more confident
|
| 146 |
-
if custom_confidence > yolo_confidence + 0.2:
|
| 147 |
-
return True
|
| 148 |
-
|
| 149 |
-
# Override if we have training data for this custom class
|
| 150 |
-
if custom_label in self.class_info.get('valid_classes', []):
|
| 151 |
-
return True
|
| 152 |
-
|
| 153 |
-
return False
|
| 154 |
-
|
| 155 |
-
def detect_with_custom_model(self, image: np.ndarray, confidence_threshold: float = 0.5) -> Tuple[np.ndarray, List[str], List[Dict]]:
|
| 156 |
-
"""
|
| 157 |
-
Two-stage detection: YOLO + Custom Classification
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
image: Input image
|
| 161 |
-
confidence_threshold: YOLO confidence threshold
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
Tuple of (annotated_image, detected_objects, detailed_attributes)
|
| 165 |
-
"""
|
| 166 |
-
# Stage 1: YOLO Detection
|
| 167 |
-
try:
|
| 168 |
-
annotated_img, detected_objects, detailed_attributes = detect_objects_enhanced(
|
| 169 |
-
image, confidence_threshold
|
| 170 |
-
)
|
| 171 |
-
except:
|
| 172 |
-
# Fallback to basic YOLO
|
| 173 |
-
from .yolo import detect_objects
|
| 174 |
-
annotated_img, detected_objects = detect_objects(image)
|
| 175 |
-
detailed_attributes = []
|
| 176 |
-
|
| 177 |
-
# Stage 2: Custom Classification (if model available)
|
| 178 |
-
if self.custom_model is None or not detailed_attributes:
|
| 179 |
-
return annotated_img, detected_objects, detailed_attributes
|
| 180 |
-
|
| 181 |
-
# Process each detection with custom model
|
| 182 |
-
enhanced_attributes = []
|
| 183 |
-
enhanced_objects = []
|
| 184 |
-
|
| 185 |
-
for i, attr in enumerate(detailed_attributes):
|
| 186 |
-
yolo_label = attr['label']
|
| 187 |
-
yolo_confidence = float(attr['confidence'].rstrip('%')) / 100.0
|
| 188 |
-
bbox = attr.get('bbox', [0, 0, 100, 100])
|
| 189 |
-
|
| 190 |
-
# Extract object region
|
| 191 |
-
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
| 192 |
-
object_crop = image[max(0, y1):min(image.shape[0], y2),
|
| 193 |
-
max(0, x1):min(image.shape[1], x2)]
|
| 194 |
-
|
| 195 |
-
# Classify with custom model
|
| 196 |
-
custom_label, custom_confidence = self.classify_object(object_crop)
|
| 197 |
-
|
| 198 |
-
# Decide which prediction to use
|
| 199 |
-
if custom_label and self.should_override_yolo(yolo_label, yolo_confidence,
|
| 200 |
-
custom_label, custom_confidence):
|
| 201 |
-
# Use custom model prediction
|
| 202 |
-
final_label = custom_label
|
| 203 |
-
final_confidence = custom_confidence
|
| 204 |
-
attr['prediction_source'] = 'custom_model'
|
| 205 |
-
attr['original_yolo'] = {'label': yolo_label, 'confidence': yolo_confidence}
|
| 206 |
-
else:
|
| 207 |
-
# Use YOLO prediction
|
| 208 |
-
final_label = yolo_label
|
| 209 |
-
final_confidence = yolo_confidence
|
| 210 |
-
attr['prediction_source'] = 'yolo'
|
| 211 |
-
if custom_label:
|
| 212 |
-
attr['custom_alternative'] = {'label': custom_label, 'confidence': custom_confidence}
|
| 213 |
-
|
| 214 |
-
# Update attributes
|
| 215 |
-
attr['label'] = final_label
|
| 216 |
-
attr['confidence'] = f"{final_confidence:.2%}"
|
| 217 |
-
|
| 218 |
-
enhanced_attributes.append(attr)
|
| 219 |
-
enhanced_objects.append(final_label)
|
| 220 |
-
|
| 221 |
-
# Update annotated image if we made changes
|
| 222 |
-
if any(attr.get('prediction_source') == 'custom_model' for attr in enhanced_attributes):
|
| 223 |
-
# Re-annotate image with updated predictions
|
| 224 |
-
annotated_img = self.annotate_image_with_predictions(image, enhanced_attributes)
|
| 225 |
-
|
| 226 |
-
return annotated_img, enhanced_objects, enhanced_attributes
|
| 227 |
-
|
| 228 |
-
def annotate_image_with_predictions(self, image: np.ndarray, attributes: List[Dict]) -> np.ndarray:
|
| 229 |
-
"""
|
| 230 |
-
Annotate image with updated predictions
|
| 231 |
-
|
| 232 |
-
Args:
|
| 233 |
-
image: Original image
|
| 234 |
-
attributes: Detection attributes with updated labels
|
| 235 |
-
|
| 236 |
-
Returns:
|
| 237 |
-
Annotated image
|
| 238 |
-
"""
|
| 239 |
-
annotated = image.copy()
|
| 240 |
-
|
| 241 |
-
for attr in attributes:
|
| 242 |
-
bbox = attr.get('bbox', [0, 0, 100, 100])
|
| 243 |
-
label = attr['label']
|
| 244 |
-
confidence = attr['confidence']
|
| 245 |
-
source = attr.get('prediction_source', 'yolo')
|
| 246 |
-
|
| 247 |
-
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
| 248 |
-
|
| 249 |
-
# Choose color based on source
|
| 250 |
-
if source == 'custom_model':
|
| 251 |
-
color = (0, 255, 0) # Green for custom model
|
| 252 |
-
label_text = f"{label} {confidence} (Custom)"
|
| 253 |
-
else:
|
| 254 |
-
color = (255, 0, 0) # Red for YOLO
|
| 255 |
-
label_text = f"{label} {confidence}"
|
| 256 |
-
|
| 257 |
-
# Draw bounding box
|
| 258 |
-
cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
|
| 259 |
-
|
| 260 |
-
# Draw label
|
| 261 |
-
cv2.putText(annotated, label_text, (x1, y1-10),
|
| 262 |
-
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
| 263 |
-
|
| 264 |
-
return annotated
|
| 265 |
-
|
| 266 |
-
def get_model_info(self) -> Dict:
|
| 267 |
-
"""Get information about loaded models"""
|
| 268 |
-
info = {
|
| 269 |
-
'yolo_model': 'YOLOv8m',
|
| 270 |
-
'custom_model_loaded': self.custom_model is not None,
|
| 271 |
-
'device': self.device
|
| 272 |
-
}
|
| 273 |
-
|
| 274 |
-
if self.custom_model is not None and self.class_info is not None:
|
| 275 |
-
info.update({
|
| 276 |
-
'custom_classes': list(self.class_info['idx_to_label'].values()),
|
| 277 |
-
'num_custom_classes': self.class_info['num_classes'],
|
| 278 |
-
'training_samples': self.class_info.get('train_samples', 0),
|
| 279 |
-
'validation_samples': self.class_info.get('val_samples', 0)
|
| 280 |
-
})
|
| 281 |
-
|
| 282 |
-
return info
|
| 283 |
-
|
| 284 |
-
# Global two-stage inference instance
|
| 285 |
two_stage_inference = TwoStageInference()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Two-Stage Inference System
|
| 3 |
+
Combines YOLO detection with custom classifier for improved accuracy
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
from typing import List, Dict, Tuple, Optional
|
| 10 |
+
import pickle
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
from .yolo_enhanced import detect_objects_enhanced, model as yolo_model
|
| 15 |
+
from .custom_trainer import CustomClassifier
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class TwoStageInference:
|
| 20 |
+
"""Two-stage detection and classification system"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, models_dir='models/'):
|
| 23 |
+
"""
|
| 24 |
+
Initialize two-stage inference system
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
models_dir: Directory containing trained custom models
|
| 28 |
+
"""
|
| 29 |
+
self.models_dir = Path(models_dir)
|
| 30 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 31 |
+
|
| 32 |
+
# Load active custom model if available
|
| 33 |
+
self.custom_model = None
|
| 34 |
+
self.class_info = None
|
| 35 |
+
self.load_active_model()
|
| 36 |
+
|
| 37 |
+
# Image preprocessing for custom classifier
|
| 38 |
+
self.preprocess = transforms.Compose([
|
| 39 |
+
transforms.ToPILImage(),
|
| 40 |
+
transforms.Resize((224, 224)),
|
| 41 |
+
transforms.ToTensor(),
|
| 42 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 43 |
+
])
|
| 44 |
+
|
| 45 |
+
def load_active_model(self):
|
| 46 |
+
"""Load the most recent trained custom model"""
|
| 47 |
+
try:
|
| 48 |
+
# Find latest model file
|
| 49 |
+
model_files = list(self.models_dir.glob('custom_classifier_*.pkl'))
|
| 50 |
+
if not model_files:
|
| 51 |
+
logger.info("No custom models found. Using YOLO only.")
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
# Get most recent model
|
| 55 |
+
latest_model = max(model_files, key=lambda x: x.stat().st_mtime)
|
| 56 |
+
|
| 57 |
+
# Load model info
|
| 58 |
+
with open(latest_model, 'rb') as f:
|
| 59 |
+
model_info = pickle.load(f)
|
| 60 |
+
|
| 61 |
+
self.class_info = model_info['class_info']
|
| 62 |
+
|
| 63 |
+
# Initialize and load custom model
|
| 64 |
+
self.custom_model = CustomClassifier(
|
| 65 |
+
num_classes=self.class_info['num_classes'],
|
| 66 |
+
backbone=model_info['training_config']['backbone']
|
| 67 |
+
)
|
| 68 |
+
self.custom_model.load_state_dict(model_info['model_state'])
|
| 69 |
+
self.custom_model = self.custom_model.to(self.device)
|
| 70 |
+
self.custom_model.eval()
|
| 71 |
+
|
| 72 |
+
logger.info(f"Loaded custom model: {latest_model.name}")
|
| 73 |
+
logger.info(f"Custom classes: {list(self.class_info['idx_to_label'].values())}")
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.error(f"Failed to load custom model: {e}")
|
| 77 |
+
self.custom_model = None
|
| 78 |
+
self.class_info = None
|
| 79 |
+
|
| 80 |
+
def classify_object(self, object_crop: np.ndarray) -> Tuple[str, float]:
|
| 81 |
+
"""
|
| 82 |
+
Classify object crop using custom model
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
object_crop: Cropped image region
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Tuple of (predicted_label, confidence)
|
| 89 |
+
"""
|
| 90 |
+
if self.custom_model is None:
|
| 91 |
+
return None, 0.0
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
# Preprocess image
|
| 95 |
+
if object_crop.size == 0:
|
| 96 |
+
return None, 0.0
|
| 97 |
+
|
| 98 |
+
# Convert BGR to RGB
|
| 99 |
+
if len(object_crop.shape) == 3 and object_crop.shape[2] == 3:
|
| 100 |
+
object_crop = cv2.cvtColor(object_crop, cv2.COLOR_BGR2RGB)
|
| 101 |
+
|
| 102 |
+
# Preprocess for model
|
| 103 |
+
input_tensor = self.preprocess(object_crop).unsqueeze(0).to(self.device)
|
| 104 |
+
|
| 105 |
+
# Inference
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
outputs = self.custom_model(input_tensor)
|
| 108 |
+
probabilities = torch.softmax(outputs, dim=1)
|
| 109 |
+
confidence, predicted = torch.max(probabilities, 1)
|
| 110 |
+
|
| 111 |
+
predicted_idx = predicted.item()
|
| 112 |
+
confidence_score = confidence.item()
|
| 113 |
+
|
| 114 |
+
# Convert to label
|
| 115 |
+
predicted_label = self.class_info['idx_to_label'][predicted_idx]
|
| 116 |
+
|
| 117 |
+
return predicted_label, confidence_score
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"Custom classification failed: {e}")
|
| 121 |
+
return None, 0.0
|
| 122 |
+
|
| 123 |
+
def should_override_yolo(self, yolo_label: str, yolo_confidence: float,
|
| 124 |
+
custom_label: str, custom_confidence: float) -> bool:
|
| 125 |
+
"""
|
| 126 |
+
Decide whether to override YOLO prediction with custom model
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
yolo_label: YOLO predicted label
|
| 130 |
+
yolo_confidence: YOLO confidence
|
| 131 |
+
custom_label: Custom model predicted label
|
| 132 |
+
custom_confidence: Custom model confidence
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
True if should use custom model prediction
|
| 136 |
+
"""
|
| 137 |
+
# Don't override if custom model not confident enough
|
| 138 |
+
if custom_confidence < 0.7:
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
# Always override if YOLO has low confidence and custom has high
|
| 142 |
+
if yolo_confidence < 0.5 and custom_confidence > 0.8:
|
| 143 |
+
return True
|
| 144 |
+
|
| 145 |
+
# Override if custom model is significantly more confident
|
| 146 |
+
if custom_confidence > yolo_confidence + 0.2:
|
| 147 |
+
return True
|
| 148 |
+
|
| 149 |
+
# Override if we have training data for this custom class
|
| 150 |
+
if custom_label in self.class_info.get('valid_classes', []):
|
| 151 |
+
return True
|
| 152 |
+
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
def detect_with_custom_model(self, image: np.ndarray, confidence_threshold: float = 0.5) -> Tuple[np.ndarray, List[str], List[Dict]]:
|
| 156 |
+
"""
|
| 157 |
+
Two-stage detection: YOLO + Custom Classification
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
image: Input image
|
| 161 |
+
confidence_threshold: YOLO confidence threshold
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Tuple of (annotated_image, detected_objects, detailed_attributes)
|
| 165 |
+
"""
|
| 166 |
+
# Stage 1: YOLO Detection
|
| 167 |
+
try:
|
| 168 |
+
annotated_img, detected_objects, detailed_attributes = detect_objects_enhanced(
|
| 169 |
+
image, confidence_threshold
|
| 170 |
+
)
|
| 171 |
+
except:
|
| 172 |
+
# Fallback to basic YOLO
|
| 173 |
+
from .yolo import detect_objects
|
| 174 |
+
annotated_img, detected_objects = detect_objects(image)
|
| 175 |
+
detailed_attributes = []
|
| 176 |
+
|
| 177 |
+
# Stage 2: Custom Classification (if model available)
|
| 178 |
+
if self.custom_model is None or not detailed_attributes:
|
| 179 |
+
return annotated_img, detected_objects, detailed_attributes
|
| 180 |
+
|
| 181 |
+
# Process each detection with custom model
|
| 182 |
+
enhanced_attributes = []
|
| 183 |
+
enhanced_objects = []
|
| 184 |
+
|
| 185 |
+
for i, attr in enumerate(detailed_attributes):
|
| 186 |
+
yolo_label = attr['label']
|
| 187 |
+
yolo_confidence = float(attr['confidence'].rstrip('%')) / 100.0
|
| 188 |
+
bbox = attr.get('bbox', [0, 0, 100, 100])
|
| 189 |
+
|
| 190 |
+
# Extract object region
|
| 191 |
+
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
| 192 |
+
object_crop = image[max(0, y1):min(image.shape[0], y2),
|
| 193 |
+
max(0, x1):min(image.shape[1], x2)]
|
| 194 |
+
|
| 195 |
+
# Classify with custom model
|
| 196 |
+
custom_label, custom_confidence = self.classify_object(object_crop)
|
| 197 |
+
|
| 198 |
+
# Decide which prediction to use
|
| 199 |
+
if custom_label and self.should_override_yolo(yolo_label, yolo_confidence,
|
| 200 |
+
custom_label, custom_confidence):
|
| 201 |
+
# Use custom model prediction
|
| 202 |
+
final_label = custom_label
|
| 203 |
+
final_confidence = custom_confidence
|
| 204 |
+
attr['prediction_source'] = 'custom_model'
|
| 205 |
+
attr['original_yolo'] = {'label': yolo_label, 'confidence': yolo_confidence}
|
| 206 |
+
else:
|
| 207 |
+
# Use YOLO prediction
|
| 208 |
+
final_label = yolo_label
|
| 209 |
+
final_confidence = yolo_confidence
|
| 210 |
+
attr['prediction_source'] = 'yolo'
|
| 211 |
+
if custom_label:
|
| 212 |
+
attr['custom_alternative'] = {'label': custom_label, 'confidence': custom_confidence}
|
| 213 |
+
|
| 214 |
+
# Update attributes
|
| 215 |
+
attr['label'] = final_label
|
| 216 |
+
attr['confidence'] = f"{final_confidence:.2%}"
|
| 217 |
+
|
| 218 |
+
enhanced_attributes.append(attr)
|
| 219 |
+
enhanced_objects.append(final_label)
|
| 220 |
+
|
| 221 |
+
# Update annotated image if we made changes
|
| 222 |
+
if any(attr.get('prediction_source') == 'custom_model' for attr in enhanced_attributes):
|
| 223 |
+
# Re-annotate image with updated predictions
|
| 224 |
+
annotated_img = self.annotate_image_with_predictions(image, enhanced_attributes)
|
| 225 |
+
|
| 226 |
+
return annotated_img, enhanced_objects, enhanced_attributes
|
| 227 |
+
|
| 228 |
+
def annotate_image_with_predictions(self, image: np.ndarray, attributes: List[Dict]) -> np.ndarray:
|
| 229 |
+
"""
|
| 230 |
+
Annotate image with updated predictions
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
image: Original image
|
| 234 |
+
attributes: Detection attributes with updated labels
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
Annotated image
|
| 238 |
+
"""
|
| 239 |
+
annotated = image.copy()
|
| 240 |
+
|
| 241 |
+
for attr in attributes:
|
| 242 |
+
bbox = attr.get('bbox', [0, 0, 100, 100])
|
| 243 |
+
label = attr['label']
|
| 244 |
+
confidence = attr['confidence']
|
| 245 |
+
source = attr.get('prediction_source', 'yolo')
|
| 246 |
+
|
| 247 |
+
x1, y1, x2, y2 = [int(coord) for coord in bbox]
|
| 248 |
+
|
| 249 |
+
# Choose color based on source
|
| 250 |
+
if source == 'custom_model':
|
| 251 |
+
color = (0, 255, 0) # Green for custom model
|
| 252 |
+
label_text = f"{label} {confidence} (Custom)"
|
| 253 |
+
else:
|
| 254 |
+
color = (255, 0, 0) # Red for YOLO
|
| 255 |
+
label_text = f"{label} {confidence}"
|
| 256 |
+
|
| 257 |
+
# Draw bounding box
|
| 258 |
+
cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
|
| 259 |
+
|
| 260 |
+
# Draw label
|
| 261 |
+
cv2.putText(annotated, label_text, (x1, y1-10),
|
| 262 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
| 263 |
+
|
| 264 |
+
return annotated
|
| 265 |
+
|
| 266 |
+
def get_model_info(self) -> Dict:
|
| 267 |
+
"""Get information about loaded models"""
|
| 268 |
+
info = {
|
| 269 |
+
'yolo_model': 'YOLOv8m',
|
| 270 |
+
'custom_model_loaded': self.custom_model is not None,
|
| 271 |
+
'device': self.device
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
if self.custom_model is not None and self.class_info is not None:
|
| 275 |
+
info.update({
|
| 276 |
+
'custom_classes': list(self.class_info['idx_to_label'].values()),
|
| 277 |
+
'num_custom_classes': self.class_info['num_classes'],
|
| 278 |
+
'training_samples': self.class_info.get('train_samples', 0),
|
| 279 |
+
'validation_samples': self.class_info.get('val_samples', 0)
|
| 280 |
+
})
|
| 281 |
+
|
| 282 |
+
return info
|
| 283 |
+
|
| 284 |
+
# Global two-stage inference instance
|
| 285 |
two_stage_inference = TwoStageInference()
|
backend/yolo.py
CHANGED
|
@@ -1,34 +1,34 @@
|
|
| 1 |
-
from ultralytics import YOLO # type: ignore
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
|
| 5 |
-
# Load a pre-trained YOLOv8 model (nano version = small & fast)
|
| 6 |
-
model = YOLO("yolov8n.pt")
|
| 7 |
-
|
| 8 |
-
def detect_objects(image):
|
| 9 |
-
"""
|
| 10 |
-
Run YOLO on the input image.
|
| 11 |
-
Returns:
|
| 12 |
-
- annotated image with bounding boxes
|
| 13 |
-
- list of detected object names
|
| 14 |
-
"""
|
| 15 |
-
# Handle different image formats and channel counts
|
| 16 |
-
if isinstance(image, np.ndarray):
|
| 17 |
-
# If image has 4 channels (RGBA), convert to RGB
|
| 18 |
-
if image.shape[-1] == 4:
|
| 19 |
-
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 20 |
-
# If image has 1 channel (grayscale), convert to RGB
|
| 21 |
-
elif len(image.shape) == 2 or image.shape[-1] == 1:
|
| 22 |
-
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 23 |
-
|
| 24 |
-
results = model(image)
|
| 25 |
-
annotated_img = results[0].plot()
|
| 26 |
-
|
| 27 |
-
# Extract detected object names
|
| 28 |
-
detected_objects = []
|
| 29 |
-
for box in results[0].boxes:
|
| 30 |
-
cls_id = int(box.cls[0].item()) # class ID
|
| 31 |
-
label = results[0].names[cls_id] # class name
|
| 32 |
-
detected_objects.append(label)
|
| 33 |
-
|
| 34 |
-
return annotated_img, detected_objects
|
|
|
|
| 1 |
+
from ultralytics import YOLO # type: ignore
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
# Load a pre-trained YOLOv8 model (nano version = small & fast)
|
| 6 |
+
model = YOLO("yolov8n.pt")
|
| 7 |
+
|
| 8 |
+
def detect_objects(image):
|
| 9 |
+
"""
|
| 10 |
+
Run YOLO on the input image.
|
| 11 |
+
Returns:
|
| 12 |
+
- annotated image with bounding boxes
|
| 13 |
+
- list of detected object names
|
| 14 |
+
"""
|
| 15 |
+
# Handle different image formats and channel counts
|
| 16 |
+
if isinstance(image, np.ndarray):
|
| 17 |
+
# If image has 4 channels (RGBA), convert to RGB
|
| 18 |
+
if image.shape[-1] == 4:
|
| 19 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 20 |
+
# If image has 1 channel (grayscale), convert to RGB
|
| 21 |
+
elif len(image.shape) == 2 or image.shape[-1] == 1:
|
| 22 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 23 |
+
|
| 24 |
+
results = model(image)
|
| 25 |
+
annotated_img = results[0].plot()
|
| 26 |
+
|
| 27 |
+
# Extract detected object names
|
| 28 |
+
detected_objects = []
|
| 29 |
+
for box in results[0].boxes:
|
| 30 |
+
cls_id = int(box.cls[0].item()) # class ID
|
| 31 |
+
label = results[0].names[cls_id] # class name
|
| 32 |
+
detected_objects.append(label)
|
| 33 |
+
|
| 34 |
+
return annotated_img, detected_objects
|
backend/yolo_enhanced.py
CHANGED
|
@@ -1,231 +1,231 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Enhanced YOLO detection with improved accuracy, color detection, and detailed attributes
|
| 3 |
-
"""
|
| 4 |
-
from ultralytics import YOLO # type: ignore
|
| 5 |
-
import cv2 # type: ignore
|
| 6 |
-
import numpy as np # type: ignore
|
| 7 |
-
from collections import Counter
|
| 8 |
-
import webcolors # type: ignore
|
| 9 |
-
# from sklearn.cluster import KMeans # type: ignore # Temporarily disabled due to numpy compatibility
|
| 10 |
-
import torch # type: ignore
|
| 11 |
-
|
| 12 |
-
# Load a more accurate YOLO model
|
| 13 |
-
# For better accuracy, use yolov8m.pt or yolov8l.pt instead of yolov8n.pt
|
| 14 |
-
model_size = 'yolov8m.pt' # Medium model for better accuracy vs speed balance
|
| 15 |
-
model = YOLO(model_size)
|
| 16 |
-
|
| 17 |
-
# Set higher confidence threshold for better accuracy
|
| 18 |
-
CONFIDENCE_THRESHOLD = 0.5 # Increase this for fewer but more accurate detections
|
| 19 |
-
NMS_THRESHOLD = 0.45 # Non-maximum suppression threshold
|
| 20 |
-
|
| 21 |
-
def get_dominant_colors(image, n_colors=3):
|
| 22 |
-
"""
|
| 23 |
-
Extract dominant colors from an image region using simple averaging
|
| 24 |
-
(K-means temporarily disabled due to numpy compatibility)
|
| 25 |
-
"""
|
| 26 |
-
try:
|
| 27 |
-
# Simple color detection without sklearn
|
| 28 |
-
# Get average color
|
| 29 |
-
avg_color = np.mean(image.reshape(-1, 3), axis=0).astype(int)
|
| 30 |
-
|
| 31 |
-
# Get corners for variety
|
| 32 |
-
h, w = image.shape[:2]
|
| 33 |
-
corners = [
|
| 34 |
-
image[0, 0], # Top-left
|
| 35 |
-
image[0, w-1] if w > 0 else image[0, 0], # Top-right
|
| 36 |
-
image[h-1, 0] if h > 0 else image[0, 0], # Bottom-left
|
| 37 |
-
image[h//2, w//2] if h > 0 and w > 0 else image[0, 0] # Center
|
| 38 |
-
]
|
| 39 |
-
|
| 40 |
-
color_names = []
|
| 41 |
-
# Add average color
|
| 42 |
-
try:
|
| 43 |
-
color_names.append(get_color_name(avg_color))
|
| 44 |
-
except:
|
| 45 |
-
color_names.append(f"RGB({avg_color[0]},{avg_color[1]},{avg_color[2]})")
|
| 46 |
-
|
| 47 |
-
# Add dominant corner color if different
|
| 48 |
-
for corner in corners[:n_colors-1]:
|
| 49 |
-
try:
|
| 50 |
-
name = get_color_name(corner)
|
| 51 |
-
if name not in color_names:
|
| 52 |
-
color_names.append(name)
|
| 53 |
-
if len(color_names) >= n_colors:
|
| 54 |
-
break
|
| 55 |
-
except:
|
| 56 |
-
pass
|
| 57 |
-
|
| 58 |
-
return color_names if color_names else ["Unknown"]
|
| 59 |
-
except:
|
| 60 |
-
return ["Unknown"]
|
| 61 |
-
|
| 62 |
-
def get_color_name(rgb_color):
|
| 63 |
-
"""
|
| 64 |
-
Convert RGB values to a human-readable color name
|
| 65 |
-
"""
|
| 66 |
-
min_colors = {}
|
| 67 |
-
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
|
| 68 |
-
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
|
| 69 |
-
rd = (r_c - rgb_color[0]) ** 2
|
| 70 |
-
gd = (g_c - rgb_color[1]) ** 2
|
| 71 |
-
bd = (b_c - rgb_color[2]) ** 2
|
| 72 |
-
min_colors[(rd + gd + bd)] = name
|
| 73 |
-
return min_colors[min(min_colors.keys())]
|
| 74 |
-
|
| 75 |
-
def analyze_object_attributes(image, box, label):
|
| 76 |
-
"""
|
| 77 |
-
Analyze detailed attributes of detected objects
|
| 78 |
-
"""
|
| 79 |
-
x1, y1, x2, y2 = box
|
| 80 |
-
object_region = image[int(y1):int(y2), int(x1):int(x2)]
|
| 81 |
-
|
| 82 |
-
attributes = {
|
| 83 |
-
'label': label,
|
| 84 |
-
'position': get_position_description(x1, y1, x2, y2, image.shape),
|
| 85 |
-
'size': get_size_description(x2-x1, y2-y1, image.shape),
|
| 86 |
-
'colors': get_dominant_colors(object_region, n_colors=2),
|
| 87 |
-
'confidence': None, # Will be set from detection
|
| 88 |
-
'bbox': [float(x1), float(y1), float(x2), float(y2)] # Add bounding box coordinates
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
return attributes
|
| 92 |
-
|
| 93 |
-
def get_position_description(x1, y1, x2, y2, image_shape):
|
| 94 |
-
"""
|
| 95 |
-
Describe object position in human terms
|
| 96 |
-
"""
|
| 97 |
-
h, w = image_shape[:2]
|
| 98 |
-
center_x = (x1 + x2) / 2
|
| 99 |
-
center_y = (y1 + y2) / 2
|
| 100 |
-
|
| 101 |
-
# Horizontal position
|
| 102 |
-
if center_x < w / 3:
|
| 103 |
-
h_pos = "left"
|
| 104 |
-
elif center_x > 2 * w / 3:
|
| 105 |
-
h_pos = "right"
|
| 106 |
-
else:
|
| 107 |
-
h_pos = "center"
|
| 108 |
-
|
| 109 |
-
# Vertical position
|
| 110 |
-
if center_y < h / 3:
|
| 111 |
-
v_pos = "top"
|
| 112 |
-
elif center_y > 2 * h / 3:
|
| 113 |
-
v_pos = "bottom"
|
| 114 |
-
else:
|
| 115 |
-
v_pos = "middle"
|
| 116 |
-
|
| 117 |
-
if h_pos == "center" and v_pos == "middle":
|
| 118 |
-
return "center"
|
| 119 |
-
elif v_pos == "middle":
|
| 120 |
-
return h_pos
|
| 121 |
-
elif h_pos == "center":
|
| 122 |
-
return v_pos
|
| 123 |
-
else:
|
| 124 |
-
return f"{v_pos}-{h_pos}"
|
| 125 |
-
|
| 126 |
-
def get_size_description(width, height, image_shape):
|
| 127 |
-
"""
|
| 128 |
-
Describe object size relative to image
|
| 129 |
-
"""
|
| 130 |
-
img_area = image_shape[0] * image_shape[1]
|
| 131 |
-
obj_area = width * height
|
| 132 |
-
ratio = obj_area / img_area
|
| 133 |
-
|
| 134 |
-
if ratio > 0.5:
|
| 135 |
-
return "very large"
|
| 136 |
-
elif ratio > 0.25:
|
| 137 |
-
return "large"
|
| 138 |
-
elif ratio > 0.1:
|
| 139 |
-
return "medium"
|
| 140 |
-
elif ratio > 0.05:
|
| 141 |
-
return "small"
|
| 142 |
-
else:
|
| 143 |
-
return "tiny"
|
| 144 |
-
|
| 145 |
-
def detect_objects_enhanced(image, confidence_threshold=CONFIDENCE_THRESHOLD):
|
| 146 |
-
"""
|
| 147 |
-
Enhanced YOLO detection with improved accuracy and detailed attributes
|
| 148 |
-
Returns:
|
| 149 |
-
- annotated image with bounding boxes
|
| 150 |
-
- list of detected object names
|
| 151 |
-
- detailed attributes for each detection
|
| 152 |
-
"""
|
| 153 |
-
# Handle different image formats
|
| 154 |
-
if isinstance(image, np.ndarray):
|
| 155 |
-
if image.shape[-1] == 4:
|
| 156 |
-
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 157 |
-
elif len(image.shape) == 2 or image.shape[-1] == 1:
|
| 158 |
-
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 159 |
-
|
| 160 |
-
# Run YOLO with custom parameters for better accuracy
|
| 161 |
-
results = model(
|
| 162 |
-
image,
|
| 163 |
-
conf=confidence_threshold, # Confidence threshold
|
| 164 |
-
iou=NMS_THRESHOLD, # NMS IoU threshold
|
| 165 |
-
imgsz=640, # Image size (can increase for better accuracy)
|
| 166 |
-
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
# Get annotated image
|
| 170 |
-
annotated_img = results[0].plot(
|
| 171 |
-
conf=True, # Show confidence scores
|
| 172 |
-
line_width=2,
|
| 173 |
-
font_size=10
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
# Extract detailed information
|
| 177 |
-
detected_objects = []
|
| 178 |
-
detailed_attributes = []
|
| 179 |
-
|
| 180 |
-
for box in results[0].boxes:
|
| 181 |
-
if box.conf[0] >= confidence_threshold: # Double-check confidence
|
| 182 |
-
cls_id = int(box.cls[0].item())
|
| 183 |
-
label = results[0].names[cls_id]
|
| 184 |
-
confidence = float(box.conf[0].item())
|
| 185 |
-
|
| 186 |
-
# Get box coordinates
|
| 187 |
-
xyxy = box.xyxy[0].tolist()
|
| 188 |
-
|
| 189 |
-
# Analyze attributes
|
| 190 |
-
attributes = analyze_object_attributes(image, xyxy, label)
|
| 191 |
-
attributes['confidence'] = f"{confidence:.2%}"
|
| 192 |
-
|
| 193 |
-
detected_objects.append(label)
|
| 194 |
-
detailed_attributes.append(attributes)
|
| 195 |
-
|
| 196 |
-
return annotated_img, detected_objects, detailed_attributes
|
| 197 |
-
|
| 198 |
-
def get_intelligence_report(detailed_attributes):
|
| 199 |
-
"""
|
| 200 |
-
Generate an intelligent report about detected objects
|
| 201 |
-
"""
|
| 202 |
-
if not detailed_attributes:
|
| 203 |
-
return "No objects detected in the image."
|
| 204 |
-
|
| 205 |
-
report = []
|
| 206 |
-
report.append(f"Detected {len(detailed_attributes)} object(s):")
|
| 207 |
-
|
| 208 |
-
for attr in detailed_attributes:
|
| 209 |
-
colors_str = " and ".join(attr['colors'][:2]) if attr['colors'] else "unknown colors"
|
| 210 |
-
report.append(
|
| 211 |
-
f"- A {attr['size']} {colors_str} {attr['label']} "
|
| 212 |
-
f"in the {attr['position']} of the image "
|
| 213 |
-
f"(confidence: {attr['confidence']})"
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
# Add summary statistics
|
| 217 |
-
object_types = Counter([attr['label'] for attr in detailed_attributes])
|
| 218 |
-
if len(object_types) > 1:
|
| 219 |
-
report.append("\nSummary:")
|
| 220 |
-
for obj_type, count in object_types.most_common():
|
| 221 |
-
report.append(f" • {count} {obj_type}(s)")
|
| 222 |
-
|
| 223 |
-
return "\n".join(report)
|
| 224 |
-
|
| 225 |
-
# Backward compatibility wrapper
|
| 226 |
-
def detect_objects(image):
|
| 227 |
-
"""
|
| 228 |
-
Wrapper for backward compatibility with original function
|
| 229 |
-
"""
|
| 230 |
-
annotated_img, detected_objects, _ = detect_objects_enhanced(image)
|
| 231 |
return annotated_img, detected_objects
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced YOLO detection with improved accuracy, color detection, and detailed attributes
|
| 3 |
+
"""
|
| 4 |
+
from ultralytics import YOLO # type: ignore
|
| 5 |
+
import cv2 # type: ignore
|
| 6 |
+
import numpy as np # type: ignore
|
| 7 |
+
from collections import Counter
|
| 8 |
+
import webcolors # type: ignore
|
| 9 |
+
# from sklearn.cluster import KMeans # type: ignore # Temporarily disabled due to numpy compatibility
|
| 10 |
+
import torch # type: ignore
|
| 11 |
+
|
| 12 |
+
# Load a more accurate YOLO model
|
| 13 |
+
# For better accuracy, use yolov8m.pt or yolov8l.pt instead of yolov8n.pt
|
| 14 |
+
model_size = 'yolov8m.pt' # Medium model for better accuracy vs speed balance
|
| 15 |
+
model = YOLO(model_size)
|
| 16 |
+
|
| 17 |
+
# Set higher confidence threshold for better accuracy
|
| 18 |
+
CONFIDENCE_THRESHOLD = 0.5 # Increase this for fewer but more accurate detections
|
| 19 |
+
NMS_THRESHOLD = 0.45 # Non-maximum suppression threshold
|
| 20 |
+
|
| 21 |
+
def get_dominant_colors(image, n_colors=3):
|
| 22 |
+
"""
|
| 23 |
+
Extract dominant colors from an image region using simple averaging
|
| 24 |
+
(K-means temporarily disabled due to numpy compatibility)
|
| 25 |
+
"""
|
| 26 |
+
try:
|
| 27 |
+
# Simple color detection without sklearn
|
| 28 |
+
# Get average color
|
| 29 |
+
avg_color = np.mean(image.reshape(-1, 3), axis=0).astype(int)
|
| 30 |
+
|
| 31 |
+
# Get corners for variety
|
| 32 |
+
h, w = image.shape[:2]
|
| 33 |
+
corners = [
|
| 34 |
+
image[0, 0], # Top-left
|
| 35 |
+
image[0, w-1] if w > 0 else image[0, 0], # Top-right
|
| 36 |
+
image[h-1, 0] if h > 0 else image[0, 0], # Bottom-left
|
| 37 |
+
image[h//2, w//2] if h > 0 and w > 0 else image[0, 0] # Center
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
color_names = []
|
| 41 |
+
# Add average color
|
| 42 |
+
try:
|
| 43 |
+
color_names.append(get_color_name(avg_color))
|
| 44 |
+
except:
|
| 45 |
+
color_names.append(f"RGB({avg_color[0]},{avg_color[1]},{avg_color[2]})")
|
| 46 |
+
|
| 47 |
+
# Add dominant corner color if different
|
| 48 |
+
for corner in corners[:n_colors-1]:
|
| 49 |
+
try:
|
| 50 |
+
name = get_color_name(corner)
|
| 51 |
+
if name not in color_names:
|
| 52 |
+
color_names.append(name)
|
| 53 |
+
if len(color_names) >= n_colors:
|
| 54 |
+
break
|
| 55 |
+
except:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
return color_names if color_names else ["Unknown"]
|
| 59 |
+
except:
|
| 60 |
+
return ["Unknown"]
|
| 61 |
+
|
| 62 |
+
def get_color_name(rgb_color):
|
| 63 |
+
"""
|
| 64 |
+
Convert RGB values to a human-readable color name
|
| 65 |
+
"""
|
| 66 |
+
min_colors = {}
|
| 67 |
+
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
|
| 68 |
+
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
|
| 69 |
+
rd = (r_c - rgb_color[0]) ** 2
|
| 70 |
+
gd = (g_c - rgb_color[1]) ** 2
|
| 71 |
+
bd = (b_c - rgb_color[2]) ** 2
|
| 72 |
+
min_colors[(rd + gd + bd)] = name
|
| 73 |
+
return min_colors[min(min_colors.keys())]
|
| 74 |
+
|
| 75 |
+
def analyze_object_attributes(image, box, label):
|
| 76 |
+
"""
|
| 77 |
+
Analyze detailed attributes of detected objects
|
| 78 |
+
"""
|
| 79 |
+
x1, y1, x2, y2 = box
|
| 80 |
+
object_region = image[int(y1):int(y2), int(x1):int(x2)]
|
| 81 |
+
|
| 82 |
+
attributes = {
|
| 83 |
+
'label': label,
|
| 84 |
+
'position': get_position_description(x1, y1, x2, y2, image.shape),
|
| 85 |
+
'size': get_size_description(x2-x1, y2-y1, image.shape),
|
| 86 |
+
'colors': get_dominant_colors(object_region, n_colors=2),
|
| 87 |
+
'confidence': None, # Will be set from detection
|
| 88 |
+
'bbox': [float(x1), float(y1), float(x2), float(y2)] # Add bounding box coordinates
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return attributes
|
| 92 |
+
|
| 93 |
+
def get_position_description(x1, y1, x2, y2, image_shape):
|
| 94 |
+
"""
|
| 95 |
+
Describe object position in human terms
|
| 96 |
+
"""
|
| 97 |
+
h, w = image_shape[:2]
|
| 98 |
+
center_x = (x1 + x2) / 2
|
| 99 |
+
center_y = (y1 + y2) / 2
|
| 100 |
+
|
| 101 |
+
# Horizontal position
|
| 102 |
+
if center_x < w / 3:
|
| 103 |
+
h_pos = "left"
|
| 104 |
+
elif center_x > 2 * w / 3:
|
| 105 |
+
h_pos = "right"
|
| 106 |
+
else:
|
| 107 |
+
h_pos = "center"
|
| 108 |
+
|
| 109 |
+
# Vertical position
|
| 110 |
+
if center_y < h / 3:
|
| 111 |
+
v_pos = "top"
|
| 112 |
+
elif center_y > 2 * h / 3:
|
| 113 |
+
v_pos = "bottom"
|
| 114 |
+
else:
|
| 115 |
+
v_pos = "middle"
|
| 116 |
+
|
| 117 |
+
if h_pos == "center" and v_pos == "middle":
|
| 118 |
+
return "center"
|
| 119 |
+
elif v_pos == "middle":
|
| 120 |
+
return h_pos
|
| 121 |
+
elif h_pos == "center":
|
| 122 |
+
return v_pos
|
| 123 |
+
else:
|
| 124 |
+
return f"{v_pos}-{h_pos}"
|
| 125 |
+
|
| 126 |
+
def get_size_description(width, height, image_shape):
|
| 127 |
+
"""
|
| 128 |
+
Describe object size relative to image
|
| 129 |
+
"""
|
| 130 |
+
img_area = image_shape[0] * image_shape[1]
|
| 131 |
+
obj_area = width * height
|
| 132 |
+
ratio = obj_area / img_area
|
| 133 |
+
|
| 134 |
+
if ratio > 0.5:
|
| 135 |
+
return "very large"
|
| 136 |
+
elif ratio > 0.25:
|
| 137 |
+
return "large"
|
| 138 |
+
elif ratio > 0.1:
|
| 139 |
+
return "medium"
|
| 140 |
+
elif ratio > 0.05:
|
| 141 |
+
return "small"
|
| 142 |
+
else:
|
| 143 |
+
return "tiny"
|
| 144 |
+
|
| 145 |
+
def detect_objects_enhanced(image, confidence_threshold=CONFIDENCE_THRESHOLD):
|
| 146 |
+
"""
|
| 147 |
+
Enhanced YOLO detection with improved accuracy and detailed attributes
|
| 148 |
+
Returns:
|
| 149 |
+
- annotated image with bounding boxes
|
| 150 |
+
- list of detected object names
|
| 151 |
+
- detailed attributes for each detection
|
| 152 |
+
"""
|
| 153 |
+
# Handle different image formats
|
| 154 |
+
if isinstance(image, np.ndarray):
|
| 155 |
+
if image.shape[-1] == 4:
|
| 156 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 157 |
+
elif len(image.shape) == 2 or image.shape[-1] == 1:
|
| 158 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 159 |
+
|
| 160 |
+
# Run YOLO with custom parameters for better accuracy
|
| 161 |
+
results = model(
|
| 162 |
+
image,
|
| 163 |
+
conf=confidence_threshold, # Confidence threshold
|
| 164 |
+
iou=NMS_THRESHOLD, # NMS IoU threshold
|
| 165 |
+
imgsz=640, # Image size (can increase for better accuracy)
|
| 166 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Get annotated image
|
| 170 |
+
annotated_img = results[0].plot(
|
| 171 |
+
conf=True, # Show confidence scores
|
| 172 |
+
line_width=2,
|
| 173 |
+
font_size=10
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Extract detailed information
|
| 177 |
+
detected_objects = []
|
| 178 |
+
detailed_attributes = []
|
| 179 |
+
|
| 180 |
+
for box in results[0].boxes:
|
| 181 |
+
if box.conf[0] >= confidence_threshold: # Double-check confidence
|
| 182 |
+
cls_id = int(box.cls[0].item())
|
| 183 |
+
label = results[0].names[cls_id]
|
| 184 |
+
confidence = float(box.conf[0].item())
|
| 185 |
+
|
| 186 |
+
# Get box coordinates
|
| 187 |
+
xyxy = box.xyxy[0].tolist()
|
| 188 |
+
|
| 189 |
+
# Analyze attributes
|
| 190 |
+
attributes = analyze_object_attributes(image, xyxy, label)
|
| 191 |
+
attributes['confidence'] = f"{confidence:.2%}"
|
| 192 |
+
|
| 193 |
+
detected_objects.append(label)
|
| 194 |
+
detailed_attributes.append(attributes)
|
| 195 |
+
|
| 196 |
+
return annotated_img, detected_objects, detailed_attributes
|
| 197 |
+
|
| 198 |
+
def get_intelligence_report(detailed_attributes):
|
| 199 |
+
"""
|
| 200 |
+
Generate an intelligent report about detected objects
|
| 201 |
+
"""
|
| 202 |
+
if not detailed_attributes:
|
| 203 |
+
return "No objects detected in the image."
|
| 204 |
+
|
| 205 |
+
report = []
|
| 206 |
+
report.append(f"Detected {len(detailed_attributes)} object(s):")
|
| 207 |
+
|
| 208 |
+
for attr in detailed_attributes:
|
| 209 |
+
colors_str = " and ".join(attr['colors'][:2]) if attr['colors'] else "unknown colors"
|
| 210 |
+
report.append(
|
| 211 |
+
f"- A {attr['size']} {colors_str} {attr['label']} "
|
| 212 |
+
f"in the {attr['position']} of the image "
|
| 213 |
+
f"(confidence: {attr['confidence']})"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Add summary statistics
|
| 217 |
+
object_types = Counter([attr['label'] for attr in detailed_attributes])
|
| 218 |
+
if len(object_types) > 1:
|
| 219 |
+
report.append("\nSummary:")
|
| 220 |
+
for obj_type, count in object_types.most_common():
|
| 221 |
+
report.append(f" • {count} {obj_type}(s)")
|
| 222 |
+
|
| 223 |
+
return "\n".join(report)
|
| 224 |
+
|
| 225 |
+
# Backward compatibility wrapper
|
| 226 |
+
def detect_objects(image):
|
| 227 |
+
"""
|
| 228 |
+
Wrapper for backward compatibility with original function
|
| 229 |
+
"""
|
| 230 |
+
annotated_img, detected_objects, _ = detect_objects_enhanced(image)
|
| 231 |
return annotated_img, detected_objects
|
packages.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
ffmpeg
|
| 2 |
-
libsm6
|
| 3 |
-
libxext6
|
| 4 |
-
libxrender-dev
|
| 5 |
libglib2.0-0
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
libsm6
|
| 3 |
+
libxext6
|
| 4 |
+
libxrender-dev
|
| 5 |
libglib2.0-0
|
requirements.txt
CHANGED
|
@@ -1,16 +1,13 @@
|
|
| 1 |
-
streamlit>=1.28.0
|
| 2 |
-
ultralytics>=8.0.0
|
| 3 |
-
openai>=1.0.0
|
| 4 |
-
opencv-python-headless>=4.8.0
|
| 5 |
-
pillow>=10.0.0
|
| 6 |
-
numpy>=1.24.0,<2.0.0
|
| 7 |
-
torch>=2.0.0
|
| 8 |
-
torchvision>=0.15.0
|
| 9 |
-
python-dotenv>=1.0.0
|
| 10 |
-
plotly>=5.17.0
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
webcolors>=1.13
|
| 15 |
-
face-recognition>=1.3.0
|
| 16 |
-
dlib>=19.24.0
|
|
|
|
| 1 |
+
streamlit>=1.28.0
|
| 2 |
+
ultralytics>=8.0.0
|
| 3 |
+
openai>=1.0.0
|
| 4 |
+
opencv-python-headless>=4.8.0
|
| 5 |
+
pillow>=10.0.0
|
| 6 |
+
numpy>=1.24.0,<2.0.0
|
| 7 |
+
torch>=2.0.0,<2.5.0
|
| 8 |
+
torchvision>=0.15.0,<0.20.0
|
| 9 |
+
python-dotenv>=1.0.0
|
| 10 |
+
plotly>=5.17.0
|
| 11 |
+
requests>=2.31.0
|
| 12 |
+
pandas>=1.5.0,<2.1.0
|
| 13 |
+
webcolors>=1.13
|
|
|
|
|
|
|
|
|
requirements_lite.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.28.0
|
| 2 |
+
ultralytics>=8.0.0
|
| 3 |
+
openai>=1.0.0
|
| 4 |
+
opencv-python-headless>=4.8.0
|
| 5 |
+
pillow>=10.0.0
|
| 6 |
+
numpy>=1.24.0,<2.0.0
|
| 7 |
+
torch>=2.0.0
|
| 8 |
+
torchvision>=0.15.0
|
| 9 |
+
python-dotenv>=1.0.0
|
| 10 |
+
plotly>=5.17.0
|
| 11 |
+
kaleido>=0.2.1
|
| 12 |
+
requests>=2.31.0
|
| 13 |
+
pandas>=1.5.0,<2.1.0
|
| 14 |
+
webcolors>=1.13
|