Navada25 commited on
Commit
75f48fa
·
verified ·
1 Parent(s): f8a3c0e

Deploy NAVADA 2.0 Lite - Optimized for HF Spaces (no face recognition)

Browse files
.gitattributes CHANGED
@@ -1 +1 @@
1
- *.pt filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,45 +1,45 @@
1
- # Python
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
- *.so
6
- .Python
7
- env/
8
- venv/
9
- .venv/
10
- ENV/
11
-
12
- # Environment variables
13
- .env
14
- .env.local
15
-
16
- # IDE
17
- .vscode/
18
- .idea/
19
- *.swp
20
- *.swo
21
-
22
- # OS
23
- .DS_Store
24
- Thumbs.db
25
-
26
- # Streamlit
27
- .streamlit/secrets.toml
28
-
29
- # Database
30
- *.db
31
- *.sqlite
32
- *.sqlite3
33
-
34
- # Models (if too large)
35
- yolov8m.pt
36
- yolov8l.pt
37
- yolov8x.pt
38
-
39
- # Logs
40
- *.log
41
-
42
- # Temporary files
43
- *.tmp
44
- temp/
45
  tmp/
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ venv/
9
+ .venv/
10
+ ENV/
11
+
12
+ # Environment variables
13
+ .env
14
+ .env.local
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+
22
+ # OS
23
+ .DS_Store
24
+ Thumbs.db
25
+
26
+ # Streamlit
27
+ .streamlit/secrets.toml
28
+
29
+ # Database
30
+ *.db
31
+ *.sqlite
32
+ *.sqlite3
33
+
34
+ # Models (if too large)
35
+ yolov8m.pt
36
+ yolov8l.pt
37
+ yolov8x.pt
38
+
39
+ # Logs
40
+ *.log
41
+
42
+ # Temporary files
43
+ *.tmp
44
+ temp/
45
  tmp/
README.md CHANGED
@@ -1,44 +1,44 @@
1
- ---
2
- title: NAVADA 2.0 - Advanced AI Computer Vision
3
- emoji: 🚀
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: streamlit
7
- sdk_version: 1.28.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- # 🚀 NAVADA 2.0 - Advanced AI Computer Vision Application
14
-
15
- An advanced AI-powered computer vision application featuring:
16
- - 🎯 Real-time object detection using YOLOv8
17
- - 👤 Face detection and recognition
18
- - 🤖 AI-powered explanations
19
- - 📊 Interactive analytics
20
- - 🎤 Voice narration
21
- - 💬 Intelligent chat agent
22
-
23
- ## Features
24
- - **Object Detection**: State-of-the-art YOLOv8 model for accurate object detection
25
- - **Face Recognition**: Advanced face detection with emotion and feature analysis
26
- - **AI Intelligence**: OpenAI-powered explanations and insights
27
- - **Interactive Charts**: Real-time visualization of detection results
28
- - **Voice Output**: Text-to-speech narration of detection results
29
- - **Chat Interface**: Intelligent assistant for image analysis
30
-
31
- ## Usage
32
- 1. Upload an image using the file uploader
33
- 2. The system will automatically detect objects and faces
34
- 3. View detailed analytics and AI-generated explanations
35
- 4. Interact with the chat agent for deeper insights
36
-
37
- ## Technology Stack
38
- - Streamlit for the web interface
39
- - YOLOv8 for object detection
40
- - OpenAI API for intelligent analysis
41
- - Face Recognition library
42
- - Plotly for interactive visualizations
43
-
44
  Created by Lee Akpareva | AI Consultant & Computer Vision Specialist
 
1
+ ---
2
+ title: NAVADA 2.0 - Advanced AI Computer Vision
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.28.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # 🚀 NAVADA 2.0 - Advanced AI Computer Vision Application
14
+
15
+ An advanced AI-powered computer vision application featuring:
16
+ - 🎯 Real-time object detection using YOLOv8
17
+ - 👤 Face detection and recognition
18
+ - 🤖 AI-powered explanations
19
+ - 📊 Interactive analytics
20
+ - 🎤 Voice narration
21
+ - 💬 Intelligent chat agent
22
+
23
+ ## Features
24
+ - **Object Detection**: State-of-the-art YOLOv8 model for accurate object detection
25
+ - **Face Recognition**: Advanced face detection with emotion and feature analysis
26
+ - **AI Intelligence**: OpenAI-powered explanations and insights
27
+ - **Interactive Charts**: Real-time visualization of detection results
28
+ - **Voice Output**: Text-to-speech narration of detection results
29
+ - **Chat Interface**: Intelligent assistant for image analysis
30
+
31
+ ## Usage
32
+ 1. Upload an image using the file uploader
33
+ 2. The system will automatically detect objects and faces
34
+ 3. View detailed analytics and AI-generated explanations
35
+ 4. Interact with the chat agent for deeper insights
36
+
37
+ ## Technology Stack
38
+ - Streamlit for the web interface
39
+ - YOLOv8 for object detection
40
+ - OpenAI API for intelligent analysis
41
+ - Face Recognition library
42
+ - Plotly for interactive visualizations
43
+
44
  Created by Lee Akpareva | AI Consultant & Computer Vision Specialist
app.py CHANGED
@@ -1,1254 +1,271 @@
1
- """
2
- 🚀 NAVADA 2.0 - Advanced AI Computer Vision Application
3
- Streamlit Version for Hugging Face Spaces Deployment
4
-
5
- Enhanced Edition by Lee Akpareva | AI Consultant & Computer Vision Specialist
6
- """
7
-
8
- import streamlit as st # type: ignore
9
-
10
- # Configure Streamlit page (MUST be first!)
11
- st.set_page_config(
12
- page_title="🚀 NAVADA 2.0 - AI Computer Vision",
13
- page_icon="🚀",
14
- layout="wide",
15
- initial_sidebar_state="expanded"
16
- )
17
-
18
- # Add Font Awesome CSS
19
- st.markdown("""
20
- <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
21
- <style>
22
- .fa-icon { margin-right: 8px; }
23
- .fa-primary { color: #3498db; }
24
- .fa-success { color: #27ae60; }
25
- .fa-warning { color: #f39c12; }
26
- .fa-error { color: #e74c3c; }
27
- .fa-spin { animation: fa-spin 2s infinite linear; }
28
- </style>
29
- """, unsafe_allow_html=True)
30
-
31
- # Font Awesome icon mapping function
32
- def fa_icon(icon_class, color="primary", text=""):
33
- """Generate Font Awesome icon HTML"""
34
- return f'<i class="fas fa-{icon_class} fa-{color} fa-icon"></i>{text}'
35
- import time
36
- from datetime import datetime
37
- import plotly.graph_objects as go # type: ignore
38
- import plotly.express as px # type: ignore
39
- from PIL import Image # type: ignore
40
- import numpy as np # type: ignore
41
-
42
- # Backend imports
43
- try:
44
- from backend.yolo_enhanced import detect_objects_enhanced, get_intelligence_report
45
- from backend.yolo import detect_objects # Keep original for fallback
46
- from backend.openai_client import explain_detection, generate_voice
47
- from backend.face_detection import face_detector
48
- from backend.recognition import recognition_system
49
- from backend.database import db
50
- from backend.chat_agent import chat_with_agent, reset_chat, get_chat_history
51
- from backend.two_stage_inference import two_stage_inference
52
- except ImportError as e:
53
- st.error(f"⚠️ Import error: {e}")
54
- st.error("📦 Please install dependencies: pip install -r requirements.txt")
55
- st.stop()
56
-
57
- # Page configuration moved to top of file
58
-
59
- # Custom CSS for enhanced styling
60
- st.markdown("""
61
- <style>
62
- .main-header {
63
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
64
- padding: 2rem;
65
- border-radius: 10px;
66
- color: white;
67
- text-align: center;
68
- margin-bottom: 2rem;
69
- }
70
-
71
- .feature-card {
72
- background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
73
- padding: 1.5rem;
74
- border-radius: 10px;
75
- color: white;
76
- margin: 1rem 0;
77
- }
78
-
79
- .stats-card {
80
- background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
81
- padding: 1rem;
82
- border-radius: 8px;
83
- color: white;
84
- text-align: center;
85
- margin: 0.5rem;
86
- }
87
-
88
- .launch-button {
89
- background: linear-gradient(135deg, #000000 0%, #434343 100%);
90
- color: white;
91
- padding: 1rem 2rem;
92
- border: none;
93
- border-radius: 8px;
94
- font-size: 1.2rem;
95
- font-weight: bold;
96
- cursor: pointer;
97
- width: 100%;
98
- margin: 1rem 0;
99
- }
100
-
101
- .stButton > button {
102
- background: linear-gradient(135deg, #000000 0%, #434343 100%);
103
- color: white;
104
- border: none;
105
- border-radius: 8px;
106
- font-weight: bold;
107
- }
108
-
109
- .compass {
110
- position: fixed;
111
- top: 10px;
112
- right: 10px;
113
- background: rgba(0,0,0,0.7);
114
- color: white;
115
- padding: 10px;
116
- border-radius: 50%;
117
- font-size: 16px;
118
- z-index: 1000;
119
- }
120
- </style>
121
- """, unsafe_allow_html=True)
122
-
123
- # Compass (News indicator)
124
- st.markdown("""
125
- <div class="compass">
126
- 📰 NEWS
127
- </div>
128
- """, unsafe_allow_html=True)
129
-
130
- # Main header
131
- st.markdown(f"""
132
- <div class="main-header">
133
- <h1><i class="fas fa-rocket fa-primary fa-icon"></i>NAVADA 2.0 - Advanced AI Computer Vision</h1>
134
- <h3><i class="fas fa-brain fa-primary fa-icon"></i>Real-time Computer Vision with Custom Recognition Database & RAG Technology</h3>
135
- <p><strong>Enhanced Edition by Lee Akpareva</strong> | AI Consultant & Computer Vision Specialist</p>
136
- <p><i class="fas fa-crosshairs fa-primary fa-icon"></i>AI Computer Vision Application Designed for Hugging Face - Build ML Models in 15 Minutes</p>
137
- </div>
138
- """, unsafe_allow_html=True)
139
-
140
- # Initialize session state
141
- if 'processing_complete' not in st.session_state:
142
- st.session_state.processing_complete = False
143
- if 'last_results' not in st.session_state:
144
- st.session_state.last_results = None
145
- if 'chat_messages' not in st.session_state:
146
- st.session_state.chat_messages = []
147
- if 'use_enhanced' not in st.session_state:
148
- st.session_state.use_enhanced = True
149
-
150
- def create_detection_chart(detected_objects, face_stats=None, face_matches=None):
151
- """Create an interactive chart showing detection statistics"""
152
-
153
- # Count object types
154
- object_counts = {}
155
- for obj in detected_objects:
156
- object_counts[obj] = object_counts.get(obj, 0) + 1
157
-
158
- # Add face detection to counts
159
- if face_stats and face_stats.get('total_faces', 0) > 0:
160
- object_counts['Faces'] = face_stats['total_faces']
161
- if face_stats.get('features_detected', {}).get('smiles', 0) > 0:
162
- object_counts['Smiles'] = face_stats['features_detected']['smiles']
163
-
164
- # Add recognized faces
165
- if face_matches:
166
- known_faces = sum(1 for match in face_matches if match['name'] != 'Unknown')
167
- if known_faces > 0:
168
- object_counts['Known Faces'] = known_faces
169
-
170
- if not object_counts:
171
- fig = go.Figure()
172
- fig.add_annotation(
173
- text="No objects detected",
174
- xref="paper", yref="paper",
175
- x=0.5, y=0.5, showarrow=False,
176
- font=dict(size=20, color="gray")
177
- )
178
- fig.update_layout(
179
- height=300,
180
- title="Detection Results",
181
- template="plotly_dark"
182
- )
183
- return fig
184
-
185
- # Create bar chart
186
- fig = go.Figure(data=[
187
- go.Bar(
188
- x=list(object_counts.keys()),
189
- y=list(object_counts.values()),
190
- marker_color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57', '#FF9FF3', '#54A0FF'],
191
- text=list(object_counts.values()),
192
- textposition='auto',
193
- )
194
- ])
195
-
196
- fig.update_layout(
197
- title="🎯 Detection Statistics",
198
- xaxis_title="Detected Items",
199
- yaxis_title="Count",
200
- height=400,
201
- template="plotly_dark"
202
- )
203
-
204
- return fig
205
-
206
- def create_confidence_pie_chart(detected_objects, face_matches=None):
207
- """Create a confidence distribution pie chart"""
208
- try:
209
- # This is a simplified version - in the full app you'd get actual confidence scores
210
- categories = list(set(detected_objects)) if detected_objects else []
211
- if face_matches:
212
- categories.extend([match['name'] for match in face_matches if match['name'] != 'Unknown'])
213
-
214
- if not categories:
215
- return None
216
-
217
- # Generate sample confidence data
218
- values = [len([obj for obj in detected_objects if obj == cat]) for cat in set(detected_objects)]
219
-
220
- fig = go.Figure(data=[go.Pie(
221
- labels=list(set(detected_objects)),
222
- values=values,
223
- hole=.3,
224
- marker_colors=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57']
225
- )])
226
-
227
- fig.update_layout(
228
- title="📊 Detection Distribution",
229
- height=400,
230
- template="plotly_dark"
231
- )
232
-
233
- return fig
234
- except:
235
- return None
236
-
237
- def process_image(image, enable_voice=False, enable_face_detection=False, enable_recognition=False, use_enhanced=True, confidence_threshold=0.5):
238
- """Process uploaded image with all NAVADA 2.0 features"""
239
- try:
240
- if image is None:
241
- return None, "No image provided", None, None, None, None
242
-
243
- start_time = time.time()
244
-
245
- # Convert PIL to numpy array
246
- image_array = np.array(image)
247
-
248
- # Object detection - use two-stage inference, enhanced, or standard
249
- detailed_attributes = None
250
- if use_enhanced:
251
- try:
252
- # Try two-stage inference first (YOLO + Custom Model)
253
- detected_img, detected_objects, detailed_attributes = two_stage_inference.detect_with_custom_model(
254
- image_array, confidence_threshold
255
- )
256
- except:
257
- try:
258
- # Fallback to enhanced YOLO only
259
- detected_img, detected_objects, detailed_attributes = detect_objects_enhanced(image_array, confidence_threshold)
260
- except:
261
- # Final fallback to standard detection
262
- detected_img, detected_objects = detect_objects(image_array)
263
- else:
264
- detected_img, detected_objects = detect_objects(image_array)
265
-
266
- # Face detection if enabled
267
- face_stats = None
268
- face_matches = None
269
- if enable_face_detection and face_detector:
270
- detected_img, face_stats = face_detector.detect_faces(detected_img)
271
-
272
- # Face recognition if enabled
273
- if enable_recognition and recognition_system:
274
- detected_img, face_matches = recognition_system.recognize_faces(detected_img)
275
-
276
- # AI explanation - enhanced version includes detailed attributes
277
- if detailed_attributes:
278
- ai_explanation = get_intelligence_report(detailed_attributes)
279
- else:
280
- ai_explanation = explain_detection(detected_objects)
281
-
282
- # RAG enhancement if recognition enabled
283
- if enable_recognition and recognition_system:
284
- rag_enhancement = recognition_system.enhance_with_rag(detected_objects, face_matches)
285
- ai_explanation = f"{ai_explanation}\n\n{rag_enhancement}"
286
-
287
- # Voice generation if enabled
288
- audio_file = None
289
- if enable_voice:
290
- try:
291
- st.info("🔊 Generating voice narration...")
292
- audio_file = generate_voice(ai_explanation)
293
- if audio_file:
294
- st.success("✅ Voice narration generated successfully!")
295
- else:
296
- st.error("❌ Voice generation failed - no audio file created")
297
- except Exception as e:
298
- st.error(f"❌ Voice generation failed: {e}")
299
- import traceback
300
- st.error(f"Details: {traceback.format_exc()}")
301
-
302
- # Save session data
303
- processing_time = time.time() - start_time
304
- if recognition_system:
305
- recognition_system.save_session_data(
306
- image_array, detected_objects, face_matches, processing_time
307
- )
308
-
309
- return detected_img, ai_explanation, detected_objects, face_stats, face_matches, audio_file, detailed_attributes
310
-
311
- except Exception as e:
312
- st.error(f"Processing failed: {e}")
313
- return None, f"Error: {e}", [], None, None, None, None
314
-
315
- def get_database_stats():
316
- """Get current database statistics"""
317
- try:
318
- if db:
319
- stats = db.get_stats()
320
- return {
321
- "faces": stats.get("faces", 0),
322
- "objects": stats.get("objects", 0),
323
- "sessions": stats.get("recent_detections", 0),
324
- "total_detections": stats.get("total_detections", 0)
325
- }
326
- return {"faces": 0, "objects": 0, "sessions": 0, "total_detections": 0}
327
- except Exception as e:
328
- st.warning(f"Database stats unavailable: {e}")
329
- return {"faces": 0, "objects": 0, "sessions": 0, "total_detections": 0}
330
-
331
- # Sidebar for database features and stats
332
- with st.sidebar:
333
- st.markdown("""
334
- <div class="feature-card">
335
- <h3>🗄️ NAVADA Database</h3>
336
- <p>Custom Recognition & RAG System</p>
337
- </div>
338
- """, unsafe_allow_html=True)
339
-
340
- # Database statistics
341
- stats = get_database_stats()
342
-
343
- # Prisma Studio Integration
344
- st.markdown("#### 🔧 Database Management")
345
- col_studio1, col_studio2 = st.columns(2)
346
-
347
- with col_studio1:
348
- if st.button("🎯 Open Prisma Studio", help="View and edit database in Prisma Studio"):
349
- try:
350
- import subprocess
351
- subprocess.Popen(["npm", "run", "studio"], cwd=".", shell=True)
352
- st.success("🚀 Prisma Studio starting on http://localhost:5556")
353
- except Exception as e:
354
- st.error(f"Failed to start Prisma Studio: {e}")
355
- st.info("💡 Run manually: npm run studio")
356
-
357
- with col_studio2:
358
- if st.button("📊 Database Info", help="Show database connection details"):
359
- st.info("📍 Database: navada_recognition.db\n🌐 Prisma Studio: http://localhost:5556")
360
-
361
- col1, col2 = st.columns(2)
362
- with col1:
363
- st.markdown(f"""
364
- <div class="stats-card">
365
- <h4>{stats.get('faces', 0)}</h4>
366
- <p>👥 Faces</p>
367
- </div>
368
- """, unsafe_allow_html=True)
369
-
370
- st.markdown(f"""
371
- <div class="stats-card">
372
- <h4>{stats.get('sessions', 0)}</h4>
373
- <p>📊 Sessions</p>
374
- </div>
375
- """, unsafe_allow_html=True)
376
-
377
- with col2:
378
- st.markdown(f"""
379
- <div class="stats-card">
380
- <h4>{stats.get('objects', 0)}</h4>
381
- <p>🏷️ Objects</p>
382
- </div>
383
- """, unsafe_allow_html=True)
384
-
385
- st.markdown(f"""
386
- <div class="stats-card">
387
- <h4>{stats.get('total_detections', 0)}</h4>
388
- <p>🎯 Detections</p>
389
- </div>
390
- """, unsafe_allow_html=True)
391
-
392
- st.markdown("---")
393
-
394
- # Computer Vision Educational Section
395
- with st.expander("🔬 Computer Vision Guide", expanded=False):
396
- st.markdown('<i class="fas fa-microscope fa-primary fa-icon"></i>**Advanced CV Learning Hub**', unsafe_allow_html=True)
397
- st.markdown('<h3><i class="fas fa-brain fa-primary fa-icon"></i>What is Computer Vision?</h3>', unsafe_allow_html=True)
398
- st.markdown("""
399
- **Computer Vision (CV)** is a field of artificial intelligence that enables machines to interpret and understand visual information from the world, mimicking human vision capabilities.
400
-
401
- **Key Components:**
402
- - **Image Processing**: Enhancing and filtering visual data
403
- - **Pattern Recognition**: Identifying objects, faces, and features
404
- - **Machine Learning**: Training models on visual datasets
405
- - **Deep Learning**: Neural networks for complex visual understanding
406
- """)
407
-
408
- st.markdown('<h3><i class="fas fa-crosshairs fa-primary fa-icon"></i>Top 5 Real-World Use Cases</h3>', unsafe_allow_html=True)
409
-
410
- use_cases = [
411
- {
412
- "icon": "hospital",
413
- "title": "Healthcare & Medical Imaging",
414
- "description": "Detecting diseases in X-rays, MRIs, and CT scans. Early cancer detection, automated diagnosis, and surgical assistance.",
415
- "impact": "95% accuracy in mammography screening"
416
- },
417
- {
418
- "icon": "car",
419
- "title": "Autonomous Vehicles",
420
- "description": "Real-time object detection, lane recognition, traffic sign identification, and pedestrian safety systems.",
421
- "impact": "$7 trillion global market potential"
422
- },
423
- {
424
- "icon": "industry",
425
- "title": "Manufacturing & Quality Control",
426
- "description": "Automated defect detection, product inspection, assembly line monitoring, and predictive maintenance.",
427
- "impact": "40% reduction in production errors"
428
- },
429
- {
430
- "icon": "shield-alt",
431
- "title": "Security & Surveillance",
432
- "description": "Facial recognition, anomaly detection, crowd monitoring, and threat identification in real-time.",
433
- "impact": "$62B global security market"
434
- },
435
- {
436
- "icon": "shopping-cart",
437
- "title": "Retail & E-commerce",
438
- "description": "Visual search, inventory management, customer behavior analysis, and augmented reality shopping.",
439
- "impact": "30% increase in conversion rates"
440
- }
441
- ]
442
-
443
- for case in use_cases:
444
- st.markdown(f"""
445
- **<i class="fas fa-{case['icon']} fa-primary fa-icon"></i>{case['title']}**
446
- {case['description']}
447
- *<i class="fas fa-chart-bar fa-primary fa-icon"></i>Impact: {case['impact']}*
448
- """, unsafe_allow_html=True)
449
- st.markdown("---")
450
-
451
- st.markdown('<h3><i class="fas fa-rocket fa-primary fa-icon"></i>Future Economic Impact</h3>', unsafe_allow_html=True)
452
- st.markdown("""
453
- **Job Market Transformation:**
454
-
455
- **🔮 2025-2030 Predictions:**
456
- - **+2.3M new CV jobs** globally by 2030
457
- - **$733B market value** by 2030 (15.3% CAGR)
458
- - **50% of industries** will integrate CV solutions
459
-
460
- **💼 Emerging Job Roles:**
461
- - CV Engineers & Architects
462
- - AI Ethics Specialists
463
- - Computer Vision Product Managers
464
- - Visual AI Trainers
465
- - Augmented Reality Developers
466
-
467
- **🌍 Economic Benefits:**
468
- - **Productivity**: 25-40% efficiency gains
469
- - **Cost Reduction**: $390B in operational savings
470
- - **Innovation**: New business models & services
471
- - **Accessibility**: Enhanced tools for disabilities
472
-
473
- **⚡ Industry Revolution:**
474
- - **Healthcare**: Personalized medicine & diagnostics
475
- - **Agriculture**: Precision farming & crop monitoring
476
- - **Education**: Interactive learning & assessment
477
- - **Entertainment**: Immersive AR/VR experiences
478
- """)
479
-
480
- st.markdown("### 🎓 Learning Path")
481
- st.markdown("""
482
- **Start Your CV Journey:**
483
- 1. **📚 Learn Fundamentals**: Python, OpenCV, Image Processing
484
- 2. **🧠 Master ML/DL**: TensorFlow, PyTorch, Neural Networks
485
- 3. **🔧 Hands-on Projects**: Like this NAVADA 2.0 demo!
486
- 4. **📊 Specialize**: Choose healthcare, automotive, etc.
487
- 5. **🚀 Build Portfolio**: Create real-world applications
488
- """)
489
-
490
- st.markdown('<h3><i class="fas fa-microchip fa-primary fa-icon"></i>Raspberry Pi Introduction</h3>', unsafe_allow_html=True)
491
- st.markdown("""
492
- **What is Raspberry Pi?**
493
-
494
- The Raspberry Pi is a series of small, affordable single-board computers perfect for AI and computer vision projects.
495
-
496
- **🎯 Real-World CV Use Cases:**
497
- - **🏠 Smart Security**: Door surveillance with face recognition
498
- - **🌿 Wildlife Monitoring**: Automated animal detection in reserves
499
- - **🏭 Industrial Inspection**: Quality control in manufacturing
500
- - **🚜 Agricultural Monitoring**: Plant health & pest detection
501
- - **🚦 Traffic Analysis**: Vehicle counting & license recognition
502
-
503
- **⚙️ NAVADA 2.0 on Pi Setup:**
504
- ```bash
505
- # Optimized for Pi 4 (4GB+ recommended)
506
- pip install ultralytics[cpu]
507
- streamlit run app.py --server.port 8080
508
- ```
509
-
510
- **🚀 Performance Tips:**
511
- - Use YOLOv8n (nano) for faster Pi inference
512
- - Enable VideoCore GPU acceleration
513
- - External USB3 storage for database ops
514
- - Lightweight OpenCV builds
515
- """)
516
-
517
- st.markdown('<h3><i class="fas fa-robot fa-primary fa-icon"></i>Robotics Introduction</h3>', unsafe_allow_html=True)
518
- st.markdown("""
519
- **Computer Vision in Robotics**
520
-
521
- CV is the "eyes" of modern robots, enabling intelligent perception and interaction with environments.
522
-
523
- **🔧 Integration Applications:**
524
- - **🗺️ Autonomous Navigation**: Path planning & obstacle avoidance
525
- - **🔧 Object Manipulation**: Precise pick-and-place operations
526
- - **👥 Human-Robot Interaction**: Gesture & facial recognition
527
- - **✅ Quality Assurance**: Robotic inspection systems
528
- - **🏥 Medical Robotics**: Surgical assistance & monitoring
529
-
530
- **🏢 Real-World Success Stories:**
531
- - **📦 Amazon Warehouses**: Kiva robots with vision navigation
532
- - **🚗 Tesla Autopilot**: Advanced CV for autonomous driving
533
- - **🐕 Boston Dynamics**: Vision-guided locomotion
534
- - **⚕️ Surgical Robots**: da Vinci precision guidance
535
- - **🌾 Agricultural Robots**: Automated crop monitoring
536
-
537
- **🛠️ Popular Frameworks:**
538
- - **ROS**: Robot Operating System (industry standard)
539
- - **OpenCV**: Essential computer vision processing
540
- - **PyBullet**: Physics simulation for testing
541
- - **MoveIt**: Motion planning for robotic arms
542
-
543
- **🚀 Getting Started with Robotics + NAVADA:**
544
- 1. **Hardware**: Camera + actuators + microcontroller
545
- 2. **Software**: NAVADA 2.0 + ROS + hardware drivers
546
- 3. **Training**: Collect task-specific object data
547
- 4. **Integration**: Connect detection → robot control
548
- """)
549
-
550
- st.info("💡 **Pro Tip**: NAVADA 2.0 demonstrates key CV concepts - object detection, face recognition, and custom training!")
551
-
552
- st.markdown("---")
553
-
554
- # Face database addition
555
- st.markdown("### 👤 Add Face to Database")
556
- face_name = st.text_input("Enter person's name:", key="face_name")
557
- if st.button("👤 Add Face", key="add_face"):
558
- if st.session_state.get('current_image') is not None and face_name:
559
- if recognition_system:
560
- success = recognition_system.add_new_face(
561
- np.array(st.session_state.current_image), face_name
562
- )
563
- if success:
564
- st.success(f"✅ Added {face_name} to face database!")
565
- st.rerun()
566
- else:
567
- st.error("❌ Failed to add face. Please ensure a clear face is visible.")
568
- else:
569
- st.error("Recognition system not available")
570
- else:
571
- st.warning("Please upload an image and enter a name first.")
572
-
573
- st.markdown("---")
574
-
575
- # Live Session Statistics
576
- st.markdown("### 📈 Live Session Stats")
577
-
578
- # Session metrics in a compact format
579
- session_col1, session_col2 = st.columns(2)
580
- with session_col1:
581
- st.metric("🖼️ This Session",
582
- st.session_state.get('images_processed', 0),
583
- delta=None,
584
- delta_color="normal")
585
-
586
- total_objects_detected = 0
587
- if 'last_results' in st.session_state and st.session_state.last_results:
588
- detected_objects = st.session_state.last_results[2]
589
- total_objects_detected = len(detected_objects) if detected_objects else 0
590
-
591
- st.metric("🎯 Objects Found",
592
- total_objects_detected,
593
- delta=None)
594
-
595
- with session_col2:
596
- processing_time = 0
597
- if 'start_time' in st.session_state:
598
- processing_time = time.time() - st.session_state.start_time
599
-
600
- st.metric("⚡ Last Process",
601
- f"{processing_time:.1f}s" if processing_time > 0 else "0.0s",
602
- delta=None)
603
-
604
- accuracy_score = 0
605
- if total_objects_detected > 0:
606
- accuracy_score = min(95, 85 + total_objects_detected * 2)
607
-
608
- st.metric("📊 Accuracy",
609
- f"{accuracy_score}%" if accuracy_score > 0 else "0%",
610
- delta=None)
611
-
612
- # Session progress bar
613
- session_target = 10 # Target images for session
614
- current_progress = min(st.session_state.get('images_processed', 0) / session_target, 1.0)
615
- st.progress(current_progress, text=f"Session Progress: {st.session_state.get('images_processed', 0)}/{session_target}")
616
-
617
- st.markdown("---")
618
-
619
- # Custom object addition
620
- st.markdown("### 🏷️ Add Custom Object")
621
- object_label = st.text_input("Object label:", key="object_label")
622
- object_category = st.text_input("Category (optional):", key="object_category")
623
- if st.button("🏷️ Add Object", key="add_object"):
624
- if st.session_state.get('current_image') is not None and object_label:
625
- if recognition_system:
626
- success = recognition_system.add_custom_object(
627
- np.array(st.session_state.current_image),
628
- object_label,
629
- object_category or "general"
630
- )
631
- if success:
632
- st.success(f"✅ Added '{object_label}' to object database!")
633
- st.rerun()
634
- else:
635
- st.error("❌ Failed to add object.")
636
- else:
637
- st.error("Recognition system not available")
638
- else:
639
- st.warning("Please upload an image and enter a label first.")
640
-
641
- # Main content area
642
- col1, col2 = st.columns([2, 1])
643
-
644
- with col1:
645
- # Image input tabs
646
- tab1, tab2 = st.tabs(["📁 Upload Image", "📸 Camera Capture"])
647
-
648
- with tab1:
649
- uploaded_file = st.file_uploader(
650
- "Choose an image file",
651
- type=['png', 'jpg', 'jpeg'],
652
- help="Upload an image for AI analysis"
653
- )
654
-
655
- if uploaded_file is not None:
656
- image = Image.open(uploaded_file)
657
- st.session_state.current_image = image
658
- st.image(image, caption="Uploaded Image", use_container_width=True)
659
-
660
- with tab2:
661
- camera_image = st.camera_input("📸 Take a picture")
662
-
663
- if camera_image is not None:
664
- image = Image.open(camera_image)
665
- st.session_state.current_image = image
666
- st.image(image, caption="Captured Image", use_container_width=True)
667
-
668
- with col2:
669
- # Processing options
670
- st.markdown("### ⚙️ Processing Options")
671
-
672
- # Model information
673
- model_info = two_stage_inference.get_model_info()
674
- if model_info.get('custom_model_loaded'):
675
- st.success(f"🧠 Custom Model Active: {model_info.get('num_custom_classes', 0)} trained classes")
676
- with st.expander("📋 Model Details"):
677
- st.text(f"Custom Classes: {', '.join(model_info.get('custom_classes', []))}")
678
- st.text(f"Training Samples: {model_info.get('training_samples', 0)}")
679
- st.text(f"Device: {model_info.get('device', 'unknown')}")
680
- else:
681
- st.info("🤖 Using YOLO only - Train custom model by providing corrections!")
682
-
683
- # Enhanced accuracy controls
684
- st.markdown("#### 🎯 Accuracy Settings")
685
- use_enhanced = st.checkbox("**Use Enhanced Detection** (Better Accuracy)", value=True, help="Uses advanced model with color detection and custom training")
686
- confidence_threshold = st.slider("Detection Confidence", 0.1, 0.9, 0.5, 0.05, help="Higher = fewer but more accurate detections")
687
-
688
- # Make voice option more prominent
689
- st.markdown("#### 🔊 Audio Features")
690
- enable_voice = st.checkbox("**Enable Voice Narration** (OpenAI TTS)", value=False, help="Generate AI voice explanation of detected objects")
691
-
692
- st.markdown("#### 🧠 AI Features")
693
- enable_face_detection = st.checkbox("👤 Enable Face Detection", value=True)
694
- enable_recognition = st.checkbox("🧠 Enable Smart Recognition", value=True)
695
-
696
- # Launch button
697
- if st.button("🚀 LAUNCH ANALYSIS", key="launch", type="primary"):
698
- if 'current_image' in st.session_state:
699
- # Track processing start time
700
- st.session_state.start_time = time.time()
701
-
702
- # Update session counters
703
- st.session_state.images_processed = st.session_state.get('images_processed', 0) + 1
704
-
705
- with st.spinner("🔄 Processing with NAVADA 2.0..."):
706
- results = process_image(
707
- st.session_state.current_image,
708
- enable_voice,
709
- enable_face_detection,
710
- enable_recognition,
711
- use_enhanced,
712
- confidence_threshold
713
- )
714
- st.session_state.last_results = results
715
- st.session_state.processing_complete = True
716
- else:
717
- st.warning("Please upload an image or take a photo first!")
718
-
719
- # Results section
720
- if st.session_state.processing_complete and st.session_state.last_results:
721
- # Unpack results - handle both old and new format
722
- if len(st.session_state.last_results) == 7:
723
- detected_img, ai_explanation, detected_objects, face_stats, face_matches, audio_file, detailed_attributes = st.session_state.last_results
724
- else:
725
- detected_img, ai_explanation, detected_objects, face_stats, face_matches, audio_file = st.session_state.last_results
726
- detailed_attributes = None
727
-
728
- st.markdown("---")
729
- st.markdown("## 🎯 Analysis Results")
730
-
731
- # Display processed image
732
- if detected_img is not None:
733
- st.image(detected_img, caption="🔍 Processed Image with Detections", use_container_width=True)
734
-
735
- # Results in two columns
736
- res_col1, res_col2 = st.columns([3, 2])
737
-
738
- with res_col1:
739
- # AI explanation
740
- st.markdown("### 🤖 AI Analysis")
741
- st.markdown(ai_explanation)
742
-
743
- # Audio playback
744
- if audio_file:
745
- st.markdown("### 🔊 Voice Narration")
746
- st.audio(audio_file)
747
-
748
- # Comprehensive App Statistics Section
749
- st.markdown("---")
750
- st.markdown("## 📊 NAVADA 2.0 Analytics Dashboard")
751
-
752
- # Get processing stats for current session
753
- processing_time = time.time() - st.session_state.get('start_time', time.time())
754
-
755
- # Create statistics tabs
756
- stats_tab1, stats_tab2, stats_tab3, stats_tab4 = st.tabs([
757
- "🚀 Performance", "📈 Usage Metrics", "🎯 Detection Stats", "🧠 AI Insights"
758
- ])
759
-
760
- with stats_tab1:
761
- # Performance Metrics
762
- col1, col2, col3 = st.columns(3)
763
-
764
- with col1:
765
- st.metric("⚡ Processing Speed", f"{processing_time:.2f}s",
766
- delta=f"-{max(0, 2.5-processing_time):.1f}s vs avg")
767
-
768
- with col2:
769
- inference_time = 0.25 if detected_objects else 0.0 # Approximate from logs
770
- st.metric("🧠 AI Inference", f"{inference_time*1000:.0f}ms",
771
- delta=f"{inference_time*1000-200:.0f}ms")
772
-
773
- with col3:
774
- accuracy = min(95, 85 + len(detected_objects) * 2) if detected_objects else 0
775
- st.metric("🎯 Detection Accuracy", f"{accuracy}%",
776
- delta=f"+{accuracy-85}%" if accuracy > 85 else "0%")
777
-
778
- # Performance trend chart
779
- performance_data = {
780
- 'Metric': ['Preprocessing', 'Inference', 'Postprocessing', 'Face Detection', 'Recognition'],
781
- 'Time (ms)': [16, 250, 18, 45, 120],
782
- 'Efficiency': [95, 88, 92, 87, 91]
783
- }
784
-
785
- perf_chart = go.Figure()
786
- perf_chart.add_trace(go.Bar(
787
- x=performance_data['Metric'],
788
- y=performance_data['Time (ms)'],
789
- name='Processing Time (ms)',
790
- marker_color='#FF6B6B'
791
- ))
792
-
793
- perf_chart.update_layout(
794
- title="⚡ NAVADA 2.0 Performance Breakdown",
795
- xaxis_title="Processing Stage",
796
- yaxis_title="Time (milliseconds)",
797
- height=350,
798
- template="plotly_dark"
799
- )
800
- st.plotly_chart(perf_chart, use_container_width=True)
801
-
802
- with stats_tab2:
803
- # Usage Analytics
804
- col1, col2, col3, col4 = st.columns(4)
805
-
806
- db_stats = get_database_stats()
807
-
808
- with col1:
809
- st.metric("📸 Images Processed",
810
- st.session_state.get('images_processed', 1),
811
- delta="+1")
812
-
813
- with col2:
814
- st.metric("👥 Faces Trained",
815
- db_stats.get('faces', 0),
816
- delta="+0")
817
-
818
- with col3:
819
- st.metric("🏷️ Objects Trained",
820
- db_stats.get('objects', 0),
821
- delta="+0")
822
-
823
- with col4:
824
- st.metric("🎯 Total Detections",
825
- db_stats.get('total_detections', 0),
826
- delta="+0")
827
-
828
- # Usage trend over time (simulated data)
829
- import datetime
830
- dates = [datetime.datetime.now() - datetime.timedelta(days=x) for x in range(7, 0, -1)]
831
- usage_data = {
832
- 'Date': dates,
833
- 'Detections': [12, 18, 25, 31, 28, 35, 42],
834
- 'Accuracy': [87, 89, 91, 93, 92, 94, 95]
835
- }
836
-
837
- usage_chart = go.Figure()
838
- usage_chart.add_trace(go.Scatter(
839
- x=usage_data['Date'],
840
- y=usage_data['Detections'],
841
- mode='lines+markers',
842
- name='Daily Detections',
843
- line=dict(color='#4ECDC4', width=3),
844
- marker=dict(size=8)
845
- ))
846
-
847
- usage_chart.add_trace(go.Scatter(
848
- x=usage_data['Date'],
849
- y=usage_data['Accuracy'],
850
- mode='lines+markers',
851
- name='Accuracy %',
852
- yaxis='y2',
853
- line=dict(color='#45B7D1', width=3),
854
- marker=dict(size=8)
855
- ))
856
-
857
- usage_chart.update_layout(
858
- title="📈 NAVADA 2.0 Weekly Performance Trends",
859
- xaxis_title="Date",
860
- yaxis_title="Number of Detections",
861
- yaxis2=dict(
862
- title="Accuracy (%)",
863
- overlaying='y',
864
- side='right'
865
- ),
866
- height=400,
867
- template="plotly_dark",
868
- hovermode='x unified'
869
- )
870
- st.plotly_chart(usage_chart, use_container_width=True)
871
-
872
- with stats_tab3:
873
- # Detection Statistics
874
- if detected_objects:
875
- # Object category distribution
876
- object_categories = {
877
- 'Animals': ['bird', 'dog', 'cat', 'horse', 'elephant', 'bear', 'zebra', 'giraffe'],
878
- 'Vehicles': ['car', 'truck', 'bus', 'motorcycle', 'bicycle', 'airplane', 'boat'],
879
- 'People': ['person'],
880
- 'Objects': ['bottle', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'book', 'laptop']
881
- }
882
-
883
- category_counts = {}
884
- for obj in detected_objects:
885
- for category, items in object_categories.items():
886
- if obj in items:
887
- category_counts[category] = category_counts.get(category, 0) + 1
888
- break
889
- else:
890
- category_counts['Other'] = category_counts.get('Other', 0) + 1
891
-
892
- # Category pie chart
893
- if category_counts:
894
- category_chart = go.Figure(data=[go.Pie(
895
- labels=list(category_counts.keys()),
896
- values=list(category_counts.values()),
897
- hole=.4,
898
- marker_colors=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57', '#FF9FF3']
899
- )])
900
-
901
- category_chart.update_layout(
902
- title="🎯 Object Categories Detected",
903
- height=350,
904
- template="plotly_dark"
905
- )
906
- st.plotly_chart(category_chart, use_container_width=True)
907
-
908
- # Confidence levels radar chart
909
- confidence_levels = {
910
- 'High Confidence (>90%)': len([obj for obj in detected_objects]) * 0.7,
911
- 'Medium Confidence (70-90%)': len([obj for obj in detected_objects]) * 0.25,
912
- 'Low Confidence (<70%)': len([obj for obj in detected_objects]) * 0.05
913
- }
914
-
915
- confidence_chart = go.Figure()
916
- confidence_chart.add_trace(go.Bar(
917
- x=list(confidence_levels.keys()),
918
- y=list(confidence_levels.values()),
919
- marker_color=['#4CAF50', '#FFC107', '#FF5722']
920
- ))
921
-
922
- confidence_chart.update_layout(
923
- title="🎯 Detection Confidence Distribution",
924
- xaxis_title="Confidence Level",
925
- yaxis_title="Number of Detections",
926
- height=300,
927
- template="plotly_dark"
928
- )
929
- st.plotly_chart(confidence_chart, use_container_width=True)
930
-
931
- else:
932
- st.info("📸 Upload an image to see detection statistics!")
933
-
934
- with stats_tab4:
935
- # AI Insights and Model Information
936
- col1, col2 = st.columns(2)
937
-
938
- with col1:
939
- st.markdown("### 🧠 AI Model Information")
940
- model_info = {
941
- "🏗️ Architecture": "YOLOv8 + Custom Recognition",
942
- "📊 Model Size": "6.2 MB (YOLOv8n)",
943
- "🎯 Classes": "80+ COCO Objects",
944
- "👥 Custom Faces": f"{db_stats.get('faces', 0)} trained",
945
- "🏷️ Custom Objects": f"{db_stats.get('objects', 0)} trained",
946
- "🧠 AI Engine": "OpenAI GPT-4o-mini",
947
- "🔊 TTS Engine": "OpenAI TTS-1",
948
- "💾 Database": "SQLite + RAG"
949
- }
950
-
951
- for key, value in model_info.items():
952
- st.markdown(f"**{key}**: {value}")
953
-
954
- with col2:
955
- # Model comparison chart
956
- models_comparison = {
957
- 'Model': ['NAVADA 2.0', 'YOLOv8', 'Standard CV', 'Basic Detection'],
958
- 'Accuracy': [94, 89, 82, 75],
959
- 'Speed (ms)': [280, 250, 400, 350],
960
- 'Features': [15, 8, 5, 3]
961
- }
962
-
963
- comparison_chart = go.Figure()
964
- comparison_chart.add_trace(go.Scatterpolar(
965
- r=[94, 95, 90, 98], # NAVADA 2.0 capabilities
966
- theta=['Accuracy', 'Speed', 'Features', 'Innovation'],
967
- fill='toself',
968
- name='NAVADA 2.0',
969
- line=dict(color='#4ECDC4')
970
- ))
971
- comparison_chart.add_trace(go.Scatterpolar(
972
- r=[89, 92, 60, 70], # Standard models
973
- theta=['Accuracy', 'Speed', 'Features', 'Innovation'],
974
- fill='toself',
975
- name='Standard Models',
976
- line=dict(color='#FF6B6B')
977
- ))
978
-
979
- comparison_chart.update_layout(
980
- title="🚀 NAVADA 2.0 vs Standard Models",
981
- polar=dict(
982
- radialaxis=dict(
983
- visible=True,
984
- range=[0, 100]
985
- )),
986
- height=350,
987
- template="plotly_dark"
988
- )
989
- st.plotly_chart(comparison_chart, use_container_width=True)
990
-
991
- # System capabilities matrix
992
- st.markdown("### ⚡ System Capabilities")
993
-
994
- # Create manual table to avoid pandas import
995
- st.markdown("""
996
- | 🎯 Feature | 📊 Status | ⚡ Performance |
997
- |------------|-----------|----------------|
998
- | Object Detection | ✅ Active | 94% |
999
- | Face Recognition | ✅ Active | 91% |
1000
- | Custom Training | ✅ Active | 89% |
1001
- | Voice Narration | ✅ Active | 96% |
1002
- | RAG Analysis | ✅ Active | 87% |
1003
- | Real-time Processing | ✅ Active | 92% |
1004
- """)
1005
-
1006
- with res_col2:
1007
- # Charts
1008
- if detected_objects:
1009
- # Detection chart
1010
- detection_chart = create_detection_chart(detected_objects, face_stats, face_matches)
1011
- st.plotly_chart(detection_chart, use_container_width=True)
1012
-
1013
- # Confidence pie chart
1014
- confidence_chart = create_confidence_pie_chart(detected_objects, face_matches)
1015
- if confidence_chart:
1016
- st.plotly_chart(confidence_chart, use_container_width=True)
1017
-
1018
- # Detection summary
1019
- st.markdown("### 📋 Detection Summary")
1020
- if detected_objects:
1021
- st.success(f"🎯 Found {len(detected_objects)} objects!")
1022
- for obj in set(detected_objects):
1023
- count = detected_objects.count(obj)
1024
- st.markdown(f"• **{obj}**: {count}")
1025
- else:
1026
- st.warning("No objects detected in this image")
1027
-
1028
- if face_matches:
1029
- st.markdown("### 👥 Face Recognition")
1030
- for match in face_matches:
1031
- name = match['name']
1032
- similarity = match.get('similarity', 0)
1033
- if name != 'Unknown':
1034
- st.markdown(f"• **{name}**: {similarity:.2f} confidence")
1035
- else:
1036
- st.markdown(f"• **{name}**: New face detected")
1037
-
1038
- # Training Feedback Section
1039
- st.markdown('<h3><i class="fas fa-brain fa-primary fa-icon"></i>Help Improve Detection Accuracy</h3>', unsafe_allow_html=True)
1040
- st.markdown("Found an incorrect detection? Help train the AI by providing corrections!")
1041
-
1042
- # Show corrections interface if we have detailed attributes
1043
- if detailed_attributes and st.session_state.get('current_image'):
1044
- # Create dropdown selection for cleaner UI
1045
- detection_options = [f"Detection #{i+1}: {attr['label']} ({attr['confidence']}) - {attr['position']}"
1046
- for i, attr in enumerate(detailed_attributes)]
1047
-
1048
- selected_detection = st.selectbox(
1049
- "Select detection to correct:",
1050
- options=range(len(detailed_attributes)),
1051
- format_func=lambda x: detection_options[x],
1052
- key="selected_detection"
1053
- )
1054
-
1055
- if selected_detection is not None:
1056
- attr = detailed_attributes[selected_detection]
1057
- i = selected_detection
1058
-
1059
- # Show selected detection details in expander
1060
- with st.expander(f"🔧 Correcting: {attr['label']} ({attr['confidence']})", expanded=True):
1061
- col1, col2 = st.columns(2)
1062
-
1063
- with col1:
1064
- st.markdown("**Detection Details:**")
1065
- st.write(f"🏷️ **Label:** {attr['label']}")
1066
- st.write(f"🎯 **Confidence:** {attr['confidence']}")
1067
- st.write(f"🎨 **Colors:** {', '.join(attr['colors'][:2])}")
1068
- st.write(f"📍 **Position:** {attr['position']}")
1069
- st.write(f"📏 **Size:** {attr['size']}")
1070
-
1071
- with col2:
1072
- st.markdown("**Provide Correction:**")
1073
- # Correction input
1074
- correct_label = st.text_input(
1075
- "Correct label:",
1076
- key=f"correct_{i}",
1077
- placeholder="e.g., rabbit, dog, car"
1078
- )
1079
-
1080
- feedback_text = st.text_area(
1081
- "Feedback (optional):",
1082
- key=f"feedback_{i}",
1083
- placeholder="Why was this wrong?",
1084
- height=68
1085
- )
1086
-
1087
- if st.button("✅ Submit Correction", key=f"submit_correction_{i}", use_container_width=True, type="primary"):
1088
- st.markdown('<i class="fas fa-spinner fa-spin fa-primary fa-icon"></i>Processing...', unsafe_allow_html=True)
1089
- if correct_label.strip():
1090
- # Extract object region from image
1091
- image_array = np.array(st.session_state.current_image)
1092
-
1093
- # Get bounding box coordinates from detailed attributes
1094
- bbox_coords = attr.get('bbox', [100, 100, 200, 200]) # [x1, y1, x2, y2]
1095
-
1096
- # Extract object crop
1097
- x1, y1, x2, y2 = bbox_coords
1098
- object_crop = image_array[max(0, int(y1)):min(image_array.shape[0], int(y2)),
1099
- max(0, int(x1)):min(image_array.shape[1], int(x2))]
1100
-
1101
- if object_crop.size > 0:
1102
- # Save correction to database
1103
- try:
1104
- success = db.save_correction(
1105
- image_crop=object_crop,
1106
- bbox_coords=bbox_coords,
1107
- yolo_prediction=attr['label'],
1108
- yolo_confidence=float(attr['confidence'].rstrip('%')) / 100.0,
1109
- correct_label=correct_label.strip(),
1110
- user_feedback=feedback_text.strip(),
1111
- session_id=st.session_state.get('session_id', '')
1112
- )
1113
-
1114
- if success:
1115
- st.markdown('<div class="fa-success"><i class="fas fa-check-circle fa-success fa-icon"></i>Successfully saved correction!</div>', unsafe_allow_html=True)
1116
- st.balloons()
1117
-
1118
- # Update training stats
1119
- if 'training_corrections' not in st.session_state:
1120
- st.session_state.training_corrections = 0
1121
- st.session_state.training_corrections += 1
1122
- else:
1123
- st.markdown('<div class="fa-error"><i class="fas fa-times-circle fa-error fa-icon"></i>Failed to save correction.</div>', unsafe_allow_html=True)
1124
- except Exception as e:
1125
- st.markdown(f'<div class="fa-error"><i class="fas fa-times-circle fa-error fa-icon"></i>Error: {e}</div>', unsafe_allow_html=True)
1126
- else:
1127
- st.markdown('<div class="fa-error"><i class="fas fa-times-circle fa-error fa-icon"></i>Could not extract object.</div>', unsafe_allow_html=True)
1128
- else:
1129
- st.markdown('<div class="fa-warning"><i class="fas fa-exclamation-triangle fa-warning fa-icon"></i>Please enter correct label.</div>', unsafe_allow_html=True)
1130
-
1131
- # Training Statistics
1132
- if db:
1133
- training_stats = db.get_training_stats()
1134
- if training_stats.get('total_corrections', 0) > 0:
1135
- st.markdown("### 📊 Training Progress")
1136
-
1137
- col1, col2, col3, col4 = st.columns(4)
1138
- with col1:
1139
- st.metric("Total Corrections", training_stats.get('total_corrections', 0))
1140
- with col2:
1141
- st.metric("Unique Classes", training_stats.get('unique_classes', 0))
1142
- with col3:
1143
- st.metric("Recent (7 days)", training_stats.get('recent_corrections', 0))
1144
- with col4:
1145
- st.metric("Avg Difficulty", f"{training_stats.get('average_difficulty', 0):.2f}")
1146
-
1147
- # Show class distribution
1148
- if training_stats.get('class_distribution'):
1149
- st.markdown("**Class Distribution:**")
1150
- for class_name, count in list(training_stats['class_distribution'].items())[:5]:
1151
- st.text(f"• {class_name}: {count} samples")
1152
-
1153
- # Training trigger
1154
- if training_stats.get('total_corrections', 0) >= 10:
1155
- if st.button("🚀 Train Custom Model", key="train_model"):
1156
- with st.spinner("Training custom classifier... This may take a few minutes."):
1157
- # Import training module
1158
- from backend.custom_trainer import custom_trainer
1159
-
1160
- # Get training data
1161
- training_data = db.get_training_data(limit=1000)
1162
-
1163
- if len(training_data) >= 10:
1164
- # Train model
1165
- result = custom_trainer.train_model(training_data, epochs=10, batch_size=8)
1166
-
1167
- if result['success']:
1168
- st.success(f"✅ Model trained successfully! Accuracy: {result['best_accuracy']:.2%}")
1169
- st.info("The new model will be used for future detections.")
1170
-
1171
- # Save model info to database (implement this method)
1172
- # db.save_model_version(result)
1173
- else:
1174
- st.error(f"❌ Training failed: {result.get('error', 'Unknown error')}")
1175
- else:
1176
- st.warning("⚠️ Need at least 10 corrections to start training.")
1177
-
1178
- # Debug information
1179
- with st.expander("🔍 Debug Information"):
1180
- st.text(f"Detected objects list: {detected_objects}")
1181
- st.text(f"Face stats: {face_stats}")
1182
- st.text(f"Face matches: {face_matches}")
1183
- if detailed_attributes:
1184
- st.text(f"Detailed attributes: {detailed_attributes}")
1185
-
1186
- # AI Chat Agent Section
1187
- st.markdown("---")
1188
- st.markdown("## 💬 AI Chat Assistant")
1189
- st.markdown("Ask questions about the detected objects or have a conversation about the image!")
1190
-
1191
- # Chat interface
1192
- chat_col1, chat_col2 = st.columns([4, 1])
1193
-
1194
- with chat_col1:
1195
- user_message = st.text_input("Your message:", key="chat_input", placeholder="Ask about colors, positions, objects...")
1196
-
1197
- with chat_col2:
1198
- col1, col2 = st.columns(2)
1199
- with col1:
1200
- send_button = st.button("Send 💬", key="send_chat")
1201
- with col2:
1202
- clear_button = st.button("Clear 🔄", key="clear_chat")
1203
-
1204
- # Voice option for chat
1205
- enable_chat_voice = st.checkbox("🔊 Enable voice responses for chat", value=True, key="chat_voice")
1206
-
1207
- # Process chat
1208
- if send_button and user_message:
1209
- with st.spinner("Thinking..."):
1210
- # Update chat agent with current detection context
1211
- if 'last_results' in st.session_state and st.session_state.last_results:
1212
- if len(st.session_state.last_results) == 7:
1213
- _, _, detected_objs, _, _, _, detailed_attrs = st.session_state.last_results
1214
- response, voice_file = chat_with_agent(
1215
- user_message,
1216
- detected_objs,
1217
- detailed_attrs,
1218
- enable_chat_voice
1219
- )
1220
- else:
1221
- response, voice_file = chat_with_agent(user_message, include_voice=enable_chat_voice)
1222
- else:
1223
- response, voice_file = chat_with_agent(user_message, include_voice=enable_chat_voice)
1224
-
1225
- # Add to chat history
1226
- st.session_state.chat_messages.append({"role": "user", "content": user_message})
1227
- st.session_state.chat_messages.append({"role": "assistant", "content": response, "voice": voice_file})
1228
-
1229
- if clear_button:
1230
- st.session_state.chat_messages = []
1231
- reset_chat()
1232
- st.rerun()
1233
-
1234
- # Display chat history
1235
- if st.session_state.chat_messages:
1236
- st.markdown("### 💭 Conversation")
1237
- for msg in st.session_state.chat_messages:
1238
- if msg["role"] == "user":
1239
- st.markdown(f"**You:** {msg['content']}")
1240
- else:
1241
- st.markdown(f"**NAVADA:** {msg['content']}")
1242
- if msg.get("voice"):
1243
- st.audio(msg["voice"], format="audio/mp3")
1244
-
1245
- # Footer
1246
- st.markdown("---")
1247
- st.markdown("""
1248
- <div style="text-align: center; padding: 2rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; margin-top: 2rem;">
1249
- <h3>🎉 Experience the Future of Computer Vision</h3>
1250
- <p><strong>⭐ Built with passion and innovation by Lee Akpareva | © 2024 AI Innovation Lab ⭐</strong></p>
1251
- <p>🚀 <em>From concept to deployment in 15 minutes - now with intelligent learning capabilities!</em></p>
1252
- <p>🔗 <strong>Deployed on Hugging Face Spaces for seamless AI model demonstration</strong></p>
1253
- </div>
1254
- """, unsafe_allow_html=True)
 
1
+ """
2
+ 🚀 NAVADA 2.0 - Advanced AI Computer Vision Application (Lite Version)
3
+ Streamlit Version for Hugging Face Spaces Deployment
4
+
5
+ Enhanced Edition by Lee Akpareva | AI Consultant & Computer Vision Specialist
6
+ """
7
+
8
+ import streamlit as st
9
+ import time
10
+ from datetime import datetime
11
+ import plotly.graph_objects as go
12
+ import plotly.express as px
13
+ from PIL import Image
14
+ import numpy as np
15
+ import os
16
+
17
+ # Configure Streamlit page (MUST be first!)
18
+ st.set_page_config(
19
+ page_title="🚀 NAVADA 2.0 - AI Computer Vision",
20
+ page_icon="🚀",
21
+ layout="wide",
22
+ initial_sidebar_state="expanded"
23
+ )
24
+
25
+ # Backend imports - Lite version (no face recognition)
26
+ try:
27
+ from backend.yolo import detect_objects
28
+ from backend.openai_client import explain_detection
29
+ except ImportError as e:
30
+ st.error(f"⚠️ Import error: {e}")
31
+ st.error("📦 Please install dependencies: pip install -r requirements.txt")
32
+ st.stop()
33
+
34
+ # Custom CSS for enhanced styling
35
+ st.markdown("""
36
+ <style>
37
+ .main-header {
38
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
39
+ padding: 2rem;
40
+ border-radius: 10px;
41
+ color: white;
42
+ text-align: center;
43
+ margin-bottom: 2rem;
44
+ }
45
+
46
+ .feature-card {
47
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
48
+ padding: 1.5rem;
49
+ border-radius: 10px;
50
+ color: white;
51
+ margin: 1rem 0;
52
+ }
53
+
54
+ .stats-card {
55
+ background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
56
+ padding: 1rem;
57
+ border-radius: 8px;
58
+ color: white;
59
+ text-align: center;
60
+ margin: 0.5rem;
61
+ }
62
+ </style>
63
+ """, unsafe_allow_html=True)
64
+
65
+ def create_detection_chart(detected_objects):
66
+ """Create an interactive chart showing detection statistics"""
67
+
68
+ # Count object types
69
+ object_counts = {}
70
+ for obj in detected_objects:
71
+ object_counts[obj] = object_counts.get(obj, 0) + 1
72
+
73
+ if not object_counts:
74
+ # Create empty chart
75
+ fig = go.Figure()
76
+ fig.add_annotation(
77
+ text="No objects detected",
78
+ xref="paper", yref="paper",
79
+ x=0.5, y=0.5, showarrow=False,
80
+ font=dict(size=20, color="gray")
81
+ )
82
+ fig.update_layout(
83
+ height=300,
84
+ showlegend=False,
85
+ paper_bgcolor='rgba(0,0,0,0)',
86
+ plot_bgcolor='rgba(0,0,0,0)'
87
+ )
88
+ return fig
89
+
90
+ # Create bar chart
91
+ objects = list(object_counts.keys())
92
+ counts = list(object_counts.values())
93
+
94
+ fig = go.Figure(data=[
95
+ go.Bar(
96
+ x=objects,
97
+ y=counts,
98
+ marker_color='rgba(50, 171, 96, 0.6)',
99
+ marker_line_color='rgba(50, 171, 96, 1.0)',
100
+ marker_line_width=2,
101
+ text=counts,
102
+ textposition='auto'
103
+ )
104
+ ])
105
+
106
+ fig.update_layout(
107
+ title="Detected Objects",
108
+ xaxis_title="Object Type",
109
+ yaxis_title="Count",
110
+ height=400,
111
+ showlegend=False,
112
+ paper_bgcolor='rgba(0,0,0,0)',
113
+ plot_bgcolor='rgba(0,0,0,0)'
114
+ )
115
+
116
+ return fig
117
+
118
+ def main():
119
+ # Main header
120
+ st.markdown("""
121
+ <div class="main-header">
122
+ <h1>🚀 NAVADA 2.0 - Advanced AI Computer Vision</h1>
123
+ <p><strong>Lite Version - Object Detection & AI Analysis</strong></p>
124
+ <p>Built with YOLOv8 • OpenAI • Streamlit</p>
125
+ </div>
126
+ """, unsafe_allow_html=True)
127
+
128
+ # Sidebar
129
+ with st.sidebar:
130
+ st.markdown("### 🎯 Detection Settings")
131
+
132
+ # Detection confidence threshold
133
+ confidence = st.slider(
134
+ "Detection Confidence",
135
+ min_value=0.1,
136
+ max_value=1.0,
137
+ value=0.5,
138
+ step=0.05,
139
+ help="Minimum confidence for object detection"
140
+ )
141
+
142
+ st.markdown("### 📊 Features")
143
+ st.markdown("""
144
+ - 🎯 **Object Detection**: YOLOv8 powered
145
+ - 🤖 **AI Explanations**: OpenAI integration
146
+ - 📈 **Interactive Charts**: Real-time analytics
147
+ - 🎨 **Visual Results**: Annotated images
148
+ """)
149
+
150
+ st.markdown("### ℹ️ About")
151
+ st.markdown("""
152
+ This is the **Lite Version** optimized for Hugging Face Spaces.
153
+
154
+ **Created by:** Lee Akpareva
155
+ **AI Consultant & Computer Vision Specialist**
156
+ """)
157
+
158
+ # Main content
159
+ col1, col2 = st.columns([2, 1])
160
+
161
+ with col1:
162
+ st.markdown("### 📸 Upload Image for Analysis")
163
+
164
+ uploaded_file = st.file_uploader(
165
+ "Choose an image...",
166
+ type=['png', 'jpg', 'jpeg'],
167
+ help="Upload an image to detect objects and get AI analysis"
168
+ )
169
+
170
+ if uploaded_file is not None:
171
+ # Display uploaded image
172
+ image = Image.open(uploaded_file)
173
+ st.image(image, caption="Uploaded Image", use_column_width=True)
174
+
175
+ # Analysis button
176
+ if st.button("🚀 Analyze Image", type="primary"):
177
+ with st.spinner("🔍 Detecting objects..."):
178
+ # Perform object detection
179
+ results = detect_objects(image, confidence_threshold=confidence)
180
+
181
+ if results and len(results['detections']) > 0:
182
+ # Extract detected objects
183
+ detected_objects = [det['class'] for det in results['detections']]
184
+
185
+ # Display results
186
+ st.success(f"✅ Detected {len(detected_objects)} objects!")
187
+
188
+ # Show annotated image
189
+ st.markdown("### 🎯 Detection Results")
190
+ if 'annotated_image' in results:
191
+ st.image(results['annotated_image'], caption="Detected Objects", use_column_width=True)
192
+
193
+ # Show detection details
194
+ st.markdown("### 📋 Detected Objects")
195
+ for i, detection in enumerate(results['detections']):
196
+ col_a, col_b, col_c = st.columns(3)
197
+ with col_a:
198
+ st.metric("Object", detection['class'])
199
+ with col_b:
200
+ st.metric("Confidence", f"{detection['confidence']:.2%}")
201
+ with col_c:
202
+ st.metric("Count", f"#{i+1}")
203
+
204
+ # AI Explanation
205
+ if os.getenv("OPENAI_API_KEY"):
206
+ st.markdown("### 🤖 AI Analysis")
207
+ with st.spinner("🧠 Generating AI explanation..."):
208
+ try:
209
+ explanation = explain_detection(detected_objects)
210
+ st.markdown(f"**AI Insight:** {explanation}")
211
+ except Exception as e:
212
+ st.warning(f"AI analysis unavailable: {str(e)}")
213
+ else:
214
+ st.warning("🔑 Add OPENAI_API_KEY in settings for AI explanations")
215
+
216
+ else:
217
+ st.warning("❌ No objects detected. Try adjusting the confidence threshold.")
218
+
219
+ with col2:
220
+ st.markdown("### 📊 Detection Statistics")
221
+
222
+ # Sample chart (will be updated with real data)
223
+ sample_data = {
224
+ 'Object': ['Person', 'Car', 'Dog', 'Cat'],
225
+ 'Count': [3, 2, 1, 1]
226
+ }
227
+
228
+ fig = px.bar(
229
+ sample_data,
230
+ x='Object',
231
+ y='Count',
232
+ title="Sample Detection Results",
233
+ color='Count',
234
+ color_continuous_scale='Viridis'
235
+ )
236
+ fig.update_layout(height=300)
237
+ st.plotly_chart(fig, use_container_width=True)
238
+
239
+ # Feature highlights
240
+ st.markdown("### Key Features")
241
+
242
+ features = [
243
+ ("🎯", "Object Detection", "Advanced YOLOv8 model"),
244
+ ("🤖", "AI Analysis", "OpenAI explanations"),
245
+ ("📊", "Real-time Charts", "Interactive visualizations"),
246
+ ("🚀", "Fast Processing", "Optimized for speed")
247
+ ]
248
+
249
+ for icon, title, desc in features:
250
+ st.markdown(f"""
251
+ <div style="display: flex; align-items: center; margin: 1rem 0; padding: 0.5rem; background: #f0f2f6; border-radius: 5px;">
252
+ <div style="font-size: 1.5rem; margin-right: 1rem;">{icon}</div>
253
+ <div>
254
+ <strong>{title}</strong><br>
255
+ <small>{desc}</small>
256
+ </div>
257
+ </div>
258
+ """, unsafe_allow_html=True)
259
+
260
+ # Footer
261
+ st.markdown("---")
262
+ st.markdown("""
263
+ <div style="text-align: center; padding: 2rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; margin-top: 2rem;">
264
+ <h3>🎉 Experience Advanced Computer Vision</h3>
265
+ <p><strong>⭐ Built by Lee Akpareva | AI Consultant & Computer Vision Specialist ⭐</strong></p>
266
+ <p>🚀 <em>Powered by YOLOv8 • OpenAI • Streamlit</em></p>
267
+ </div>
268
+ """, unsafe_allow_html=True)
269
+
270
+ if __name__ == "__main__":
271
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_lite.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🚀 NAVADA 2.0 - Advanced AI Computer Vision Application (Lite Version)
3
+ Streamlit Version for Hugging Face Spaces Deployment
4
+
5
+ Enhanced Edition by Lee Akpareva | AI Consultant & Computer Vision Specialist
6
+ """
7
+
8
+ import streamlit as st
9
+ import time
10
+ from datetime import datetime
11
+ import plotly.graph_objects as go
12
+ import plotly.express as px
13
+ from PIL import Image
14
+ import numpy as np
15
+ import os
16
+
17
+ # Configure Streamlit page (MUST be first!)
18
+ st.set_page_config(
19
+ page_title="🚀 NAVADA 2.0 - AI Computer Vision",
20
+ page_icon="🚀",
21
+ layout="wide",
22
+ initial_sidebar_state="expanded"
23
+ )
24
+
25
+ # Backend imports - Lite version (no face recognition)
26
+ try:
27
+ from backend.yolo import detect_objects
28
+ from backend.openai_client import explain_detection
29
+ except ImportError as e:
30
+ st.error(f"⚠️ Import error: {e}")
31
+ st.error("📦 Please install dependencies: pip install -r requirements.txt")
32
+ st.stop()
33
+
34
+ # Custom CSS for enhanced styling
35
+ st.markdown("""
36
+ <style>
37
+ .main-header {
38
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
39
+ padding: 2rem;
40
+ border-radius: 10px;
41
+ color: white;
42
+ text-align: center;
43
+ margin-bottom: 2rem;
44
+ }
45
+
46
+ .feature-card {
47
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
48
+ padding: 1.5rem;
49
+ border-radius: 10px;
50
+ color: white;
51
+ margin: 1rem 0;
52
+ }
53
+
54
+ .stats-card {
55
+ background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
56
+ padding: 1rem;
57
+ border-radius: 8px;
58
+ color: white;
59
+ text-align: center;
60
+ margin: 0.5rem;
61
+ }
62
+ </style>
63
+ """, unsafe_allow_html=True)
64
+
65
+ def create_detection_chart(detected_objects):
66
+ """Create an interactive chart showing detection statistics"""
67
+
68
+ # Count object types
69
+ object_counts = {}
70
+ for obj in detected_objects:
71
+ object_counts[obj] = object_counts.get(obj, 0) + 1
72
+
73
+ if not object_counts:
74
+ # Create empty chart
75
+ fig = go.Figure()
76
+ fig.add_annotation(
77
+ text="No objects detected",
78
+ xref="paper", yref="paper",
79
+ x=0.5, y=0.5, showarrow=False,
80
+ font=dict(size=20, color="gray")
81
+ )
82
+ fig.update_layout(
83
+ height=300,
84
+ showlegend=False,
85
+ paper_bgcolor='rgba(0,0,0,0)',
86
+ plot_bgcolor='rgba(0,0,0,0)'
87
+ )
88
+ return fig
89
+
90
+ # Create bar chart
91
+ objects = list(object_counts.keys())
92
+ counts = list(object_counts.values())
93
+
94
+ fig = go.Figure(data=[
95
+ go.Bar(
96
+ x=objects,
97
+ y=counts,
98
+ marker_color='rgba(50, 171, 96, 0.6)',
99
+ marker_line_color='rgba(50, 171, 96, 1.0)',
100
+ marker_line_width=2,
101
+ text=counts,
102
+ textposition='auto'
103
+ )
104
+ ])
105
+
106
+ fig.update_layout(
107
+ title="Detected Objects",
108
+ xaxis_title="Object Type",
109
+ yaxis_title="Count",
110
+ height=400,
111
+ showlegend=False,
112
+ paper_bgcolor='rgba(0,0,0,0)',
113
+ plot_bgcolor='rgba(0,0,0,0)'
114
+ )
115
+
116
+ return fig
117
+
118
+ def main():
119
+ # Main header
120
+ st.markdown("""
121
+ <div class="main-header">
122
+ <h1>🚀 NAVADA 2.0 - Advanced AI Computer Vision</h1>
123
+ <p><strong>Lite Version - Object Detection & AI Analysis</strong></p>
124
+ <p>Built with YOLOv8 • OpenAI • Streamlit</p>
125
+ </div>
126
+ """, unsafe_allow_html=True)
127
+
128
+ # Sidebar
129
+ with st.sidebar:
130
+ st.markdown("### 🎯 Detection Settings")
131
+
132
+ # Detection confidence threshold
133
+ confidence = st.slider(
134
+ "Detection Confidence",
135
+ min_value=0.1,
136
+ max_value=1.0,
137
+ value=0.5,
138
+ step=0.05,
139
+ help="Minimum confidence for object detection"
140
+ )
141
+
142
+ st.markdown("### 📊 Features")
143
+ st.markdown("""
144
+ - 🎯 **Object Detection**: YOLOv8 powered
145
+ - 🤖 **AI Explanations**: OpenAI integration
146
+ - 📈 **Interactive Charts**: Real-time analytics
147
+ - 🎨 **Visual Results**: Annotated images
148
+ """)
149
+
150
+ st.markdown("### ℹ️ About")
151
+ st.markdown("""
152
+ This is the **Lite Version** optimized for Hugging Face Spaces.
153
+
154
+ **Created by:** Lee Akpareva
155
+ **AI Consultant & Computer Vision Specialist**
156
+ """)
157
+
158
+ # Main content
159
+ col1, col2 = st.columns([2, 1])
160
+
161
+ with col1:
162
+ st.markdown("### 📸 Upload Image for Analysis")
163
+
164
+ uploaded_file = st.file_uploader(
165
+ "Choose an image...",
166
+ type=['png', 'jpg', 'jpeg'],
167
+ help="Upload an image to detect objects and get AI analysis"
168
+ )
169
+
170
+ if uploaded_file is not None:
171
+ # Display uploaded image
172
+ image = Image.open(uploaded_file)
173
+ st.image(image, caption="Uploaded Image", use_column_width=True)
174
+
175
+ # Analysis button
176
+ if st.button("🚀 Analyze Image", type="primary"):
177
+ with st.spinner("🔍 Detecting objects..."):
178
+ # Perform object detection
179
+ results = detect_objects(image, confidence_threshold=confidence)
180
+
181
+ if results and len(results['detections']) > 0:
182
+ # Extract detected objects
183
+ detected_objects = [det['class'] for det in results['detections']]
184
+
185
+ # Display results
186
+ st.success(f"✅ Detected {len(detected_objects)} objects!")
187
+
188
+ # Show annotated image
189
+ st.markdown("### 🎯 Detection Results")
190
+ if 'annotated_image' in results:
191
+ st.image(results['annotated_image'], caption="Detected Objects", use_column_width=True)
192
+
193
+ # Show detection details
194
+ st.markdown("### 📋 Detected Objects")
195
+ for i, detection in enumerate(results['detections']):
196
+ col_a, col_b, col_c = st.columns(3)
197
+ with col_a:
198
+ st.metric("Object", detection['class'])
199
+ with col_b:
200
+ st.metric("Confidence", f"{detection['confidence']:.2%}")
201
+ with col_c:
202
+ st.metric("Count", f"#{i+1}")
203
+
204
+ # AI Explanation
205
+ if os.getenv("OPENAI_API_KEY"):
206
+ st.markdown("### 🤖 AI Analysis")
207
+ with st.spinner("🧠 Generating AI explanation..."):
208
+ try:
209
+ explanation = explain_detection(detected_objects)
210
+ st.markdown(f"**AI Insight:** {explanation}")
211
+ except Exception as e:
212
+ st.warning(f"AI analysis unavailable: {str(e)}")
213
+ else:
214
+ st.warning("🔑 Add OPENAI_API_KEY in settings for AI explanations")
215
+
216
+ else:
217
+ st.warning("❌ No objects detected. Try adjusting the confidence threshold.")
218
+
219
+ with col2:
220
+ st.markdown("### 📊 Detection Statistics")
221
+
222
+ # Sample chart (will be updated with real data)
223
+ sample_data = {
224
+ 'Object': ['Person', 'Car', 'Dog', 'Cat'],
225
+ 'Count': [3, 2, 1, 1]
226
+ }
227
+
228
+ fig = px.bar(
229
+ sample_data,
230
+ x='Object',
231
+ y='Count',
232
+ title="Sample Detection Results",
233
+ color='Count',
234
+ color_continuous_scale='Viridis'
235
+ )
236
+ fig.update_layout(height=300)
237
+ st.plotly_chart(fig, use_container_width=True)
238
+
239
+ # Feature highlights
240
+ st.markdown("### ✨ Key Features")
241
+
242
+ features = [
243
+ ("🎯", "Object Detection", "Advanced YOLOv8 model"),
244
+ ("🤖", "AI Analysis", "OpenAI explanations"),
245
+ ("📊", "Real-time Charts", "Interactive visualizations"),
246
+ ("🚀", "Fast Processing", "Optimized for speed")
247
+ ]
248
+
249
+ for icon, title, desc in features:
250
+ st.markdown(f"""
251
+ <div style="display: flex; align-items: center; margin: 1rem 0; padding: 0.5rem; background: #f0f2f6; border-radius: 5px;">
252
+ <div style="font-size: 1.5rem; margin-right: 1rem;">{icon}</div>
253
+ <div>
254
+ <strong>{title}</strong><br>
255
+ <small>{desc}</small>
256
+ </div>
257
+ </div>
258
+ """, unsafe_allow_html=True)
259
+
260
+ # Footer
261
+ st.markdown("---")
262
+ st.markdown("""
263
+ <div style="text-align: center; padding: 2rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white; margin-top: 2rem;">
264
+ <h3>🎉 Experience Advanced Computer Vision</h3>
265
+ <p><strong>⭐ Built by Lee Akpareva | AI Consultant & Computer Vision Specialist ⭐</strong></p>
266
+ <p>🚀 <em>Powered by YOLOv8 • OpenAI • Streamlit</em></p>
267
+ </div>
268
+ """, unsafe_allow_html=True)
269
+
270
+ if __name__ == "__main__":
271
+ main()
backend/chat_agent.py CHANGED
@@ -1,189 +1,189 @@
1
- """
2
- AI Chat Agent with conversation memory and text-to-speech capabilities
3
- """
4
- import os
5
- from openai import OpenAI # type: ignore
6
- import tempfile
7
- from datetime import datetime
8
- import json
9
-
10
- # Initialize OpenAI client
11
- api_key = os.getenv("OPENAI_API_KEY")
12
- if not api_key:
13
- raise ValueError("OPENAI_API_KEY environment variable is required")
14
- client = OpenAI(api_key=api_key)
15
-
16
- class ChatAgent:
17
- def __init__(self):
18
- """Initialize the chat agent with conversation memory"""
19
- self.conversation_history = []
20
- self.system_prompt = """You are NAVADA Assistant, an intelligent AI companion for computer vision analysis.
21
- You help users understand what's in their images, answer questions about detected objects,
22
- and provide insights about visual content. You're friendly, helpful, and knowledgeable about
23
- computer vision, image analysis, and can discuss colors, positions, sizes, and relationships
24
- between objects in images. You have access to detailed detection results including object colors,
25
- positions, sizes, and confidence scores."""
26
-
27
- # Add system message to history
28
- self.conversation_history.append({
29
- "role": "system",
30
- "content": self.system_prompt
31
- })
32
-
33
- # Store context about current image analysis
34
- self.current_image_context = None
35
-
36
- def update_image_context(self, detected_objects, detailed_attributes=None):
37
- """Update the agent's knowledge about the current image"""
38
- context = f"Current image analysis shows: {', '.join(detected_objects) if detected_objects else 'no objects detected'}."
39
-
40
- if detailed_attributes:
41
- context += "\n\nDetailed analysis:"
42
- for attr in detailed_attributes:
43
- colors = " and ".join(attr.get('colors', ['unknown'])[:2])
44
- context += f"\n- {attr['label']}: {colors} color(s), {attr.get('size', 'unknown')} size, located at {attr.get('position', 'unknown')} (confidence: {attr.get('confidence', 'unknown')})"
45
-
46
- self.current_image_context = context
47
-
48
- # Add context to conversation as a system message
49
- self.conversation_history.append({
50
- "role": "system",
51
- "content": f"Image context update: {context}"
52
- })
53
-
54
- def chat(self, user_message, include_voice=True):
55
- """
56
- Process user message and return response with optional voice
57
-
58
- Args:
59
- user_message: The user's input message
60
- include_voice: Whether to generate voice response
61
-
62
- Returns:
63
- tuple: (text_response, voice_file_path or None)
64
- """
65
- # Add user message to history
66
- self.conversation_history.append({
67
- "role": "user",
68
- "content": user_message
69
- })
70
-
71
- # Keep conversation history manageable (last 20 messages)
72
- if len(self.conversation_history) > 20:
73
- # Keep system prompt and current context, remove old messages
74
- system_messages = [msg for msg in self.conversation_history if msg["role"] == "system"]
75
- recent_messages = self.conversation_history[-15:]
76
- self.conversation_history = system_messages + recent_messages
77
-
78
- try:
79
- # Get response from OpenAI
80
- response = client.chat.completions.create(
81
- model="gpt-4o-mini",
82
- messages=self.conversation_history,
83
- temperature=0.7,
84
- max_tokens=500
85
- )
86
-
87
- text_response = response.choices[0].message.content
88
-
89
- # Add assistant response to history
90
- self.conversation_history.append({
91
- "role": "assistant",
92
- "content": text_response
93
- })
94
-
95
- # Generate voice if requested
96
- voice_file = None
97
- if include_voice:
98
- voice_file = self.generate_voice(text_response)
99
-
100
- return text_response, voice_file
101
-
102
- except Exception as e:
103
- error_msg = f"Chat error: {str(e)}"
104
- return error_msg, None
105
-
106
- def generate_voice(self, text):
107
- """Generate voice narration for text using OpenAI TTS"""
108
- try:
109
- # Generate speech using OpenAI TTS
110
- response = client.audio.speech.create(
111
- model="tts-1",
112
- voice="nova", # Options: alloy, echo, fable, onyx, nova, shimmer
113
- input=text,
114
- response_format="mp3"
115
- )
116
-
117
- # Save to temporary file
118
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
119
- temp_audio.write(response.content)
120
- return temp_audio.name
121
-
122
- except Exception as e:
123
- print(f"Voice generation error: {e}")
124
- return None
125
-
126
- def get_conversation_summary(self):
127
- """Get a summary of the conversation"""
128
- messages = [msg for msg in self.conversation_history if msg["role"] in ["user", "assistant"]]
129
- return messages
130
-
131
- def reset_conversation(self):
132
- """Reset conversation history while keeping system prompt"""
133
- self.conversation_history = [{
134
- "role": "system",
135
- "content": self.system_prompt
136
- }]
137
- self.current_image_context = None
138
-
139
- def save_conversation(self, filepath=None):
140
- """Save conversation history to file"""
141
- if filepath is None:
142
- filepath = f"conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
143
-
144
- with open(filepath, 'w') as f:
145
- json.dump({
146
- 'timestamp': datetime.now().isoformat(),
147
- 'conversation': self.conversation_history,
148
- 'image_context': self.current_image_context
149
- }, f, indent=2)
150
-
151
- return filepath
152
-
153
- def load_conversation(self, filepath):
154
- """Load conversation history from file"""
155
- with open(filepath, 'r') as f:
156
- data = json.load(f)
157
- self.conversation_history = data['conversation']
158
- self.current_image_context = data.get('image_context')
159
-
160
- # Create a global chat agent instance
161
- chat_agent = ChatAgent()
162
-
163
- # Helper functions for easy integration
164
- def chat_with_agent(message, detected_objects=None, detailed_attributes=None, include_voice=True):
165
- """
166
- Simple interface to chat with the agent
167
-
168
- Args:
169
- message: User's message
170
- detected_objects: List of detected objects (optional)
171
- detailed_attributes: Detailed attributes from enhanced detection (optional)
172
- include_voice: Whether to generate voice response
173
-
174
- Returns:
175
- tuple: (text_response, voice_file_path or None)
176
- """
177
- # Update context if new detection results provided
178
- if detected_objects is not None:
179
- chat_agent.update_image_context(detected_objects, detailed_attributes)
180
-
181
- return chat_agent.chat(message, include_voice)
182
-
183
- def reset_chat():
184
- """Reset the chat conversation"""
185
- chat_agent.reset_conversation()
186
-
187
- def get_chat_history():
188
- """Get the current chat history"""
189
  return chat_agent.get_conversation_summary()
 
1
+ """
2
+ AI Chat Agent with conversation memory and text-to-speech capabilities
3
+ """
4
+ import os
5
+ from openai import OpenAI # type: ignore
6
+ import tempfile
7
+ from datetime import datetime
8
+ import json
9
+
10
+ # Initialize OpenAI client
11
+ api_key = os.getenv("OPENAI_API_KEY")
12
+ if not api_key:
13
+ raise ValueError("OPENAI_API_KEY environment variable is required")
14
+ client = OpenAI(api_key=api_key)
15
+
16
+ class ChatAgent:
17
+ def __init__(self):
18
+ """Initialize the chat agent with conversation memory"""
19
+ self.conversation_history = []
20
+ self.system_prompt = """You are NAVADA Assistant, an intelligent AI companion for computer vision analysis.
21
+ You help users understand what's in their images, answer questions about detected objects,
22
+ and provide insights about visual content. You're friendly, helpful, and knowledgeable about
23
+ computer vision, image analysis, and can discuss colors, positions, sizes, and relationships
24
+ between objects in images. You have access to detailed detection results including object colors,
25
+ positions, sizes, and confidence scores."""
26
+
27
+ # Add system message to history
28
+ self.conversation_history.append({
29
+ "role": "system",
30
+ "content": self.system_prompt
31
+ })
32
+
33
+ # Store context about current image analysis
34
+ self.current_image_context = None
35
+
36
+ def update_image_context(self, detected_objects, detailed_attributes=None):
37
+ """Update the agent's knowledge about the current image"""
38
+ context = f"Current image analysis shows: {', '.join(detected_objects) if detected_objects else 'no objects detected'}."
39
+
40
+ if detailed_attributes:
41
+ context += "\n\nDetailed analysis:"
42
+ for attr in detailed_attributes:
43
+ colors = " and ".join(attr.get('colors', ['unknown'])[:2])
44
+ context += f"\n- {attr['label']}: {colors} color(s), {attr.get('size', 'unknown')} size, located at {attr.get('position', 'unknown')} (confidence: {attr.get('confidence', 'unknown')})"
45
+
46
+ self.current_image_context = context
47
+
48
+ # Add context to conversation as a system message
49
+ self.conversation_history.append({
50
+ "role": "system",
51
+ "content": f"Image context update: {context}"
52
+ })
53
+
54
+ def chat(self, user_message, include_voice=True):
55
+ """
56
+ Process user message and return response with optional voice
57
+
58
+ Args:
59
+ user_message: The user's input message
60
+ include_voice: Whether to generate voice response
61
+
62
+ Returns:
63
+ tuple: (text_response, voice_file_path or None)
64
+ """
65
+ # Add user message to history
66
+ self.conversation_history.append({
67
+ "role": "user",
68
+ "content": user_message
69
+ })
70
+
71
+ # Keep conversation history manageable (last 20 messages)
72
+ if len(self.conversation_history) > 20:
73
+ # Keep system prompt and current context, remove old messages
74
+ system_messages = [msg for msg in self.conversation_history if msg["role"] == "system"]
75
+ recent_messages = self.conversation_history[-15:]
76
+ self.conversation_history = system_messages + recent_messages
77
+
78
+ try:
79
+ # Get response from OpenAI
80
+ response = client.chat.completions.create(
81
+ model="gpt-4o-mini",
82
+ messages=self.conversation_history,
83
+ temperature=0.7,
84
+ max_tokens=500
85
+ )
86
+
87
+ text_response = response.choices[0].message.content
88
+
89
+ # Add assistant response to history
90
+ self.conversation_history.append({
91
+ "role": "assistant",
92
+ "content": text_response
93
+ })
94
+
95
+ # Generate voice if requested
96
+ voice_file = None
97
+ if include_voice:
98
+ voice_file = self.generate_voice(text_response)
99
+
100
+ return text_response, voice_file
101
+
102
+ except Exception as e:
103
+ error_msg = f"Chat error: {str(e)}"
104
+ return error_msg, None
105
+
106
+ def generate_voice(self, text):
107
+ """Generate voice narration for text using OpenAI TTS"""
108
+ try:
109
+ # Generate speech using OpenAI TTS
110
+ response = client.audio.speech.create(
111
+ model="tts-1",
112
+ voice="nova", # Options: alloy, echo, fable, onyx, nova, shimmer
113
+ input=text,
114
+ response_format="mp3"
115
+ )
116
+
117
+ # Save to temporary file
118
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
119
+ temp_audio.write(response.content)
120
+ return temp_audio.name
121
+
122
+ except Exception as e:
123
+ print(f"Voice generation error: {e}")
124
+ return None
125
+
126
+ def get_conversation_summary(self):
127
+ """Get a summary of the conversation"""
128
+ messages = [msg for msg in self.conversation_history if msg["role"] in ["user", "assistant"]]
129
+ return messages
130
+
131
+ def reset_conversation(self):
132
+ """Reset conversation history while keeping system prompt"""
133
+ self.conversation_history = [{
134
+ "role": "system",
135
+ "content": self.system_prompt
136
+ }]
137
+ self.current_image_context = None
138
+
139
+ def save_conversation(self, filepath=None):
140
+ """Save conversation history to file"""
141
+ if filepath is None:
142
+ filepath = f"conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
143
+
144
+ with open(filepath, 'w') as f:
145
+ json.dump({
146
+ 'timestamp': datetime.now().isoformat(),
147
+ 'conversation': self.conversation_history,
148
+ 'image_context': self.current_image_context
149
+ }, f, indent=2)
150
+
151
+ return filepath
152
+
153
+ def load_conversation(self, filepath):
154
+ """Load conversation history from file"""
155
+ with open(filepath, 'r') as f:
156
+ data = json.load(f)
157
+ self.conversation_history = data['conversation']
158
+ self.current_image_context = data.get('image_context')
159
+
160
+ # Create a global chat agent instance
161
+ chat_agent = ChatAgent()
162
+
163
+ # Helper functions for easy integration
164
+ def chat_with_agent(message, detected_objects=None, detailed_attributes=None, include_voice=True):
165
+ """
166
+ Simple interface to chat with the agent
167
+
168
+ Args:
169
+ message: User's message
170
+ detected_objects: List of detected objects (optional)
171
+ detailed_attributes: Detailed attributes from enhanced detection (optional)
172
+ include_voice: Whether to generate voice response
173
+
174
+ Returns:
175
+ tuple: (text_response, voice_file_path or None)
176
+ """
177
+ # Update context if new detection results provided
178
+ if detected_objects is not None:
179
+ chat_agent.update_image_context(detected_objects, detailed_attributes)
180
+
181
+ return chat_agent.chat(message, include_voice)
182
+
183
+ def reset_chat():
184
+ """Reset the chat conversation"""
185
+ chat_agent.reset_conversation()
186
+
187
+ def get_chat_history():
188
+ """Get the current chat history"""
189
  return chat_agent.get_conversation_summary()
backend/custom_trainer.py CHANGED
@@ -1,399 +1,399 @@
1
- """
2
- Custom Object Classifier Training Module
3
- Implements transfer learning for user feedback corrections
4
- """
5
- import torch
6
- import torch.nn as nn
7
- import torchvision.transforms as transforms
8
- from torchvision import models
9
- import numpy as np
10
- import cv2
11
- from torch.utils.data import Dataset, DataLoader
12
- # from sklearn.model_selection import train_test_split # Temporarily disabled due to numpy compatibility
13
- # from sklearn.metrics import accuracy_score, precision_recall_fscore_support # Temporarily disabled
14
- import pickle
15
- import json
16
- from typing import List, Dict, Tuple, Optional
17
- from datetime import datetime
18
- import logging
19
- from pathlib import Path
20
-
21
- # Configure logging
22
- logger = logging.getLogger(__name__)
23
-
24
- class CustomObjectDataset(Dataset):
25
- """Dataset class for custom object training"""
26
-
27
- def __init__(self, data: List[Dict], transform=None):
28
- """
29
- Initialize dataset with training data
30
-
31
- Args:
32
- data: List of training samples from database
33
- transform: Image transformations
34
- """
35
- self.data = data
36
- self.transform = transform
37
-
38
- # Create label mapping
39
- unique_labels = list(set([sample['correct_label'] for sample in data]))
40
- self.label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
41
- self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}
42
- self.num_classes = len(unique_labels)
43
-
44
- def __len__(self):
45
- return len(self.data)
46
-
47
- def __getitem__(self, idx):
48
- sample = self.data[idx]
49
- image = sample['image']
50
- label = sample['correct_label']
51
-
52
- # Convert BGR to RGB
53
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
54
-
55
- if self.transform:
56
- image = self.transform(image)
57
-
58
- label_idx = self.label_to_idx[label]
59
-
60
- return {
61
- 'image': image,
62
- 'label': label_idx,
63
- 'original_label': label,
64
- 'yolo_prediction': sample['yolo_prediction'],
65
- 'confidence': sample['yolo_confidence'],
66
- 'difficulty': sample['difficulty_score']
67
- }
68
-
69
- class CustomClassifier(nn.Module):
70
- """Custom classifier built on pre-trained backbone"""
71
-
72
- def __init__(self, num_classes: int, backbone='resnet18', pretrained=True):
73
- """
74
- Initialize custom classifier
75
-
76
- Args:
77
- num_classes: Number of output classes
78
- backbone: Backbone architecture (resnet18, resnet50, efficientnet_b0)
79
- pretrained: Use pre-trained weights
80
- """
81
- super(CustomClassifier, self).__init__()
82
-
83
- self.num_classes = num_classes
84
- self.backbone = backbone
85
-
86
- if backbone == 'resnet18':
87
- self.model = models.resnet18(pretrained=pretrained)
88
- self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
89
- elif backbone == 'resnet50':
90
- self.model = models.resnet50(pretrained=pretrained)
91
- self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
92
- elif backbone == 'efficientnet_b0':
93
- self.model = models.efficientnet_b0(pretrained=pretrained)
94
- self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)
95
- else:
96
- raise ValueError(f"Unsupported backbone: {backbone}")
97
-
98
- def forward(self, x):
99
- return self.model(x)
100
-
101
- class CustomTrainer:
102
- """Trainer class for custom object classification"""
103
-
104
- def __init__(self, model_dir='models/', device=None):
105
- """
106
- Initialize trainer
107
-
108
- Args:
109
- model_dir: Directory to save models
110
- device: Training device (cuda/cpu)
111
- """
112
- self.model_dir = Path(model_dir)
113
- self.model_dir.mkdir(exist_ok=True)
114
-
115
- self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
116
- logger.info(f"Using device: {self.device}")
117
-
118
- # Image transformations
119
- self.train_transform = transforms.Compose([
120
- transforms.ToPILImage(),
121
- transforms.Resize((224, 224)),
122
- transforms.RandomHorizontalFlip(0.5),
123
- transforms.RandomRotation(10),
124
- transforms.ColorJitter(brightness=0.2, contrast=0.2),
125
- transforms.ToTensor(),
126
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
127
- ])
128
-
129
- self.val_transform = transforms.Compose([
130
- transforms.ToPILImage(),
131
- transforms.Resize((224, 224)),
132
- transforms.ToTensor(),
133
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
134
- ])
135
-
136
- def prepare_data(self, training_data: List[Dict], test_size=0.2, min_samples_per_class=5):
137
- """
138
- Prepare training and validation data
139
-
140
- Args:
141
- training_data: List of training samples from database
142
- test_size: Fraction for validation split
143
- min_samples_per_class: Minimum samples required per class
144
-
145
- Returns:
146
- Tuple of (train_dataset, val_dataset, class_info)
147
- """
148
- # Filter classes with insufficient samples
149
- class_counts = {}
150
- for sample in training_data:
151
- label = sample['correct_label']
152
- class_counts[label] = class_counts.get(label, 0) + 1
153
-
154
- # Remove classes with insufficient samples
155
- valid_classes = {label for label, count in class_counts.items()
156
- if count >= min_samples_per_class}
157
-
158
- filtered_data = [sample for sample in training_data
159
- if sample['correct_label'] in valid_classes]
160
-
161
- if len(filtered_data) < 10:
162
- raise ValueError(f"Insufficient training data: {len(filtered_data)} samples")
163
-
164
- if len(valid_classes) < 2:
165
- raise ValueError(f"Need at least 2 classes, got {len(valid_classes)}")
166
-
167
- # Simple train/val split without sklearn
168
- np.random.seed(42)
169
- indices = np.random.permutation(len(filtered_data))
170
- split_idx = int(len(filtered_data) * (1 - test_size))
171
-
172
- train_indices = indices[:split_idx]
173
- val_indices = indices[split_idx:]
174
-
175
- train_data = [filtered_data[i] for i in train_indices]
176
- val_data = [filtered_data[i] for i in val_indices]
177
-
178
- # Create datasets
179
- train_dataset = CustomObjectDataset(train_data, self.train_transform)
180
- val_dataset = CustomObjectDataset(val_data, self.val_transform)
181
-
182
- # Ensure same label mapping
183
- val_dataset.label_to_idx = train_dataset.label_to_idx
184
- val_dataset.idx_to_label = train_dataset.idx_to_label
185
- val_dataset.num_classes = train_dataset.num_classes
186
-
187
- class_info = {
188
- 'num_classes': train_dataset.num_classes,
189
- 'label_to_idx': train_dataset.label_to_idx,
190
- 'idx_to_label': train_dataset.idx_to_label,
191
- 'class_counts': class_counts,
192
- 'valid_classes': list(valid_classes),
193
- 'train_samples': len(train_data),
194
- 'val_samples': len(val_data)
195
- }
196
-
197
- return train_dataset, val_dataset, class_info
198
-
199
- def train_model(self, training_data: List[Dict],
200
- epochs=20, batch_size=16, learning_rate=0.001,
201
- backbone='resnet18', patience=5) -> Dict:
202
- """
203
- Train custom classifier
204
-
205
- Args:
206
- training_data: Training samples from database
207
- epochs: Number of training epochs
208
- batch_size: Batch size for training
209
- learning_rate: Learning rate
210
- backbone: Model backbone architecture
211
- patience: Early stopping patience
212
-
213
- Returns:
214
- Training results and metrics
215
- """
216
- try:
217
- # Prepare data
218
- train_dataset, val_dataset, class_info = self.prepare_data(training_data)
219
-
220
- # Create data loaders
221
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
222
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
223
-
224
- # Initialize model
225
- model = CustomClassifier(class_info['num_classes'], backbone)
226
- model = model.to(self.device)
227
-
228
- # Loss and optimizer
229
- criterion = nn.CrossEntropyLoss()
230
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
231
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
232
-
233
- # Training history
234
- history = {
235
- 'train_loss': [],
236
- 'train_acc': [],
237
- 'val_loss': [],
238
- 'val_acc': []
239
- }
240
-
241
- best_val_acc = 0.0
242
- patience_counter = 0
243
-
244
- logger.info(f"Starting training: {epochs} epochs, {class_info['num_classes']} classes")
245
-
246
- for epoch in range(epochs):
247
- # Training phase
248
- model.train()
249
- train_loss = 0.0
250
- train_correct = 0
251
- train_total = 0
252
-
253
- for batch in train_loader:
254
- images = batch['image'].to(self.device)
255
- labels = batch['label'].to(self.device)
256
-
257
- optimizer.zero_grad()
258
- outputs = model(images)
259
- loss = criterion(outputs, labels)
260
- loss.backward()
261
- optimizer.step()
262
-
263
- train_loss += loss.item()
264
- _, predicted = torch.max(outputs.data, 1)
265
- train_total += labels.size(0)
266
- train_correct += (predicted == labels).sum().item()
267
-
268
- train_acc = train_correct / train_total
269
- avg_train_loss = train_loss / len(train_loader)
270
-
271
- # Validation phase
272
- model.eval()
273
- val_loss = 0.0
274
- val_correct = 0
275
- val_total = 0
276
-
277
- with torch.no_grad():
278
- for batch in val_loader:
279
- images = batch['image'].to(self.device)
280
- labels = batch['label'].to(self.device)
281
-
282
- outputs = model(images)
283
- loss = criterion(outputs, labels)
284
-
285
- val_loss += loss.item()
286
- _, predicted = torch.max(outputs.data, 1)
287
- val_total += labels.size(0)
288
- val_correct += (predicted == labels).sum().item()
289
-
290
- val_acc = val_correct / val_total
291
- avg_val_loss = val_loss / len(val_loader)
292
-
293
- # Update history
294
- history['train_loss'].append(avg_train_loss)
295
- history['train_acc'].append(train_acc)
296
- history['val_loss'].append(avg_val_loss)
297
- history['val_acc'].append(val_acc)
298
-
299
- logger.info(f"Epoch {epoch+1}/{epochs}: "
300
- f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, "
301
- f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")
302
-
303
- # Early stopping
304
- if val_acc > best_val_acc:
305
- best_val_acc = val_acc
306
- patience_counter = 0
307
- # Save best model
308
- torch.save(model.state_dict(), self.model_dir / 'best_model.pth')
309
- else:
310
- patience_counter += 1
311
-
312
- if patience_counter >= patience:
313
- logger.info(f"Early stopping at epoch {epoch+1}")
314
- break
315
-
316
- scheduler.step()
317
-
318
- # Load best model
319
- model.load_state_dict(torch.load(self.model_dir / 'best_model.pth'))
320
-
321
- # Final evaluation
322
- final_metrics = self.evaluate_model(model, val_loader, class_info)
323
-
324
- # Save model and metadata
325
- model_info = {
326
- 'model_state': model.state_dict(),
327
- 'class_info': class_info,
328
- 'training_config': {
329
- 'backbone': backbone,
330
- 'epochs': epochs,
331
- 'batch_size': batch_size,
332
- 'learning_rate': learning_rate
333
- },
334
- 'history': history,
335
- 'metrics': final_metrics,
336
- 'timestamp': datetime.now().isoformat()
337
- }
338
-
339
- # Save complete model info
340
- model_path = self.model_dir / f'custom_classifier_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pkl'
341
- with open(model_path, 'wb') as f:
342
- pickle.dump(model_info, f)
343
-
344
- logger.info(f"Training completed. Best validation accuracy: {best_val_acc:.4f}")
345
- logger.info(f"Model saved to: {model_path}")
346
-
347
- return {
348
- 'success': True,
349
- 'model_path': str(model_path),
350
- 'best_accuracy': best_val_acc,
351
- 'final_metrics': final_metrics,
352
- 'class_info': class_info,
353
- 'history': history
354
- }
355
-
356
- except Exception as e:
357
- logger.error(f"Training failed: {e}")
358
- return {
359
- 'success': False,
360
- 'error': str(e)
361
- }
362
-
363
- def evaluate_model(self, model, val_loader, class_info) -> Dict:
364
- """Evaluate model performance"""
365
- model.eval()
366
- all_predictions = []
367
- all_labels = []
368
- all_confidences = []
369
-
370
- with torch.no_grad():
371
- for batch in val_loader:
372
- images = batch['image'].to(self.device)
373
- labels = batch['label']
374
-
375
- outputs = model(images)
376
- probabilities = torch.softmax(outputs, dim=1)
377
- confidences, predicted = torch.max(probabilities, 1)
378
-
379
- all_predictions.extend(predicted.cpu().numpy())
380
- all_labels.extend(labels.numpy())
381
- all_confidences.extend(confidences.cpu().numpy())
382
-
383
- # Calculate metrics manually without sklearn
384
- accuracy = sum(1 for true, pred in zip(all_labels, all_predictions) if true == pred) / len(all_labels)
385
-
386
- # Simple precision/recall calculation
387
- precision = recall = f1 = accuracy # Simplified for now
388
-
389
- return {
390
- 'accuracy': float(accuracy),
391
- 'precision': float(precision),
392
- 'recall': float(recall),
393
- 'f1_score': float(f1),
394
- 'avg_confidence': float(np.mean(all_confidences)),
395
- 'num_samples': len(all_labels)
396
- }
397
-
398
- # Global trainer instance
399
  custom_trainer = CustomTrainer()
 
1
+ """
2
+ Custom Object Classifier Training Module
3
+ Implements transfer learning for user feedback corrections
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms as transforms
8
+ from torchvision import models
9
+ import numpy as np
10
+ import cv2
11
+ from torch.utils.data import Dataset, DataLoader
12
+ # from sklearn.model_selection import train_test_split # Temporarily disabled due to numpy compatibility
13
+ # from sklearn.metrics import accuracy_score, precision_recall_fscore_support # Temporarily disabled
14
+ import pickle
15
+ import json
16
+ from typing import List, Dict, Tuple, Optional
17
+ from datetime import datetime
18
+ import logging
19
+ from pathlib import Path
20
+
21
+ # Configure logging
22
+ logger = logging.getLogger(__name__)
23
+
24
+ class CustomObjectDataset(Dataset):
25
+ """Dataset class for custom object training"""
26
+
27
+ def __init__(self, data: List[Dict], transform=None):
28
+ """
29
+ Initialize dataset with training data
30
+
31
+ Args:
32
+ data: List of training samples from database
33
+ transform: Image transformations
34
+ """
35
+ self.data = data
36
+ self.transform = transform
37
+
38
+ # Create label mapping
39
+ unique_labels = list(set([sample['correct_label'] for sample in data]))
40
+ self.label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
41
+ self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}
42
+ self.num_classes = len(unique_labels)
43
+
44
+ def __len__(self):
45
+ return len(self.data)
46
+
47
+ def __getitem__(self, idx):
48
+ sample = self.data[idx]
49
+ image = sample['image']
50
+ label = sample['correct_label']
51
+
52
+ # Convert BGR to RGB
53
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
54
+
55
+ if self.transform:
56
+ image = self.transform(image)
57
+
58
+ label_idx = self.label_to_idx[label]
59
+
60
+ return {
61
+ 'image': image,
62
+ 'label': label_idx,
63
+ 'original_label': label,
64
+ 'yolo_prediction': sample['yolo_prediction'],
65
+ 'confidence': sample['yolo_confidence'],
66
+ 'difficulty': sample['difficulty_score']
67
+ }
68
+
69
+ class CustomClassifier(nn.Module):
70
+ """Custom classifier built on pre-trained backbone"""
71
+
72
+ def __init__(self, num_classes: int, backbone='resnet18', pretrained=True):
73
+ """
74
+ Initialize custom classifier
75
+
76
+ Args:
77
+ num_classes: Number of output classes
78
+ backbone: Backbone architecture (resnet18, resnet50, efficientnet_b0)
79
+ pretrained: Use pre-trained weights
80
+ """
81
+ super(CustomClassifier, self).__init__()
82
+
83
+ self.num_classes = num_classes
84
+ self.backbone = backbone
85
+
86
+ if backbone == 'resnet18':
87
+ self.model = models.resnet18(pretrained=pretrained)
88
+ self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
89
+ elif backbone == 'resnet50':
90
+ self.model = models.resnet50(pretrained=pretrained)
91
+ self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
92
+ elif backbone == 'efficientnet_b0':
93
+ self.model = models.efficientnet_b0(pretrained=pretrained)
94
+ self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)
95
+ else:
96
+ raise ValueError(f"Unsupported backbone: {backbone}")
97
+
98
+ def forward(self, x):
99
+ return self.model(x)
100
+
101
+ class CustomTrainer:
102
+ """Trainer class for custom object classification"""
103
+
104
+ def __init__(self, model_dir='models/', device=None):
105
+ """
106
+ Initialize trainer
107
+
108
+ Args:
109
+ model_dir: Directory to save models
110
+ device: Training device (cuda/cpu)
111
+ """
112
+ self.model_dir = Path(model_dir)
113
+ self.model_dir.mkdir(exist_ok=True)
114
+
115
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
116
+ logger.info(f"Using device: {self.device}")
117
+
118
+ # Image transformations
119
+ self.train_transform = transforms.Compose([
120
+ transforms.ToPILImage(),
121
+ transforms.Resize((224, 224)),
122
+ transforms.RandomHorizontalFlip(0.5),
123
+ transforms.RandomRotation(10),
124
+ transforms.ColorJitter(brightness=0.2, contrast=0.2),
125
+ transforms.ToTensor(),
126
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
127
+ ])
128
+
129
+ self.val_transform = transforms.Compose([
130
+ transforms.ToPILImage(),
131
+ transforms.Resize((224, 224)),
132
+ transforms.ToTensor(),
133
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
134
+ ])
135
+
136
+ def prepare_data(self, training_data: List[Dict], test_size=0.2, min_samples_per_class=5):
137
+ """
138
+ Prepare training and validation data
139
+
140
+ Args:
141
+ training_data: List of training samples from database
142
+ test_size: Fraction for validation split
143
+ min_samples_per_class: Minimum samples required per class
144
+
145
+ Returns:
146
+ Tuple of (train_dataset, val_dataset, class_info)
147
+ """
148
+ # Filter classes with insufficient samples
149
+ class_counts = {}
150
+ for sample in training_data:
151
+ label = sample['correct_label']
152
+ class_counts[label] = class_counts.get(label, 0) + 1
153
+
154
+ # Remove classes with insufficient samples
155
+ valid_classes = {label for label, count in class_counts.items()
156
+ if count >= min_samples_per_class}
157
+
158
+ filtered_data = [sample for sample in training_data
159
+ if sample['correct_label'] in valid_classes]
160
+
161
+ if len(filtered_data) < 10:
162
+ raise ValueError(f"Insufficient training data: {len(filtered_data)} samples")
163
+
164
+ if len(valid_classes) < 2:
165
+ raise ValueError(f"Need at least 2 classes, got {len(valid_classes)}")
166
+
167
+ # Simple train/val split without sklearn
168
+ np.random.seed(42)
169
+ indices = np.random.permutation(len(filtered_data))
170
+ split_idx = int(len(filtered_data) * (1 - test_size))
171
+
172
+ train_indices = indices[:split_idx]
173
+ val_indices = indices[split_idx:]
174
+
175
+ train_data = [filtered_data[i] for i in train_indices]
176
+ val_data = [filtered_data[i] for i in val_indices]
177
+
178
+ # Create datasets
179
+ train_dataset = CustomObjectDataset(train_data, self.train_transform)
180
+ val_dataset = CustomObjectDataset(val_data, self.val_transform)
181
+
182
+ # Ensure same label mapping
183
+ val_dataset.label_to_idx = train_dataset.label_to_idx
184
+ val_dataset.idx_to_label = train_dataset.idx_to_label
185
+ val_dataset.num_classes = train_dataset.num_classes
186
+
187
+ class_info = {
188
+ 'num_classes': train_dataset.num_classes,
189
+ 'label_to_idx': train_dataset.label_to_idx,
190
+ 'idx_to_label': train_dataset.idx_to_label,
191
+ 'class_counts': class_counts,
192
+ 'valid_classes': list(valid_classes),
193
+ 'train_samples': len(train_data),
194
+ 'val_samples': len(val_data)
195
+ }
196
+
197
+ return train_dataset, val_dataset, class_info
198
+
199
+ def train_model(self, training_data: List[Dict],
200
+ epochs=20, batch_size=16, learning_rate=0.001,
201
+ backbone='resnet18', patience=5) -> Dict:
202
+ """
203
+ Train custom classifier
204
+
205
+ Args:
206
+ training_data: Training samples from database
207
+ epochs: Number of training epochs
208
+ batch_size: Batch size for training
209
+ learning_rate: Learning rate
210
+ backbone: Model backbone architecture
211
+ patience: Early stopping patience
212
+
213
+ Returns:
214
+ Training results and metrics
215
+ """
216
+ try:
217
+ # Prepare data
218
+ train_dataset, val_dataset, class_info = self.prepare_data(training_data)
219
+
220
+ # Create data loaders
221
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
222
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
223
+
224
+ # Initialize model
225
+ model = CustomClassifier(class_info['num_classes'], backbone)
226
+ model = model.to(self.device)
227
+
228
+ # Loss and optimizer
229
+ criterion = nn.CrossEntropyLoss()
230
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
231
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
232
+
233
+ # Training history
234
+ history = {
235
+ 'train_loss': [],
236
+ 'train_acc': [],
237
+ 'val_loss': [],
238
+ 'val_acc': []
239
+ }
240
+
241
+ best_val_acc = 0.0
242
+ patience_counter = 0
243
+
244
+ logger.info(f"Starting training: {epochs} epochs, {class_info['num_classes']} classes")
245
+
246
+ for epoch in range(epochs):
247
+ # Training phase
248
+ model.train()
249
+ train_loss = 0.0
250
+ train_correct = 0
251
+ train_total = 0
252
+
253
+ for batch in train_loader:
254
+ images = batch['image'].to(self.device)
255
+ labels = batch['label'].to(self.device)
256
+
257
+ optimizer.zero_grad()
258
+ outputs = model(images)
259
+ loss = criterion(outputs, labels)
260
+ loss.backward()
261
+ optimizer.step()
262
+
263
+ train_loss += loss.item()
264
+ _, predicted = torch.max(outputs.data, 1)
265
+ train_total += labels.size(0)
266
+ train_correct += (predicted == labels).sum().item()
267
+
268
+ train_acc = train_correct / train_total
269
+ avg_train_loss = train_loss / len(train_loader)
270
+
271
+ # Validation phase
272
+ model.eval()
273
+ val_loss = 0.0
274
+ val_correct = 0
275
+ val_total = 0
276
+
277
+ with torch.no_grad():
278
+ for batch in val_loader:
279
+ images = batch['image'].to(self.device)
280
+ labels = batch['label'].to(self.device)
281
+
282
+ outputs = model(images)
283
+ loss = criterion(outputs, labels)
284
+
285
+ val_loss += loss.item()
286
+ _, predicted = torch.max(outputs.data, 1)
287
+ val_total += labels.size(0)
288
+ val_correct += (predicted == labels).sum().item()
289
+
290
+ val_acc = val_correct / val_total
291
+ avg_val_loss = val_loss / len(val_loader)
292
+
293
+ # Update history
294
+ history['train_loss'].append(avg_train_loss)
295
+ history['train_acc'].append(train_acc)
296
+ history['val_loss'].append(avg_val_loss)
297
+ history['val_acc'].append(val_acc)
298
+
299
+ logger.info(f"Epoch {epoch+1}/{epochs}: "
300
+ f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, "
301
+ f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")
302
+
303
+ # Early stopping
304
+ if val_acc > best_val_acc:
305
+ best_val_acc = val_acc
306
+ patience_counter = 0
307
+ # Save best model
308
+ torch.save(model.state_dict(), self.model_dir / 'best_model.pth')
309
+ else:
310
+ patience_counter += 1
311
+
312
+ if patience_counter >= patience:
313
+ logger.info(f"Early stopping at epoch {epoch+1}")
314
+ break
315
+
316
+ scheduler.step()
317
+
318
+ # Load best model
319
+ model.load_state_dict(torch.load(self.model_dir / 'best_model.pth'))
320
+
321
+ # Final evaluation
322
+ final_metrics = self.evaluate_model(model, val_loader, class_info)
323
+
324
+ # Save model and metadata
325
+ model_info = {
326
+ 'model_state': model.state_dict(),
327
+ 'class_info': class_info,
328
+ 'training_config': {
329
+ 'backbone': backbone,
330
+ 'epochs': epochs,
331
+ 'batch_size': batch_size,
332
+ 'learning_rate': learning_rate
333
+ },
334
+ 'history': history,
335
+ 'metrics': final_metrics,
336
+ 'timestamp': datetime.now().isoformat()
337
+ }
338
+
339
+ # Save complete model info
340
+ model_path = self.model_dir / f'custom_classifier_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pkl'
341
+ with open(model_path, 'wb') as f:
342
+ pickle.dump(model_info, f)
343
+
344
+ logger.info(f"Training completed. Best validation accuracy: {best_val_acc:.4f}")
345
+ logger.info(f"Model saved to: {model_path}")
346
+
347
+ return {
348
+ 'success': True,
349
+ 'model_path': str(model_path),
350
+ 'best_accuracy': best_val_acc,
351
+ 'final_metrics': final_metrics,
352
+ 'class_info': class_info,
353
+ 'history': history
354
+ }
355
+
356
+ except Exception as e:
357
+ logger.error(f"Training failed: {e}")
358
+ return {
359
+ 'success': False,
360
+ 'error': str(e)
361
+ }
362
+
363
+ def evaluate_model(self, model, val_loader, class_info) -> Dict:
364
+ """Evaluate model performance"""
365
+ model.eval()
366
+ all_predictions = []
367
+ all_labels = []
368
+ all_confidences = []
369
+
370
+ with torch.no_grad():
371
+ for batch in val_loader:
372
+ images = batch['image'].to(self.device)
373
+ labels = batch['label']
374
+
375
+ outputs = model(images)
376
+ probabilities = torch.softmax(outputs, dim=1)
377
+ confidences, predicted = torch.max(probabilities, 1)
378
+
379
+ all_predictions.extend(predicted.cpu().numpy())
380
+ all_labels.extend(labels.numpy())
381
+ all_confidences.extend(confidences.cpu().numpy())
382
+
383
+ # Calculate metrics manually without sklearn
384
+ accuracy = sum(1 for true, pred in zip(all_labels, all_predictions) if true == pred) / len(all_labels)
385
+
386
+ # Simple precision/recall calculation
387
+ precision = recall = f1 = accuracy # Simplified for now
388
+
389
+ return {
390
+ 'accuracy': float(accuracy),
391
+ 'precision': float(precision),
392
+ 'recall': float(recall),
393
+ 'f1_score': float(f1),
394
+ 'avg_confidence': float(np.mean(all_confidences)),
395
+ 'num_samples': len(all_labels)
396
+ }
397
+
398
+ # Global trainer instance
399
  custom_trainer = CustomTrainer()
backend/database.py CHANGED
@@ -1,678 +1,678 @@
1
- """
2
- Database Module for NAVADA - SQLite storage for faces and objects
3
- Handles storage, retrieval, and management of custom recognition data
4
- """
5
-
6
- import sqlite3
7
- import numpy as np
8
- import cv2
9
- from datetime import datetime
10
- import json
11
- import base64
12
- from typing import List, Dict, Optional, Tuple
13
- import logging
14
- from pathlib import Path
15
-
16
- # Configure logging
17
- logger = logging.getLogger(__name__)
18
-
19
- class NAVADADatabase:
20
- """Database manager for storing faces, objects, and recognition data"""
21
-
22
- def __init__(self, db_path: str = "navada_recognition.db"):
23
- """
24
- Initialize database connection and create tables
25
-
26
- Args:
27
- db_path: Path to SQLite database file
28
- """
29
- self.db_path = db_path
30
- self.init_database()
31
-
32
- def init_database(self):
33
- """Create database tables if they don't exist"""
34
- try:
35
- with sqlite3.connect(self.db_path) as conn:
36
- cursor = conn.cursor()
37
-
38
- # Create faces table
39
- cursor.execute("""
40
- CREATE TABLE IF NOT EXISTS faces (
41
- id INTEGER PRIMARY KEY AUTOINCREMENT,
42
- name TEXT NOT NULL,
43
- encoding BLOB NOT NULL,
44
- image_data BLOB,
45
- confidence REAL DEFAULT 0.0,
46
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
47
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
48
- metadata TEXT,
49
- is_active BOOLEAN DEFAULT 1
50
- )
51
- """)
52
-
53
- # Create objects table
54
- cursor.execute("""
55
- CREATE TABLE IF NOT EXISTS objects (
56
- id INTEGER PRIMARY KEY AUTOINCREMENT,
57
- label TEXT NOT NULL,
58
- category TEXT,
59
- features BLOB,
60
- image_data BLOB,
61
- bounding_box TEXT,
62
- confidence REAL DEFAULT 0.0,
63
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
64
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
65
- metadata TEXT,
66
- is_active BOOLEAN DEFAULT 1
67
- )
68
- """)
69
-
70
- # Create detection history table
71
- cursor.execute("""
72
- CREATE TABLE IF NOT EXISTS detection_history (
73
- id INTEGER PRIMARY KEY AUTOINCREMENT,
74
- session_id TEXT,
75
- image_data BLOB,
76
- detections TEXT,
77
- face_matches TEXT,
78
- object_matches TEXT,
79
- confidence_scores TEXT,
80
- processing_time REAL,
81
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
82
- metadata TEXT
83
- )
84
- """)
85
-
86
- # Create knowledge base for RAG
87
- cursor.execute("""
88
- CREATE TABLE IF NOT EXISTS knowledge_base (
89
- id INTEGER PRIMARY KEY AUTOINCREMENT,
90
- entity_type TEXT NOT NULL,
91
- entity_id INTEGER NOT NULL,
92
- content TEXT NOT NULL,
93
- embedding BLOB,
94
- keywords TEXT,
95
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
96
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
97
- )
98
- """)
99
-
100
- # Create training corrections table for active learning
101
- cursor.execute("""
102
- CREATE TABLE IF NOT EXISTS training_corrections (
103
- id INTEGER PRIMARY KEY AUTOINCREMENT,
104
- image_path TEXT,
105
- image_crop BLOB NOT NULL,
106
- bbox_coords TEXT NOT NULL,
107
- yolo_prediction TEXT NOT NULL,
108
- yolo_confidence REAL NOT NULL,
109
- correct_label TEXT NOT NULL,
110
- user_feedback TEXT,
111
- difficulty_score REAL DEFAULT 0.0,
112
- validated BOOLEAN DEFAULT 0,
113
- used_for_training BOOLEAN DEFAULT 0,
114
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
115
- session_id TEXT,
116
- metadata TEXT
117
- )
118
- """)
119
-
120
- # Create custom model versions table
121
- cursor.execute("""
122
- CREATE TABLE IF NOT EXISTS model_versions (
123
- id INTEGER PRIMARY KEY AUTOINCREMENT,
124
- version_name TEXT NOT NULL UNIQUE,
125
- model_path TEXT NOT NULL,
126
- accuracy REAL,
127
- precision_score REAL,
128
- recall_score REAL,
129
- f1_score REAL,
130
- training_samples INTEGER DEFAULT 0,
131
- validation_samples INTEGER DEFAULT 0,
132
- training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
133
- is_active BOOLEAN DEFAULT 0,
134
- performance_metrics TEXT,
135
- training_config TEXT,
136
- notes TEXT
137
- )
138
- """)
139
-
140
- # Create custom classes mapping
141
- cursor.execute("""
142
- CREATE TABLE IF NOT EXISTS custom_classes (
143
- id INTEGER PRIMARY KEY AUTOINCREMENT,
144
- class_name TEXT NOT NULL UNIQUE,
145
- yolo_class TEXT,
146
- sample_count INTEGER DEFAULT 0,
147
- confidence_threshold REAL DEFAULT 0.5,
148
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
149
- is_active BOOLEAN DEFAULT 1,
150
- description TEXT
151
- )
152
- """)
153
-
154
- # Create indexes for better performance
155
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_faces_name ON faces(name)")
156
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_objects_label ON objects(label)")
157
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_history_session ON detection_history(session_id)")
158
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_knowledge_entity ON knowledge_base(entity_type, entity_id)")
159
-
160
- conn.commit()
161
- logger.info("Database initialized successfully")
162
-
163
- except Exception as e:
164
- logger.error(f"Database initialization failed: {e}")
165
- raise
166
-
167
- def add_face(self, name: str, face_encoding: np.ndarray, image: np.ndarray,
168
- confidence: float = 0.0, metadata: Dict = None) -> int:
169
- """
170
- Add a new face to the database
171
-
172
- Args:
173
- name: Person's name
174
- face_encoding: Face encoding vector
175
- image: Face image array
176
- confidence: Recognition confidence
177
- metadata: Additional metadata
178
-
179
- Returns:
180
- Face ID in database
181
- """
182
- try:
183
- # Encode image to base64
184
- _, buffer = cv2.imencode('.jpg', image)
185
- image_data = base64.b64encode(buffer).decode('utf-8')
186
-
187
- # Serialize face encoding
188
- encoding_data = face_encoding.tobytes()
189
-
190
- # Convert metadata to JSON
191
- metadata_json = json.dumps(metadata) if metadata else None
192
-
193
- with sqlite3.connect(self.db_path) as conn:
194
- cursor = conn.cursor()
195
- cursor.execute("""
196
- INSERT INTO faces (name, encoding, image_data, confidence, metadata)
197
- VALUES (?, ?, ?, ?, ?)
198
- """, (name, encoding_data, image_data, confidence, metadata_json))
199
-
200
- face_id = cursor.lastrowid
201
- conn.commit()
202
-
203
- # Add to knowledge base
204
- self.add_knowledge_entry("face", face_id, f"Person named {name}")
205
-
206
- logger.info(f"Added face for {name} with ID {face_id}")
207
- return face_id
208
-
209
- except Exception as e:
210
- logger.error(f"Failed to add face: {e}")
211
- raise
212
-
213
- def add_object(self, label: str, category: str, features: np.ndarray,
214
- image: np.ndarray, bounding_box: Tuple, confidence: float = 0.0,
215
- metadata: Dict = None) -> int:
216
- """
217
- Add a new custom object to the database
218
-
219
- Args:
220
- label: Object label/name
221
- category: Object category
222
- features: Feature vector
223
- image: Object image
224
- bounding_box: (x, y, w, h) bounding box
225
- confidence: Detection confidence
226
- metadata: Additional metadata
227
-
228
- Returns:
229
- Object ID in database
230
- """
231
- try:
232
- # Encode image to base64
233
- _, buffer = cv2.imencode('.jpg', image)
234
- image_data = base64.b64encode(buffer).decode('utf-8')
235
-
236
- # Serialize features
237
- features_data = features.tobytes() if features is not None else None
238
-
239
- # Serialize bounding box
240
- bbox_json = json.dumps(bounding_box)
241
-
242
- # Convert metadata to JSON
243
- metadata_json = json.dumps(metadata) if metadata else None
244
-
245
- with sqlite3.connect(self.db_path) as conn:
246
- cursor = conn.cursor()
247
- cursor.execute("""
248
- INSERT INTO objects (label, category, features, image_data,
249
- bounding_box, confidence, metadata)
250
- VALUES (?, ?, ?, ?, ?, ?, ?)
251
- """, (label, category, features_data, image_data, bbox_json,
252
- confidence, metadata_json))
253
-
254
- object_id = cursor.lastrowid
255
- conn.commit()
256
-
257
- # Add to knowledge base
258
- self.add_knowledge_entry("object", object_id,
259
- f"{label} - {category} object")
260
-
261
- logger.info(f"Added object {label} with ID {object_id}")
262
- return object_id
263
-
264
- except Exception as e:
265
- logger.error(f"Failed to add object: {e}")
266
- raise
267
-
268
- def get_faces(self, active_only: bool = True) -> List[Dict]:
269
- """Get all faces from database"""
270
- try:
271
- with sqlite3.connect(self.db_path) as conn:
272
- cursor = conn.cursor()
273
- query = "SELECT * FROM faces"
274
- if active_only:
275
- query += " WHERE is_active = 1"
276
-
277
- cursor.execute(query)
278
- rows = cursor.fetchall()
279
-
280
- faces = []
281
- for row in rows:
282
- face = {
283
- 'id': row[0],
284
- 'name': row[1],
285
- 'encoding': np.frombuffer(row[2], dtype=np.float64),
286
- 'confidence': row[4],
287
- 'created_at': row[5],
288
- 'metadata': json.loads(row[7]) if row[7] else {}
289
- }
290
- faces.append(face)
291
-
292
- return faces
293
-
294
- except Exception as e:
295
- logger.error(f"Failed to get faces: {e}")
296
- return []
297
-
298
- def get_objects(self, category: str = None, active_only: bool = True) -> List[Dict]:
299
- """Get objects from database"""
300
- try:
301
- with sqlite3.connect(self.db_path) as conn:
302
- cursor = conn.cursor()
303
- query = "SELECT * FROM objects"
304
- params = []
305
-
306
- conditions = []
307
- if active_only:
308
- conditions.append("is_active = 1")
309
- if category:
310
- conditions.append("category = ?")
311
- params.append(category)
312
-
313
- if conditions:
314
- query += " WHERE " + " AND ".join(conditions)
315
-
316
- cursor.execute(query, params)
317
- rows = cursor.fetchall()
318
-
319
- objects = []
320
- for row in rows:
321
- obj = {
322
- 'id': row[0],
323
- 'label': row[1],
324
- 'category': row[2],
325
- 'features': np.frombuffer(row[3], dtype=np.float64) if row[3] else None,
326
- 'bounding_box': json.loads(row[5]) if row[5] else None,
327
- 'confidence': row[6],
328
- 'created_at': row[7],
329
- 'metadata': json.loads(row[9]) if row[9] else {}
330
- }
331
- objects.append(obj)
332
-
333
- return objects
334
-
335
- except Exception as e:
336
- logger.error(f"Failed to get objects: {e}")
337
- return []
338
-
339
- def save_detection_history(self, session_id: str, image: np.ndarray,
340
- detections: List, face_matches: List = None,
341
- object_matches: List = None, confidence_scores: Dict = None,
342
- processing_time: float = 0.0, metadata: Dict = None) -> int:
343
- """Save detection results to history"""
344
- try:
345
- # Encode image
346
- _, buffer = cv2.imencode('.jpg', image)
347
- image_data = base64.b64encode(buffer).decode('utf-8')
348
-
349
- # Serialize data
350
- detections_json = json.dumps(detections)
351
- face_matches_json = json.dumps(face_matches) if face_matches else None
352
- object_matches_json = json.dumps(object_matches) if object_matches else None
353
- confidence_json = json.dumps(confidence_scores) if confidence_scores else None
354
- metadata_json = json.dumps(metadata) if metadata else None
355
-
356
- with sqlite3.connect(self.db_path) as conn:
357
- cursor = conn.cursor()
358
- cursor.execute("""
359
- INSERT INTO detection_history
360
- (session_id, image_data, detections, face_matches, object_matches,
361
- confidence_scores, processing_time, metadata)
362
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
363
- """, (session_id, image_data, detections_json, face_matches_json,
364
- object_matches_json, confidence_json, processing_time, metadata_json))
365
-
366
- history_id = cursor.lastrowid
367
- conn.commit()
368
-
369
- logger.info(f"Saved detection history with ID {history_id}")
370
- return history_id
371
-
372
- except Exception as e:
373
- logger.error(f"Failed to save detection history: {e}")
374
- raise
375
-
376
- def add_knowledge_entry(self, entity_type: str, entity_id: int, content: str,
377
- keywords: List[str] = None):
378
- """Add entry to knowledge base for RAG"""
379
- try:
380
- keywords_json = json.dumps(keywords) if keywords else None
381
-
382
- with sqlite3.connect(self.db_path) as conn:
383
- cursor = conn.cursor()
384
- cursor.execute("""
385
- INSERT INTO knowledge_base (entity_type, entity_id, content, keywords)
386
- VALUES (?, ?, ?, ?)
387
- """, (entity_type, entity_id, content, keywords_json))
388
- conn.commit()
389
-
390
- except Exception as e:
391
- logger.error(f"Failed to add knowledge entry: {e}")
392
-
393
- def search_knowledge(self, query: str, entity_type: str = None) -> List[Dict]:
394
- """Search knowledge base for RAG"""
395
- try:
396
- with sqlite3.connect(self.db_path) as conn:
397
- cursor = conn.cursor()
398
-
399
- # Simple text search (can be enhanced with embeddings)
400
- search_query = f"%{query.lower()}%"
401
-
402
- if entity_type:
403
- cursor.execute("""
404
- SELECT * FROM knowledge_base
405
- WHERE entity_type = ? AND LOWER(content) LIKE ?
406
- ORDER BY created_at DESC LIMIT 10
407
- """, (entity_type, search_query))
408
- else:
409
- cursor.execute("""
410
- SELECT * FROM knowledge_base
411
- WHERE LOWER(content) LIKE ?
412
- ORDER BY created_at DESC LIMIT 10
413
- """, (search_query,))
414
-
415
- rows = cursor.fetchall()
416
- results = []
417
-
418
- for row in rows:
419
- result = {
420
- 'id': row[0],
421
- 'entity_type': row[1],
422
- 'entity_id': row[2],
423
- 'content': row[3],
424
- 'keywords': json.loads(row[5]) if row[5] else [],
425
- 'created_at': row[6]
426
- }
427
- results.append(result)
428
-
429
- return results
430
-
431
- except Exception as e:
432
- logger.error(f"Knowledge search failed: {e}")
433
- return []
434
-
435
- def get_stats(self) -> Dict:
436
- """Get database statistics"""
437
- try:
438
- with sqlite3.connect(self.db_path) as conn:
439
- cursor = conn.cursor()
440
-
441
- # Count faces
442
- cursor.execute("SELECT COUNT(*) FROM faces WHERE is_active = 1")
443
- face_count = cursor.fetchone()[0]
444
-
445
- # Count objects
446
- cursor.execute("SELECT COUNT(*) FROM objects WHERE is_active = 1")
447
- object_count = cursor.fetchone()[0]
448
-
449
- # Count history entries
450
- cursor.execute("SELECT COUNT(*) FROM detection_history")
451
- history_count = cursor.fetchone()[0]
452
-
453
- # Get recent activity
454
- cursor.execute("""
455
- SELECT COUNT(*) FROM detection_history
456
- WHERE created_at > datetime('now', '-7 days')
457
- """)
458
- recent_detections = cursor.fetchone()[0]
459
-
460
- return {
461
- 'faces': face_count,
462
- 'objects': object_count,
463
- 'total_detections': history_count,
464
- 'recent_detections': recent_detections,
465
- 'database_size': Path(self.db_path).stat().st_size if Path(self.db_path).exists() else 0
466
- }
467
-
468
- except Exception as e:
469
- logger.error(f"Failed to get stats: {e}")
470
- return {}
471
-
472
- # Training Corrections Methods for Active Learning
473
-
474
- def save_correction(self, image_crop: np.ndarray, bbox_coords: List[float],
475
- yolo_prediction: str, yolo_confidence: float,
476
- correct_label: str, user_feedback: str = "",
477
- session_id: str = "") -> bool:
478
- """
479
- Save a user correction for training
480
-
481
- Args:
482
- image_crop: Cropped image of the detected object
483
- bbox_coords: [x1, y1, x2, y2] bounding box coordinates
484
- yolo_prediction: Original YOLO predicted label
485
- yolo_confidence: Original YOLO confidence score
486
- correct_label: User-provided correct label
487
- user_feedback: Optional user feedback text
488
- session_id: Session identifier
489
-
490
- Returns:
491
- bool: Success status
492
- """
493
- try:
494
- # Convert image to bytes
495
- _, buffer = cv2.imencode('.jpg', image_crop)
496
- image_bytes = buffer.tobytes()
497
-
498
- # Calculate difficulty score (lower confidence = higher difficulty)
499
- difficulty_score = 1.0 - yolo_confidence
500
-
501
- with sqlite3.connect(self.db_path) as conn:
502
- cursor = conn.cursor()
503
-
504
- cursor.execute("""
505
- INSERT INTO training_corrections
506
- (image_crop, bbox_coords, yolo_prediction, yolo_confidence,
507
- correct_label, user_feedback, difficulty_score, session_id, metadata)
508
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
509
- """, (
510
- image_bytes,
511
- json.dumps(bbox_coords),
512
- yolo_prediction,
513
- yolo_confidence,
514
- correct_label,
515
- user_feedback,
516
- difficulty_score,
517
- session_id,
518
- json.dumps({
519
- 'timestamp': datetime.now().isoformat(),
520
- 'image_shape': image_crop.shape,
521
- 'correction_type': 'user_feedback'
522
- })
523
- ))
524
-
525
- # Update or create custom class entry
526
- cursor.execute("""
527
- INSERT OR IGNORE INTO custom_classes (class_name, yolo_class, sample_count)
528
- VALUES (?, ?, 0)
529
- """, (correct_label, yolo_prediction))
530
-
531
- cursor.execute("""
532
- UPDATE custom_classes
533
- SET sample_count = sample_count + 1
534
- WHERE class_name = ?
535
- """, (correct_label,))
536
-
537
- return True
538
-
539
- except Exception as e:
540
- logger.error(f"Failed to save correction: {e}")
541
- return False
542
-
543
- def get_training_data(self, class_name: str = None, limit: int = 1000,
544
- validated_only: bool = False) -> List[Dict]:
545
- """
546
- Retrieve training data for model training
547
-
548
- Args:
549
- class_name: Filter by specific class (optional)
550
- limit: Maximum number of samples to return
551
- validated_only: Only return validated corrections
552
-
553
- Returns:
554
- List of training samples
555
- """
556
- try:
557
- with sqlite3.connect(self.db_path) as conn:
558
- cursor = conn.cursor()
559
-
560
- query = """
561
- SELECT id, image_crop, bbox_coords, yolo_prediction,
562
- yolo_confidence, correct_label, difficulty_score,
563
- created_at, metadata
564
- FROM training_corrections
565
- WHERE 1=1
566
- """
567
- params = []
568
-
569
- if class_name:
570
- query += " AND correct_label = ?"
571
- params.append(class_name)
572
-
573
- if validated_only:
574
- query += " AND validated = 1"
575
-
576
- query += " ORDER BY difficulty_score DESC, created_at DESC LIMIT ?"
577
- params.append(limit)
578
-
579
- cursor.execute(query, params)
580
- rows = cursor.fetchall()
581
-
582
- training_data = []
583
- for row in rows:
584
- # Decode image
585
- image_bytes = row[1]
586
- image_array = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
587
-
588
- training_data.append({
589
- 'id': row[0],
590
- 'image': image_array,
591
- 'bbox_coords': json.loads(row[2]),
592
- 'yolo_prediction': row[3],
593
- 'yolo_confidence': row[4],
594
- 'correct_label': row[5],
595
- 'difficulty_score': row[6],
596
- 'created_at': row[7],
597
- 'metadata': json.loads(row[8]) if row[8] else {}
598
- })
599
-
600
- return training_data
601
-
602
- except Exception as e:
603
- logger.error(f"Failed to get training data: {e}")
604
- return []
605
-
606
- def get_training_stats(self) -> Dict:
607
- """Get statistics about training corrections"""
608
- try:
609
- with sqlite3.connect(self.db_path) as conn:
610
- cursor = conn.cursor()
611
-
612
- # Total corrections
613
- cursor.execute("SELECT COUNT(*) FROM training_corrections")
614
- total_corrections = cursor.fetchone()[0]
615
-
616
- # Corrections by class
617
- cursor.execute("""
618
- SELECT correct_label, COUNT(*) as count
619
- FROM training_corrections
620
- GROUP BY correct_label
621
- ORDER BY count DESC
622
- """)
623
- class_counts = dict(cursor.fetchall())
624
-
625
- # Validated corrections
626
- cursor.execute("SELECT COUNT(*) FROM training_corrections WHERE validated = 1")
627
- validated_count = cursor.fetchone()[0]
628
-
629
- # Recent corrections (last 7 days)
630
- cursor.execute("""
631
- SELECT COUNT(*) FROM training_corrections
632
- WHERE created_at > datetime('now', '-7 days')
633
- """)
634
- recent_corrections = cursor.fetchone()[0]
635
-
636
- # Average difficulty score
637
- cursor.execute("SELECT AVG(difficulty_score) FROM training_corrections")
638
- avg_difficulty = cursor.fetchone()[0] or 0.0
639
-
640
- return {
641
- 'total_corrections': total_corrections,
642
- 'validated_corrections': validated_count,
643
- 'recent_corrections': recent_corrections,
644
- 'class_distribution': class_counts,
645
- 'average_difficulty': round(avg_difficulty, 3),
646
- 'unique_classes': len(class_counts)
647
- }
648
-
649
- except Exception as e:
650
- logger.error(f"Failed to get training stats: {e}")
651
- return {}
652
-
653
- def mark_corrections_used(self, correction_ids: List[int]) -> bool:
654
- """Mark corrections as used for training"""
655
- try:
656
- with sqlite3.connect(self.db_path) as conn:
657
- cursor = conn.cursor()
658
-
659
- placeholders = ','.join(['?'] * len(correction_ids))
660
- cursor.execute(f"""
661
- UPDATE training_corrections
662
- SET used_for_training = 1
663
- WHERE id IN ({placeholders})
664
- """, correction_ids)
665
-
666
- return True
667
-
668
- except Exception as e:
669
- logger.error(f"Failed to mark corrections as used: {e}")
670
- return False
671
-
672
- # Global database instance
673
- try:
674
- db = NAVADADatabase()
675
- logger.info("Database instance created successfully")
676
- except Exception as e:
677
- logger.error(f"Failed to create database instance: {e}")
678
  db = None
 
1
+ """
2
+ Database Module for NAVADA - SQLite storage for faces and objects
3
+ Handles storage, retrieval, and management of custom recognition data
4
+ """
5
+
6
+ import sqlite3
7
+ import numpy as np
8
+ import cv2
9
+ from datetime import datetime
10
+ import json
11
+ import base64
12
+ from typing import List, Dict, Optional, Tuple
13
+ import logging
14
+ from pathlib import Path
15
+
16
+ # Configure logging
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class NAVADADatabase:
20
+ """Database manager for storing faces, objects, and recognition data"""
21
+
22
+ def __init__(self, db_path: str = "navada_recognition.db"):
23
+ """
24
+ Initialize database connection and create tables
25
+
26
+ Args:
27
+ db_path: Path to SQLite database file
28
+ """
29
+ self.db_path = db_path
30
+ self.init_database()
31
+
32
+ def init_database(self):
33
+ """Create database tables if they don't exist"""
34
+ try:
35
+ with sqlite3.connect(self.db_path) as conn:
36
+ cursor = conn.cursor()
37
+
38
+ # Create faces table
39
+ cursor.execute("""
40
+ CREATE TABLE IF NOT EXISTS faces (
41
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
42
+ name TEXT NOT NULL,
43
+ encoding BLOB NOT NULL,
44
+ image_data BLOB,
45
+ confidence REAL DEFAULT 0.0,
46
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
47
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
48
+ metadata TEXT,
49
+ is_active BOOLEAN DEFAULT 1
50
+ )
51
+ """)
52
+
53
+ # Create objects table
54
+ cursor.execute("""
55
+ CREATE TABLE IF NOT EXISTS objects (
56
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
57
+ label TEXT NOT NULL,
58
+ category TEXT,
59
+ features BLOB,
60
+ image_data BLOB,
61
+ bounding_box TEXT,
62
+ confidence REAL DEFAULT 0.0,
63
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
64
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
65
+ metadata TEXT,
66
+ is_active BOOLEAN DEFAULT 1
67
+ )
68
+ """)
69
+
70
+ # Create detection history table
71
+ cursor.execute("""
72
+ CREATE TABLE IF NOT EXISTS detection_history (
73
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
74
+ session_id TEXT,
75
+ image_data BLOB,
76
+ detections TEXT,
77
+ face_matches TEXT,
78
+ object_matches TEXT,
79
+ confidence_scores TEXT,
80
+ processing_time REAL,
81
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
82
+ metadata TEXT
83
+ )
84
+ """)
85
+
86
+ # Create knowledge base for RAG
87
+ cursor.execute("""
88
+ CREATE TABLE IF NOT EXISTS knowledge_base (
89
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
90
+ entity_type TEXT NOT NULL,
91
+ entity_id INTEGER NOT NULL,
92
+ content TEXT NOT NULL,
93
+ embedding BLOB,
94
+ keywords TEXT,
95
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
96
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
97
+ )
98
+ """)
99
+
100
+ # Create training corrections table for active learning
101
+ cursor.execute("""
102
+ CREATE TABLE IF NOT EXISTS training_corrections (
103
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
104
+ image_path TEXT,
105
+ image_crop BLOB NOT NULL,
106
+ bbox_coords TEXT NOT NULL,
107
+ yolo_prediction TEXT NOT NULL,
108
+ yolo_confidence REAL NOT NULL,
109
+ correct_label TEXT NOT NULL,
110
+ user_feedback TEXT,
111
+ difficulty_score REAL DEFAULT 0.0,
112
+ validated BOOLEAN DEFAULT 0,
113
+ used_for_training BOOLEAN DEFAULT 0,
114
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
115
+ session_id TEXT,
116
+ metadata TEXT
117
+ )
118
+ """)
119
+
120
+ # Create custom model versions table
121
+ cursor.execute("""
122
+ CREATE TABLE IF NOT EXISTS model_versions (
123
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
124
+ version_name TEXT NOT NULL UNIQUE,
125
+ model_path TEXT NOT NULL,
126
+ accuracy REAL,
127
+ precision_score REAL,
128
+ recall_score REAL,
129
+ f1_score REAL,
130
+ training_samples INTEGER DEFAULT 0,
131
+ validation_samples INTEGER DEFAULT 0,
132
+ training_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
133
+ is_active BOOLEAN DEFAULT 0,
134
+ performance_metrics TEXT,
135
+ training_config TEXT,
136
+ notes TEXT
137
+ )
138
+ """)
139
+
140
+ # Create custom classes mapping
141
+ cursor.execute("""
142
+ CREATE TABLE IF NOT EXISTS custom_classes (
143
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
144
+ class_name TEXT NOT NULL UNIQUE,
145
+ yolo_class TEXT,
146
+ sample_count INTEGER DEFAULT 0,
147
+ confidence_threshold REAL DEFAULT 0.5,
148
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
149
+ is_active BOOLEAN DEFAULT 1,
150
+ description TEXT
151
+ )
152
+ """)
153
+
154
+ # Create indexes for better performance
155
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_faces_name ON faces(name)")
156
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_objects_label ON objects(label)")
157
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_history_session ON detection_history(session_id)")
158
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_knowledge_entity ON knowledge_base(entity_type, entity_id)")
159
+
160
+ conn.commit()
161
+ logger.info("Database initialized successfully")
162
+
163
+ except Exception as e:
164
+ logger.error(f"Database initialization failed: {e}")
165
+ raise
166
+
167
+ def add_face(self, name: str, face_encoding: np.ndarray, image: np.ndarray,
168
+ confidence: float = 0.0, metadata: Dict = None) -> int:
169
+ """
170
+ Add a new face to the database
171
+
172
+ Args:
173
+ name: Person's name
174
+ face_encoding: Face encoding vector
175
+ image: Face image array
176
+ confidence: Recognition confidence
177
+ metadata: Additional metadata
178
+
179
+ Returns:
180
+ Face ID in database
181
+ """
182
+ try:
183
+ # Encode image to base64
184
+ _, buffer = cv2.imencode('.jpg', image)
185
+ image_data = base64.b64encode(buffer).decode('utf-8')
186
+
187
+ # Serialize face encoding
188
+ encoding_data = face_encoding.tobytes()
189
+
190
+ # Convert metadata to JSON
191
+ metadata_json = json.dumps(metadata) if metadata else None
192
+
193
+ with sqlite3.connect(self.db_path) as conn:
194
+ cursor = conn.cursor()
195
+ cursor.execute("""
196
+ INSERT INTO faces (name, encoding, image_data, confidence, metadata)
197
+ VALUES (?, ?, ?, ?, ?)
198
+ """, (name, encoding_data, image_data, confidence, metadata_json))
199
+
200
+ face_id = cursor.lastrowid
201
+ conn.commit()
202
+
203
+ # Add to knowledge base
204
+ self.add_knowledge_entry("face", face_id, f"Person named {name}")
205
+
206
+ logger.info(f"Added face for {name} with ID {face_id}")
207
+ return face_id
208
+
209
+ except Exception as e:
210
+ logger.error(f"Failed to add face: {e}")
211
+ raise
212
+
213
+ def add_object(self, label: str, category: str, features: np.ndarray,
214
+ image: np.ndarray, bounding_box: Tuple, confidence: float = 0.0,
215
+ metadata: Dict = None) -> int:
216
+ """
217
+ Add a new custom object to the database
218
+
219
+ Args:
220
+ label: Object label/name
221
+ category: Object category
222
+ features: Feature vector
223
+ image: Object image
224
+ bounding_box: (x, y, w, h) bounding box
225
+ confidence: Detection confidence
226
+ metadata: Additional metadata
227
+
228
+ Returns:
229
+ Object ID in database
230
+ """
231
+ try:
232
+ # Encode image to base64
233
+ _, buffer = cv2.imencode('.jpg', image)
234
+ image_data = base64.b64encode(buffer).decode('utf-8')
235
+
236
+ # Serialize features
237
+ features_data = features.tobytes() if features is not None else None
238
+
239
+ # Serialize bounding box
240
+ bbox_json = json.dumps(bounding_box)
241
+
242
+ # Convert metadata to JSON
243
+ metadata_json = json.dumps(metadata) if metadata else None
244
+
245
+ with sqlite3.connect(self.db_path) as conn:
246
+ cursor = conn.cursor()
247
+ cursor.execute("""
248
+ INSERT INTO objects (label, category, features, image_data,
249
+ bounding_box, confidence, metadata)
250
+ VALUES (?, ?, ?, ?, ?, ?, ?)
251
+ """, (label, category, features_data, image_data, bbox_json,
252
+ confidence, metadata_json))
253
+
254
+ object_id = cursor.lastrowid
255
+ conn.commit()
256
+
257
+ # Add to knowledge base
258
+ self.add_knowledge_entry("object", object_id,
259
+ f"{label} - {category} object")
260
+
261
+ logger.info(f"Added object {label} with ID {object_id}")
262
+ return object_id
263
+
264
+ except Exception as e:
265
+ logger.error(f"Failed to add object: {e}")
266
+ raise
267
+
268
+ def get_faces(self, active_only: bool = True) -> List[Dict]:
269
+ """Get all faces from database"""
270
+ try:
271
+ with sqlite3.connect(self.db_path) as conn:
272
+ cursor = conn.cursor()
273
+ query = "SELECT * FROM faces"
274
+ if active_only:
275
+ query += " WHERE is_active = 1"
276
+
277
+ cursor.execute(query)
278
+ rows = cursor.fetchall()
279
+
280
+ faces = []
281
+ for row in rows:
282
+ face = {
283
+ 'id': row[0],
284
+ 'name': row[1],
285
+ 'encoding': np.frombuffer(row[2], dtype=np.float64),
286
+ 'confidence': row[4],
287
+ 'created_at': row[5],
288
+ 'metadata': json.loads(row[7]) if row[7] else {}
289
+ }
290
+ faces.append(face)
291
+
292
+ return faces
293
+
294
+ except Exception as e:
295
+ logger.error(f"Failed to get faces: {e}")
296
+ return []
297
+
298
+ def get_objects(self, category: str = None, active_only: bool = True) -> List[Dict]:
299
+ """Get objects from database"""
300
+ try:
301
+ with sqlite3.connect(self.db_path) as conn:
302
+ cursor = conn.cursor()
303
+ query = "SELECT * FROM objects"
304
+ params = []
305
+
306
+ conditions = []
307
+ if active_only:
308
+ conditions.append("is_active = 1")
309
+ if category:
310
+ conditions.append("category = ?")
311
+ params.append(category)
312
+
313
+ if conditions:
314
+ query += " WHERE " + " AND ".join(conditions)
315
+
316
+ cursor.execute(query, params)
317
+ rows = cursor.fetchall()
318
+
319
+ objects = []
320
+ for row in rows:
321
+ obj = {
322
+ 'id': row[0],
323
+ 'label': row[1],
324
+ 'category': row[2],
325
+ 'features': np.frombuffer(row[3], dtype=np.float64) if row[3] else None,
326
+ 'bounding_box': json.loads(row[5]) if row[5] else None,
327
+ 'confidence': row[6],
328
+ 'created_at': row[7],
329
+ 'metadata': json.loads(row[9]) if row[9] else {}
330
+ }
331
+ objects.append(obj)
332
+
333
+ return objects
334
+
335
+ except Exception as e:
336
+ logger.error(f"Failed to get objects: {e}")
337
+ return []
338
+
339
+ def save_detection_history(self, session_id: str, image: np.ndarray,
340
+ detections: List, face_matches: List = None,
341
+ object_matches: List = None, confidence_scores: Dict = None,
342
+ processing_time: float = 0.0, metadata: Dict = None) -> int:
343
+ """Save detection results to history"""
344
+ try:
345
+ # Encode image
346
+ _, buffer = cv2.imencode('.jpg', image)
347
+ image_data = base64.b64encode(buffer).decode('utf-8')
348
+
349
+ # Serialize data
350
+ detections_json = json.dumps(detections)
351
+ face_matches_json = json.dumps(face_matches) if face_matches else None
352
+ object_matches_json = json.dumps(object_matches) if object_matches else None
353
+ confidence_json = json.dumps(confidence_scores) if confidence_scores else None
354
+ metadata_json = json.dumps(metadata) if metadata else None
355
+
356
+ with sqlite3.connect(self.db_path) as conn:
357
+ cursor = conn.cursor()
358
+ cursor.execute("""
359
+ INSERT INTO detection_history
360
+ (session_id, image_data, detections, face_matches, object_matches,
361
+ confidence_scores, processing_time, metadata)
362
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
363
+ """, (session_id, image_data, detections_json, face_matches_json,
364
+ object_matches_json, confidence_json, processing_time, metadata_json))
365
+
366
+ history_id = cursor.lastrowid
367
+ conn.commit()
368
+
369
+ logger.info(f"Saved detection history with ID {history_id}")
370
+ return history_id
371
+
372
+ except Exception as e:
373
+ logger.error(f"Failed to save detection history: {e}")
374
+ raise
375
+
376
+ def add_knowledge_entry(self, entity_type: str, entity_id: int, content: str,
377
+ keywords: List[str] = None):
378
+ """Add entry to knowledge base for RAG"""
379
+ try:
380
+ keywords_json = json.dumps(keywords) if keywords else None
381
+
382
+ with sqlite3.connect(self.db_path) as conn:
383
+ cursor = conn.cursor()
384
+ cursor.execute("""
385
+ INSERT INTO knowledge_base (entity_type, entity_id, content, keywords)
386
+ VALUES (?, ?, ?, ?)
387
+ """, (entity_type, entity_id, content, keywords_json))
388
+ conn.commit()
389
+
390
+ except Exception as e:
391
+ logger.error(f"Failed to add knowledge entry: {e}")
392
+
393
+ def search_knowledge(self, query: str, entity_type: str = None) -> List[Dict]:
394
+ """Search knowledge base for RAG"""
395
+ try:
396
+ with sqlite3.connect(self.db_path) as conn:
397
+ cursor = conn.cursor()
398
+
399
+ # Simple text search (can be enhanced with embeddings)
400
+ search_query = f"%{query.lower()}%"
401
+
402
+ if entity_type:
403
+ cursor.execute("""
404
+ SELECT * FROM knowledge_base
405
+ WHERE entity_type = ? AND LOWER(content) LIKE ?
406
+ ORDER BY created_at DESC LIMIT 10
407
+ """, (entity_type, search_query))
408
+ else:
409
+ cursor.execute("""
410
+ SELECT * FROM knowledge_base
411
+ WHERE LOWER(content) LIKE ?
412
+ ORDER BY created_at DESC LIMIT 10
413
+ """, (search_query,))
414
+
415
+ rows = cursor.fetchall()
416
+ results = []
417
+
418
+ for row in rows:
419
+ result = {
420
+ 'id': row[0],
421
+ 'entity_type': row[1],
422
+ 'entity_id': row[2],
423
+ 'content': row[3],
424
+ 'keywords': json.loads(row[5]) if row[5] else [],
425
+ 'created_at': row[6]
426
+ }
427
+ results.append(result)
428
+
429
+ return results
430
+
431
+ except Exception as e:
432
+ logger.error(f"Knowledge search failed: {e}")
433
+ return []
434
+
435
+ def get_stats(self) -> Dict:
436
+ """Get database statistics"""
437
+ try:
438
+ with sqlite3.connect(self.db_path) as conn:
439
+ cursor = conn.cursor()
440
+
441
+ # Count faces
442
+ cursor.execute("SELECT COUNT(*) FROM faces WHERE is_active = 1")
443
+ face_count = cursor.fetchone()[0]
444
+
445
+ # Count objects
446
+ cursor.execute("SELECT COUNT(*) FROM objects WHERE is_active = 1")
447
+ object_count = cursor.fetchone()[0]
448
+
449
+ # Count history entries
450
+ cursor.execute("SELECT COUNT(*) FROM detection_history")
451
+ history_count = cursor.fetchone()[0]
452
+
453
+ # Get recent activity
454
+ cursor.execute("""
455
+ SELECT COUNT(*) FROM detection_history
456
+ WHERE created_at > datetime('now', '-7 days')
457
+ """)
458
+ recent_detections = cursor.fetchone()[0]
459
+
460
+ return {
461
+ 'faces': face_count,
462
+ 'objects': object_count,
463
+ 'total_detections': history_count,
464
+ 'recent_detections': recent_detections,
465
+ 'database_size': Path(self.db_path).stat().st_size if Path(self.db_path).exists() else 0
466
+ }
467
+
468
+ except Exception as e:
469
+ logger.error(f"Failed to get stats: {e}")
470
+ return {}
471
+
472
+ # Training Corrections Methods for Active Learning
473
+
474
+ def save_correction(self, image_crop: np.ndarray, bbox_coords: List[float],
475
+ yolo_prediction: str, yolo_confidence: float,
476
+ correct_label: str, user_feedback: str = "",
477
+ session_id: str = "") -> bool:
478
+ """
479
+ Save a user correction for training
480
+
481
+ Args:
482
+ image_crop: Cropped image of the detected object
483
+ bbox_coords: [x1, y1, x2, y2] bounding box coordinates
484
+ yolo_prediction: Original YOLO predicted label
485
+ yolo_confidence: Original YOLO confidence score
486
+ correct_label: User-provided correct label
487
+ user_feedback: Optional user feedback text
488
+ session_id: Session identifier
489
+
490
+ Returns:
491
+ bool: Success status
492
+ """
493
+ try:
494
+ # Convert image to bytes
495
+ _, buffer = cv2.imencode('.jpg', image_crop)
496
+ image_bytes = buffer.tobytes()
497
+
498
+ # Calculate difficulty score (lower confidence = higher difficulty)
499
+ difficulty_score = 1.0 - yolo_confidence
500
+
501
+ with sqlite3.connect(self.db_path) as conn:
502
+ cursor = conn.cursor()
503
+
504
+ cursor.execute("""
505
+ INSERT INTO training_corrections
506
+ (image_crop, bbox_coords, yolo_prediction, yolo_confidence,
507
+ correct_label, user_feedback, difficulty_score, session_id, metadata)
508
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
509
+ """, (
510
+ image_bytes,
511
+ json.dumps(bbox_coords),
512
+ yolo_prediction,
513
+ yolo_confidence,
514
+ correct_label,
515
+ user_feedback,
516
+ difficulty_score,
517
+ session_id,
518
+ json.dumps({
519
+ 'timestamp': datetime.now().isoformat(),
520
+ 'image_shape': image_crop.shape,
521
+ 'correction_type': 'user_feedback'
522
+ })
523
+ ))
524
+
525
+ # Update or create custom class entry
526
+ cursor.execute("""
527
+ INSERT OR IGNORE INTO custom_classes (class_name, yolo_class, sample_count)
528
+ VALUES (?, ?, 0)
529
+ """, (correct_label, yolo_prediction))
530
+
531
+ cursor.execute("""
532
+ UPDATE custom_classes
533
+ SET sample_count = sample_count + 1
534
+ WHERE class_name = ?
535
+ """, (correct_label,))
536
+
537
+ return True
538
+
539
+ except Exception as e:
540
+ logger.error(f"Failed to save correction: {e}")
541
+ return False
542
+
543
+ def get_training_data(self, class_name: str = None, limit: int = 1000,
544
+ validated_only: bool = False) -> List[Dict]:
545
+ """
546
+ Retrieve training data for model training
547
+
548
+ Args:
549
+ class_name: Filter by specific class (optional)
550
+ limit: Maximum number of samples to return
551
+ validated_only: Only return validated corrections
552
+
553
+ Returns:
554
+ List of training samples
555
+ """
556
+ try:
557
+ with sqlite3.connect(self.db_path) as conn:
558
+ cursor = conn.cursor()
559
+
560
+ query = """
561
+ SELECT id, image_crop, bbox_coords, yolo_prediction,
562
+ yolo_confidence, correct_label, difficulty_score,
563
+ created_at, metadata
564
+ FROM training_corrections
565
+ WHERE 1=1
566
+ """
567
+ params = []
568
+
569
+ if class_name:
570
+ query += " AND correct_label = ?"
571
+ params.append(class_name)
572
+
573
+ if validated_only:
574
+ query += " AND validated = 1"
575
+
576
+ query += " ORDER BY difficulty_score DESC, created_at DESC LIMIT ?"
577
+ params.append(limit)
578
+
579
+ cursor.execute(query, params)
580
+ rows = cursor.fetchall()
581
+
582
+ training_data = []
583
+ for row in rows:
584
+ # Decode image
585
+ image_bytes = row[1]
586
+ image_array = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
587
+
588
+ training_data.append({
589
+ 'id': row[0],
590
+ 'image': image_array,
591
+ 'bbox_coords': json.loads(row[2]),
592
+ 'yolo_prediction': row[3],
593
+ 'yolo_confidence': row[4],
594
+ 'correct_label': row[5],
595
+ 'difficulty_score': row[6],
596
+ 'created_at': row[7],
597
+ 'metadata': json.loads(row[8]) if row[8] else {}
598
+ })
599
+
600
+ return training_data
601
+
602
+ except Exception as e:
603
+ logger.error(f"Failed to get training data: {e}")
604
+ return []
605
+
606
+ def get_training_stats(self) -> Dict:
607
+ """Get statistics about training corrections"""
608
+ try:
609
+ with sqlite3.connect(self.db_path) as conn:
610
+ cursor = conn.cursor()
611
+
612
+ # Total corrections
613
+ cursor.execute("SELECT COUNT(*) FROM training_corrections")
614
+ total_corrections = cursor.fetchone()[0]
615
+
616
+ # Corrections by class
617
+ cursor.execute("""
618
+ SELECT correct_label, COUNT(*) as count
619
+ FROM training_corrections
620
+ GROUP BY correct_label
621
+ ORDER BY count DESC
622
+ """)
623
+ class_counts = dict(cursor.fetchall())
624
+
625
+ # Validated corrections
626
+ cursor.execute("SELECT COUNT(*) FROM training_corrections WHERE validated = 1")
627
+ validated_count = cursor.fetchone()[0]
628
+
629
+ # Recent corrections (last 7 days)
630
+ cursor.execute("""
631
+ SELECT COUNT(*) FROM training_corrections
632
+ WHERE created_at > datetime('now', '-7 days')
633
+ """)
634
+ recent_corrections = cursor.fetchone()[0]
635
+
636
+ # Average difficulty score
637
+ cursor.execute("SELECT AVG(difficulty_score) FROM training_corrections")
638
+ avg_difficulty = cursor.fetchone()[0] or 0.0
639
+
640
+ return {
641
+ 'total_corrections': total_corrections,
642
+ 'validated_corrections': validated_count,
643
+ 'recent_corrections': recent_corrections,
644
+ 'class_distribution': class_counts,
645
+ 'average_difficulty': round(avg_difficulty, 3),
646
+ 'unique_classes': len(class_counts)
647
+ }
648
+
649
+ except Exception as e:
650
+ logger.error(f"Failed to get training stats: {e}")
651
+ return {}
652
+
653
+ def mark_corrections_used(self, correction_ids: List[int]) -> bool:
654
+ """Mark corrections as used for training"""
655
+ try:
656
+ with sqlite3.connect(self.db_path) as conn:
657
+ cursor = conn.cursor()
658
+
659
+ placeholders = ','.join(['?'] * len(correction_ids))
660
+ cursor.execute(f"""
661
+ UPDATE training_corrections
662
+ SET used_for_training = 1
663
+ WHERE id IN ({placeholders})
664
+ """, correction_ids)
665
+
666
+ return True
667
+
668
+ except Exception as e:
669
+ logger.error(f"Failed to mark corrections as used: {e}")
670
+ return False
671
+
672
+ # Global database instance
673
+ try:
674
+ db = NAVADADatabase()
675
+ logger.info("Database instance created successfully")
676
+ except Exception as e:
677
+ logger.error(f"Failed to create database instance: {e}")
678
  db = None
backend/face_detection.py CHANGED
@@ -1,299 +1,299 @@
1
- """
2
- Face Detection Module for NAVADA
3
- This module provides face detection capabilities using OpenCV's Haar Cascades.
4
- It can detect faces, eyes, and smiles in images and return detailed statistics.
5
- """
6
-
7
- import cv2 # OpenCV library for computer vision tasks
8
- import numpy as np # NumPy for numerical operations on arrays
9
- from typing import Tuple, List, Dict, Optional, Union # Type hints for better code documentation
10
- import os # Operating system interface for file path operations
11
- import logging # Logging module for error tracking
12
-
13
- # Configure logging for this module
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class FaceDetector:
18
- """
19
- A class to handle face detection using OpenCV's Haar Cascade classifiers.
20
- This detector can identify faces, eyes, and smiles in images.
21
- """
22
-
23
- def __init__(self):
24
- """
25
- Initialize the FaceDetector with pre-trained Haar Cascade classifiers.
26
- Loads classifiers for face, eye, and smile detection.
27
- """
28
- try:
29
- # Load the pre-trained Haar Cascade classifier for frontal face detection
30
- # This XML file contains trained patterns for detecting frontal faces
31
- self.face_cascade = cv2.CascadeClassifier(
32
- cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
33
- )
34
-
35
- # Load the classifier for eye detection
36
- # This works best when applied to face regions
37
- self.eye_cascade = cv2.CascadeClassifier(
38
- cv2.data.haarcascades + 'haarcascade_eye.xml'
39
- )
40
-
41
- # Load the classifier for smile detection
42
- # This detects smiling expressions in face regions
43
- self.smile_cascade = cv2.CascadeClassifier(
44
- cv2.data.haarcascades + 'haarcascade_smile.xml'
45
- )
46
-
47
- # Verify that classifiers loaded successfully
48
- if self.face_cascade.empty():
49
- raise ValueError("Failed to load face cascade classifier")
50
- if self.eye_cascade.empty():
51
- raise ValueError("Failed to load eye cascade classifier")
52
- if self.smile_cascade.empty():
53
- raise ValueError("Failed to load smile cascade classifier")
54
-
55
- logger.info("Face detection classifiers loaded successfully")
56
-
57
- except Exception as e:
58
- logger.error(f"Error initializing face detector: {str(e)}")
59
- raise
60
-
61
- def detect_faces(self, image: Optional[np.ndarray]) -> Tuple[np.ndarray, Dict]:
62
- """
63
- Detect faces in an image and return an annotated image with statistics.
64
-
65
- Parameters:
66
- -----------
67
- image : np.ndarray
68
- Input image as a NumPy array (can be grayscale or color)
69
-
70
- Returns:
71
- --------
72
- Tuple[np.ndarray, Dict]
73
- - Annotated image with face detection boxes and labels
74
- - Dictionary containing detection statistics and face details
75
- """
76
- # Input validation - check if image is provided and valid
77
- if image is None:
78
- logger.warning("No image provided for face detection")
79
- return np.zeros((480, 640, 3), dtype=np.uint8), {
80
- 'total_faces': 0,
81
- 'faces': [],
82
- 'detection_method': 'Haar Cascade',
83
- 'features_detected': {'eyes': 0, 'smiles': 0}
84
- }
85
-
86
- # Ensure image is a numpy array
87
- if not isinstance(image, np.ndarray):
88
- logger.error("Image must be a numpy array")
89
- raise TypeError("Image must be a numpy array")
90
-
91
- # Check if image is empty
92
- if image.size == 0:
93
- logger.warning("Empty image provided")
94
- return image, {
95
- 'total_faces': 0,
96
- 'faces': [],
97
- 'detection_method': 'Haar Cascade',
98
- 'features_detected': {'eyes': 0, 'smiles': 0}
99
- }
100
-
101
- # Convert grayscale images to RGB for consistent output
102
- # Check the number of dimensions to determine if image is grayscale
103
- if len(image.shape) == 2: # Grayscale image (height, width)
104
- img_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
105
- elif len(image.shape) == 3: # Color image (height, width, channels)
106
- img_rgb = image.copy() # Create a copy to avoid modifying original
107
- else:
108
- logger.error(f"Invalid image shape: {image.shape}")
109
- raise ValueError(f"Invalid image shape: {image.shape}")
110
-
111
- # Convert to grayscale for detection algorithms
112
- # Haar Cascades work on grayscale images for better performance
113
- try:
114
- gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
115
- except cv2.error as e:
116
- logger.error(f"Error converting image to grayscale: {str(e)}")
117
- return img_rgb, {
118
- 'total_faces': 0,
119
- 'faces': [],
120
- 'detection_method': 'Haar Cascade',
121
- 'features_detected': {'eyes': 0, 'smiles': 0}
122
- }
123
-
124
- # Detect faces using the Haar Cascade classifier
125
- # Parameters control detection sensitivity and performance
126
- faces = self.face_cascade.detectMultiScale(
127
- gray, # Grayscale image to search
128
- scaleFactor=1.1, # Image pyramid scaling factor (1.1 = 10% reduction each level)
129
- minNeighbors=5, # Minimum neighbors for detection confidence
130
- minSize=(30, 30) # Minimum face size in pixels
131
- )
132
-
133
- # List to store detailed information about each detected face
134
- face_details = []
135
-
136
- # Process each detected face
137
- for idx, (x, y, w, h) in enumerate(faces):
138
- # Draw a magenta rectangle around the detected face
139
- # Parameters: image, top-left corner, bottom-right corner, color (BGR), thickness
140
- cv2.rectangle(img_rgb, (x, y), (x+w, y+h), (255, 0, 255), 3)
141
-
142
- # Add a label above each face
143
- # Parameters: image, text, position, font, scale, color, thickness
144
- cv2.putText(
145
- img_rgb,
146
- f"Face {idx+1}", # Label text
147
- (x, y-10), # Position (above the rectangle)
148
- cv2.FONT_HERSHEY_SIMPLEX, # Font type
149
- 0.7, # Font scale
150
- (255, 0, 255), # Color (magenta in RGB)
151
- 2 # Thickness
152
- )
153
-
154
- # Extract Region of Interest (ROI) for face area
155
- # This isolates the face region for feature detection
156
- roi_gray = gray[y:y+h, x:x+w] # Grayscale ROI for detection
157
- roi_color = img_rgb[y:y+h, x:x+w] # Color ROI for drawing
158
-
159
- # Detect eyes within the face region
160
- # Using different parameters for eye detection (more sensitive)
161
- eyes = self.eye_cascade.detectMultiScale(
162
- roi_gray,
163
- scaleFactor=1.05, # Smaller scale factor for finer detection
164
- minNeighbors=3 # Fewer neighbors required
165
- )
166
- eye_count = len(eyes) # Count the number of detected eyes
167
-
168
- # Draw green rectangles around detected eyes
169
- for (ex, ey, ew, eh) in eyes:
170
- cv2.rectangle(
171
- roi_color, # Draw on the color ROI
172
- (ex, ey), # Top-left corner
173
- (ex+ew, ey+eh), # Bottom-right corner
174
- (0, 255, 0), # Green color in RGB
175
- 2 # Thickness
176
- )
177
-
178
- # Detect smiles within the face region
179
- # Smile detection requires different parameters
180
- smiles = self.smile_cascade.detectMultiScale(
181
- roi_gray,
182
- scaleFactor=1.8, # Larger scale factor for smile detection
183
- minNeighbors=20 # More neighbors required for confidence
184
- )
185
- has_smile = len(smiles) > 0 # Boolean flag for smile presence
186
-
187
- # Draw yellow rectangles around detected smiles
188
- for (sx, sy, sw, sh) in smiles:
189
- cv2.rectangle(
190
- roi_color, # Draw on the color ROI
191
- (sx, sy), # Top-left corner
192
- (sx+sw, sy+sh), # Bottom-right corner
193
- (0, 255, 255), # Yellow color in RGB
194
- 2 # Thickness
195
- )
196
-
197
- # Store detailed information about this face
198
- face_details.append({
199
- 'face_id': idx + 1, # Sequential face ID starting from 1
200
- 'position': { # Face bounding box coordinates
201
- 'x': int(x), # X coordinate of top-left corner
202
- 'y': int(y), # Y coordinate of top-left corner
203
- 'width': int(w), # Width of face bounding box
204
- 'height': int(h) # Height of face bounding box
205
- },
206
- 'eyes_detected': eye_count, # Number of eyes detected
207
- 'smile_detected': has_smile, # Whether a smile was detected
208
- 'confidence': 0.95 # Placeholder confidence score
209
- })
210
-
211
- # Compile comprehensive statistics about all detected faces
212
- stats = {
213
- 'total_faces': len(faces), # Total number of faces detected
214
- 'faces': face_details, # List of detailed face information
215
- 'detection_method': 'Haar Cascade', # Method used for detection
216
- 'features_detected': { # Aggregate feature statistics
217
- 'eyes': sum(f['eyes_detected'] for f in face_details), # Total eyes
218
- 'smiles': sum(1 for f in face_details if f['smile_detected']) # Total smiles
219
- }
220
- }
221
-
222
- return img_rgb, stats # Return annotated image and statistics
223
-
224
- def analyze_demographics(self, face_stats: Optional[Dict]) -> str:
225
- """
226
- Create a demographic analysis report based on face detection statistics.
227
-
228
- Parameters:
229
- -----------
230
- face_stats : Dict
231
- Dictionary containing face detection statistics
232
-
233
- Returns:
234
- --------
235
- str
236
- Formatted text analysis of detected faces and their features
237
- """
238
- # Handle case where no statistics are provided
239
- if not face_stats:
240
- return "No face detection data available."
241
-
242
- # Handle case where no faces were detected
243
- if face_stats.get('total_faces', 0) == 0:
244
- return "No faces detected in the image."
245
-
246
- # Build analysis report
247
- analysis = [] # List to accumulate analysis text
248
-
249
- # Add header with total face count
250
- analysis.append(f"👥 Detected {face_stats['total_faces']} face(s) in the image\n")
251
-
252
- # Add detailed information for each face
253
- for face in face_stats.get('faces', []):
254
- # Create description for individual face
255
- face_desc = f"\n**Face {face['face_id']}:**"
256
-
257
- # Add position information
258
- pos = face.get('position', {})
259
- face_desc += f"\n • Position: ({pos.get('x', 0)}, {pos.get('y', 0)})"
260
-
261
- # Add size information
262
- face_desc += f"\n • Size: {pos.get('width', 0)}x{pos.get('height', 0)} pixels"
263
-
264
- # Add eye detection information if eyes were found
265
- if face.get('eyes_detected', 0) > 0:
266
- face_desc += f"\n • Eyes detected: {face['eyes_detected']}"
267
-
268
- # Add smile detection information
269
- if face.get('smile_detected', False):
270
- face_desc += "\n • 😊 Smile detected!"
271
-
272
- analysis.append(face_desc) # Add face description to analysis
273
-
274
- # Add summary statistics if smiles were detected
275
- features = face_stats.get('features_detected', {})
276
- smile_count = features.get('smiles', 0)
277
-
278
- if smile_count > 0:
279
- # Calculate percentage of faces that are smiling
280
- smile_ratio = (smile_count / face_stats['total_faces']) * 100
281
-
282
- # Add overall analysis section
283
- analysis.append(f"\n\n📊 **Overall Analysis:**")
284
- analysis.append(f"\n • {smile_ratio:.0f}% of faces are smiling")
285
- analysis.append(f"\n • Total eyes detected: {features.get('eyes', 0)}")
286
-
287
- # Join all analysis parts and return
288
- return "".join(analysis)
289
-
290
-
291
- # Create a global instance of FaceDetector for use throughout the application
292
- # This avoids reloading classifiers multiple times
293
- try:
294
- face_detector = FaceDetector()
295
- logger.info("Global face detector initialized successfully")
296
- except Exception as e:
297
- logger.error(f"Failed to initialize global face detector: {str(e)}")
298
- # Create a dummy detector that returns empty results
299
  face_detector = None
 
1
+ """
2
+ Face Detection Module for NAVADA
3
+ This module provides face detection capabilities using OpenCV's Haar Cascades.
4
+ It can detect faces, eyes, and smiles in images and return detailed statistics.
5
+ """
6
+
7
+ import cv2 # OpenCV library for computer vision tasks
8
+ import numpy as np # NumPy for numerical operations on arrays
9
+ from typing import Tuple, List, Dict, Optional, Union # Type hints for better code documentation
10
+ import os # Operating system interface for file path operations
11
+ import logging # Logging module for error tracking
12
+
13
+ # Configure logging for this module
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class FaceDetector:
18
+ """
19
+ A class to handle face detection using OpenCV's Haar Cascade classifiers.
20
+ This detector can identify faces, eyes, and smiles in images.
21
+ """
22
+
23
+ def __init__(self):
24
+ """
25
+ Initialize the FaceDetector with pre-trained Haar Cascade classifiers.
26
+ Loads classifiers for face, eye, and smile detection.
27
+ """
28
+ try:
29
+ # Load the pre-trained Haar Cascade classifier for frontal face detection
30
+ # This XML file contains trained patterns for detecting frontal faces
31
+ self.face_cascade = cv2.CascadeClassifier(
32
+ cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
33
+ )
34
+
35
+ # Load the classifier for eye detection
36
+ # This works best when applied to face regions
37
+ self.eye_cascade = cv2.CascadeClassifier(
38
+ cv2.data.haarcascades + 'haarcascade_eye.xml'
39
+ )
40
+
41
+ # Load the classifier for smile detection
42
+ # This detects smiling expressions in face regions
43
+ self.smile_cascade = cv2.CascadeClassifier(
44
+ cv2.data.haarcascades + 'haarcascade_smile.xml'
45
+ )
46
+
47
+ # Verify that classifiers loaded successfully
48
+ if self.face_cascade.empty():
49
+ raise ValueError("Failed to load face cascade classifier")
50
+ if self.eye_cascade.empty():
51
+ raise ValueError("Failed to load eye cascade classifier")
52
+ if self.smile_cascade.empty():
53
+ raise ValueError("Failed to load smile cascade classifier")
54
+
55
+ logger.info("Face detection classifiers loaded successfully")
56
+
57
+ except Exception as e:
58
+ logger.error(f"Error initializing face detector: {str(e)}")
59
+ raise
60
+
61
+ def detect_faces(self, image: Optional[np.ndarray]) -> Tuple[np.ndarray, Dict]:
62
+ """
63
+ Detect faces in an image and return an annotated image with statistics.
64
+
65
+ Parameters:
66
+ -----------
67
+ image : np.ndarray
68
+ Input image as a NumPy array (can be grayscale or color)
69
+
70
+ Returns:
71
+ --------
72
+ Tuple[np.ndarray, Dict]
73
+ - Annotated image with face detection boxes and labels
74
+ - Dictionary containing detection statistics and face details
75
+ """
76
+ # Input validation - check if image is provided and valid
77
+ if image is None:
78
+ logger.warning("No image provided for face detection")
79
+ return np.zeros((480, 640, 3), dtype=np.uint8), {
80
+ 'total_faces': 0,
81
+ 'faces': [],
82
+ 'detection_method': 'Haar Cascade',
83
+ 'features_detected': {'eyes': 0, 'smiles': 0}
84
+ }
85
+
86
+ # Ensure image is a numpy array
87
+ if not isinstance(image, np.ndarray):
88
+ logger.error("Image must be a numpy array")
89
+ raise TypeError("Image must be a numpy array")
90
+
91
+ # Check if image is empty
92
+ if image.size == 0:
93
+ logger.warning("Empty image provided")
94
+ return image, {
95
+ 'total_faces': 0,
96
+ 'faces': [],
97
+ 'detection_method': 'Haar Cascade',
98
+ 'features_detected': {'eyes': 0, 'smiles': 0}
99
+ }
100
+
101
+ # Convert grayscale images to RGB for consistent output
102
+ # Check the number of dimensions to determine if image is grayscale
103
+ if len(image.shape) == 2: # Grayscale image (height, width)
104
+ img_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
105
+ elif len(image.shape) == 3: # Color image (height, width, channels)
106
+ img_rgb = image.copy() # Create a copy to avoid modifying original
107
+ else:
108
+ logger.error(f"Invalid image shape: {image.shape}")
109
+ raise ValueError(f"Invalid image shape: {image.shape}")
110
+
111
+ # Convert to grayscale for detection algorithms
112
+ # Haar Cascades work on grayscale images for better performance
113
+ try:
114
+ gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
115
+ except cv2.error as e:
116
+ logger.error(f"Error converting image to grayscale: {str(e)}")
117
+ return img_rgb, {
118
+ 'total_faces': 0,
119
+ 'faces': [],
120
+ 'detection_method': 'Haar Cascade',
121
+ 'features_detected': {'eyes': 0, 'smiles': 0}
122
+ }
123
+
124
+ # Detect faces using the Haar Cascade classifier
125
+ # Parameters control detection sensitivity and performance
126
+ faces = self.face_cascade.detectMultiScale(
127
+ gray, # Grayscale image to search
128
+ scaleFactor=1.1, # Image pyramid scaling factor (1.1 = 10% reduction each level)
129
+ minNeighbors=5, # Minimum neighbors for detection confidence
130
+ minSize=(30, 30) # Minimum face size in pixels
131
+ )
132
+
133
+ # List to store detailed information about each detected face
134
+ face_details = []
135
+
136
+ # Process each detected face
137
+ for idx, (x, y, w, h) in enumerate(faces):
138
+ # Draw a magenta rectangle around the detected face
139
+ # Parameters: image, top-left corner, bottom-right corner, color (BGR), thickness
140
+ cv2.rectangle(img_rgb, (x, y), (x+w, y+h), (255, 0, 255), 3)
141
+
142
+ # Add a label above each face
143
+ # Parameters: image, text, position, font, scale, color, thickness
144
+ cv2.putText(
145
+ img_rgb,
146
+ f"Face {idx+1}", # Label text
147
+ (x, y-10), # Position (above the rectangle)
148
+ cv2.FONT_HERSHEY_SIMPLEX, # Font type
149
+ 0.7, # Font scale
150
+ (255, 0, 255), # Color (magenta in RGB)
151
+ 2 # Thickness
152
+ )
153
+
154
+ # Extract Region of Interest (ROI) for face area
155
+ # This isolates the face region for feature detection
156
+ roi_gray = gray[y:y+h, x:x+w] # Grayscale ROI for detection
157
+ roi_color = img_rgb[y:y+h, x:x+w] # Color ROI for drawing
158
+
159
+ # Detect eyes within the face region
160
+ # Using different parameters for eye detection (more sensitive)
161
+ eyes = self.eye_cascade.detectMultiScale(
162
+ roi_gray,
163
+ scaleFactor=1.05, # Smaller scale factor for finer detection
164
+ minNeighbors=3 # Fewer neighbors required
165
+ )
166
+ eye_count = len(eyes) # Count the number of detected eyes
167
+
168
+ # Draw green rectangles around detected eyes
169
+ for (ex, ey, ew, eh) in eyes:
170
+ cv2.rectangle(
171
+ roi_color, # Draw on the color ROI
172
+ (ex, ey), # Top-left corner
173
+ (ex+ew, ey+eh), # Bottom-right corner
174
+ (0, 255, 0), # Green color in RGB
175
+ 2 # Thickness
176
+ )
177
+
178
+ # Detect smiles within the face region
179
+ # Smile detection requires different parameters
180
+ smiles = self.smile_cascade.detectMultiScale(
181
+ roi_gray,
182
+ scaleFactor=1.8, # Larger scale factor for smile detection
183
+ minNeighbors=20 # More neighbors required for confidence
184
+ )
185
+ has_smile = len(smiles) > 0 # Boolean flag for smile presence
186
+
187
+ # Draw yellow rectangles around detected smiles
188
+ for (sx, sy, sw, sh) in smiles:
189
+ cv2.rectangle(
190
+ roi_color, # Draw on the color ROI
191
+ (sx, sy), # Top-left corner
192
+ (sx+sw, sy+sh), # Bottom-right corner
193
+ (0, 255, 255), # Yellow color in RGB
194
+ 2 # Thickness
195
+ )
196
+
197
+ # Store detailed information about this face
198
+ face_details.append({
199
+ 'face_id': idx + 1, # Sequential face ID starting from 1
200
+ 'position': { # Face bounding box coordinates
201
+ 'x': int(x), # X coordinate of top-left corner
202
+ 'y': int(y), # Y coordinate of top-left corner
203
+ 'width': int(w), # Width of face bounding box
204
+ 'height': int(h) # Height of face bounding box
205
+ },
206
+ 'eyes_detected': eye_count, # Number of eyes detected
207
+ 'smile_detected': has_smile, # Whether a smile was detected
208
+ 'confidence': 0.95 # Placeholder confidence score
209
+ })
210
+
211
+ # Compile comprehensive statistics about all detected faces
212
+ stats = {
213
+ 'total_faces': len(faces), # Total number of faces detected
214
+ 'faces': face_details, # List of detailed face information
215
+ 'detection_method': 'Haar Cascade', # Method used for detection
216
+ 'features_detected': { # Aggregate feature statistics
217
+ 'eyes': sum(f['eyes_detected'] for f in face_details), # Total eyes
218
+ 'smiles': sum(1 for f in face_details if f['smile_detected']) # Total smiles
219
+ }
220
+ }
221
+
222
+ return img_rgb, stats # Return annotated image and statistics
223
+
224
+ def analyze_demographics(self, face_stats: Optional[Dict]) -> str:
225
+ """
226
+ Create a demographic analysis report based on face detection statistics.
227
+
228
+ Parameters:
229
+ -----------
230
+ face_stats : Dict
231
+ Dictionary containing face detection statistics
232
+
233
+ Returns:
234
+ --------
235
+ str
236
+ Formatted text analysis of detected faces and their features
237
+ """
238
+ # Handle case where no statistics are provided
239
+ if not face_stats:
240
+ return "No face detection data available."
241
+
242
+ # Handle case where no faces were detected
243
+ if face_stats.get('total_faces', 0) == 0:
244
+ return "No faces detected in the image."
245
+
246
+ # Build analysis report
247
+ analysis = [] # List to accumulate analysis text
248
+
249
+ # Add header with total face count
250
+ analysis.append(f"👥 Detected {face_stats['total_faces']} face(s) in the image\n")
251
+
252
+ # Add detailed information for each face
253
+ for face in face_stats.get('faces', []):
254
+ # Create description for individual face
255
+ face_desc = f"\n**Face {face['face_id']}:**"
256
+
257
+ # Add position information
258
+ pos = face.get('position', {})
259
+ face_desc += f"\n • Position: ({pos.get('x', 0)}, {pos.get('y', 0)})"
260
+
261
+ # Add size information
262
+ face_desc += f"\n • Size: {pos.get('width', 0)}x{pos.get('height', 0)} pixels"
263
+
264
+ # Add eye detection information if eyes were found
265
+ if face.get('eyes_detected', 0) > 0:
266
+ face_desc += f"\n • Eyes detected: {face['eyes_detected']}"
267
+
268
+ # Add smile detection information
269
+ if face.get('smile_detected', False):
270
+ face_desc += "\n • 😊 Smile detected!"
271
+
272
+ analysis.append(face_desc) # Add face description to analysis
273
+
274
+ # Add summary statistics if smiles were detected
275
+ features = face_stats.get('features_detected', {})
276
+ smile_count = features.get('smiles', 0)
277
+
278
+ if smile_count > 0:
279
+ # Calculate percentage of faces that are smiling
280
+ smile_ratio = (smile_count / face_stats['total_faces']) * 100
281
+
282
+ # Add overall analysis section
283
+ analysis.append(f"\n\n📊 **Overall Analysis:**")
284
+ analysis.append(f"\n • {smile_ratio:.0f}% of faces are smiling")
285
+ analysis.append(f"\n • Total eyes detected: {features.get('eyes', 0)}")
286
+
287
+ # Join all analysis parts and return
288
+ return "".join(analysis)
289
+
290
+
291
+ # Create a global instance of FaceDetector for use throughout the application
292
+ # This avoids reloading classifiers multiple times
293
+ try:
294
+ face_detector = FaceDetector()
295
+ logger.info("Global face detector initialized successfully")
296
+ except Exception as e:
297
+ logger.error(f"Failed to initialize global face detector: {str(e)}")
298
+ # Create a dummy detector that returns empty results
299
  face_detector = None
backend/openai_client.py CHANGED
@@ -1,65 +1,65 @@
1
- import os
2
- from openai import OpenAI # type: ignore
3
- import tempfile
4
-
5
- # Lazily initialized OpenAI client to avoid import-time errors when the
6
- # API key isn't configured. Previously this module attempted to create the
7
- # client on import and raised a ``ValueError`` if ``OPENAI_API_KEY`` was
8
- # missing, which prevented the rest of the application from running (and
9
- # broke tests that don't require the API). The client is now created only
10
- # when needed.
11
- _client: OpenAI | None = None
12
-
13
-
14
- def _get_client() -> OpenAI:
15
- """Return a cached OpenAI client instance.
16
-
17
- Raises:
18
- ValueError: If the ``OPENAI_API_KEY`` environment variable is not set.
19
- """
20
- global _client
21
- if _client is None:
22
- api_key = os.getenv("OPENAI_API_KEY")
23
- if not api_key:
24
- raise ValueError("OPENAI_API_KEY environment variable is required but not set")
25
- _client = OpenAI(api_key=api_key)
26
- return _client
27
-
28
-
29
- def explain_detection(objects_list):
30
- """Send detected objects to OpenAI and return an explanation."""
31
- if not objects_list:
32
- return "No objects detected."
33
-
34
- prompt = f"Explain these detected objects in simple terms: {objects_list}"
35
-
36
- client = _get_client()
37
- response = client.chat.completions.create(
38
- model="gpt-4o-mini", # new lightweight chat model
39
- messages=[{"role": "user", "content": prompt}],
40
- )
41
-
42
- return response.choices[0].message.content
43
-
44
-
45
- def generate_voice(text):
46
- """Generate voice narration using OpenAI's TTS service."""
47
- try:
48
- client = _get_client()
49
-
50
- # Generate speech using OpenAI TTS
51
- response = client.audio.speech.create(
52
- model="tts-1",
53
- voice="alloy", # You can change this to: alloy, echo, fable, onyx, nova, or shimmer
54
- input=text,
55
- response_format="mp3",
56
- )
57
-
58
- # Save the audio to a temporary file
59
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
60
- temp_audio.write(response.content)
61
- return temp_audio.name
62
-
63
- except Exception as e:
64
- print(f"Voice generation error: {e}")
65
- return None
 
1
+ import os
2
+ from openai import OpenAI # type: ignore
3
+ import tempfile
4
+
5
+ # Lazily initialized OpenAI client to avoid import-time errors when the
6
+ # API key isn't configured. Previously this module attempted to create the
7
+ # client on import and raised a ``ValueError`` if ``OPENAI_API_KEY`` was
8
+ # missing, which prevented the rest of the application from running (and
9
+ # broke tests that don't require the API). The client is now created only
10
+ # when needed.
11
+ _client: OpenAI | None = None
12
+
13
+
14
+ def _get_client() -> OpenAI:
15
+ """Return a cached OpenAI client instance.
16
+
17
+ Raises:
18
+ ValueError: If the ``OPENAI_API_KEY`` environment variable is not set.
19
+ """
20
+ global _client
21
+ if _client is None:
22
+ api_key = os.getenv("OPENAI_API_KEY")
23
+ if not api_key:
24
+ raise ValueError("OPENAI_API_KEY environment variable is required but not set")
25
+ _client = OpenAI(api_key=api_key)
26
+ return _client
27
+
28
+
29
+ def explain_detection(objects_list):
30
+ """Send detected objects to OpenAI and return an explanation."""
31
+ if not objects_list:
32
+ return "No objects detected."
33
+
34
+ prompt = f"Explain these detected objects in simple terms: {objects_list}"
35
+
36
+ client = _get_client()
37
+ response = client.chat.completions.create(
38
+ model="gpt-4o-mini", # new lightweight chat model
39
+ messages=[{"role": "user", "content": prompt}],
40
+ )
41
+
42
+ return response.choices[0].message.content
43
+
44
+
45
+ def generate_voice(text):
46
+ """Generate voice narration using OpenAI's TTS service."""
47
+ try:
48
+ client = _get_client()
49
+
50
+ # Generate speech using OpenAI TTS
51
+ response = client.audio.speech.create(
52
+ model="tts-1",
53
+ voice="alloy", # You can change this to: alloy, echo, fable, onyx, nova, or shimmer
54
+ input=text,
55
+ response_format="mp3",
56
+ )
57
+
58
+ # Save the audio to a temporary file
59
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
60
+ temp_audio.write(response.content)
61
+ return temp_audio.name
62
+
63
+ except Exception as e:
64
+ print(f"Voice generation error: {e}")
65
+ return None
backend/prisma_client.py CHANGED
@@ -1,400 +1,400 @@
1
- """
2
- Prisma Client Integration for NAVADA 2.0
3
- Provides enhanced database operations with Prisma ORM
4
- """
5
-
6
- import asyncio
7
- import json
8
- import base64
9
- import logging
10
- from typing import List, Dict, Optional, Any
11
- from datetime import datetime
12
- import numpy as np
13
- import cv2
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
- class PrismaManager:
18
- """Enhanced database manager using Prisma ORM"""
19
-
20
- def __init__(self):
21
- self.client = None
22
- self._init_client()
23
-
24
- def _init_client(self):
25
- """Initialize Prisma client"""
26
- try:
27
- # Import Prisma client (needs to be generated first)
28
- # from prisma import Prisma
29
- # self.client = Prisma()
30
- logger.info("Prisma client initialized")
31
- except ImportError:
32
- logger.warning("Prisma client not available - run 'npm run prisma:generate'")
33
- self.client = None
34
- except Exception as e:
35
- logger.error(f"Failed to initialize Prisma client: {e}")
36
- self.client = None
37
-
38
- async def connect(self):
39
- """Connect to database"""
40
- if self.client:
41
- try:
42
- await self.client.connect()
43
- logger.info("Connected to database via Prisma")
44
- return True
45
- except Exception as e:
46
- logger.error(f"Failed to connect to database: {e}")
47
- return False
48
- return False
49
-
50
- async def disconnect(self):
51
- """Disconnect from database"""
52
- if self.client:
53
- try:
54
- await self.client.disconnect()
55
- logger.info("Disconnected from database")
56
- except Exception as e:
57
- logger.error(f"Error disconnecting: {e}")
58
-
59
- # Document Management for Knowledge Retrieval
60
- async def add_document(self, title: str, content: str, content_type: str = "text",
61
- tags: List[str] = None, category: str = None,
62
- image_data: bytes = None, image_url: str = None) -> Optional[int]:
63
- """
64
- Add document for knowledge retrieval
65
-
66
- Args:
67
- title: Document title
68
- content: Document content (text)
69
- content_type: "text", "image", "mixed"
70
- tags: List of tags
71
- category: Document category
72
- image_data: Binary image data
73
- image_url: URL to image
74
-
75
- Returns:
76
- Document ID if successful
77
- """
78
- if not self.client:
79
- return None
80
-
81
- try:
82
- tags_str = json.dumps(tags) if tags else None
83
-
84
- document = await self.client.document.create(
85
- data={
86
- 'title': title,
87
- 'content': content,
88
- 'contentType': content_type,
89
- 'tags': tags_str,
90
- 'category': category,
91
- 'imageData': image_data,
92
- 'imageUrl': image_url
93
- }
94
- )
95
-
96
- # Create document chunks for better retrieval
97
- await self._create_document_chunks(document.id, content)
98
-
99
- logger.info(f"Added document: {title} (ID: {document.id})")
100
- return document.id
101
-
102
- except Exception as e:
103
- logger.error(f"Failed to add document: {e}")
104
- return None
105
-
106
- async def _create_document_chunks(self, document_id: int, content: str, chunk_size: int = 500):
107
- """Create chunks from document content for better retrieval"""
108
- if not self.client:
109
- return
110
-
111
- try:
112
- # Split content into chunks
113
- chunks = [content[i:i+chunk_size] for i in range(0, len(content), chunk_size)]
114
-
115
- for i, chunk in enumerate(chunks):
116
- await self.client.documentchunk.create(
117
- data={
118
- 'documentId': document_id,
119
- 'chunkIndex': i,
120
- 'content': chunk
121
- }
122
- )
123
-
124
- except Exception as e:
125
- logger.error(f"Failed to create document chunks: {e}")
126
-
127
- async def search_documents(self, query: str, content_type: str = None,
128
- category: str = None, limit: int = 10) -> List[Dict]:
129
- """
130
- Search documents by content, tags, or category
131
-
132
- Args:
133
- query: Search query
134
- content_type: Filter by content type
135
- category: Filter by category
136
- limit: Maximum results
137
-
138
- Returns:
139
- List of matching documents
140
- """
141
- if not self.client:
142
- return []
143
-
144
- try:
145
- where_clause = {
146
- 'isActive': True,
147
- 'OR': [
148
- {'title': {'contains': query}},
149
- {'content': {'contains': query}},
150
- {'tags': {'contains': query}}
151
- ]
152
- }
153
-
154
- if content_type:
155
- where_clause['contentType'] = content_type
156
- if category:
157
- where_clause['category'] = category
158
-
159
- documents = await self.client.document.find_many(
160
- where=where_clause,
161
- take=limit,
162
- order_by={'createdAt': 'desc'}
163
- )
164
-
165
- return [self._document_to_dict(doc) for doc in documents]
166
-
167
- except Exception as e:
168
- logger.error(f"Document search failed: {e}")
169
- return []
170
-
171
- def _document_to_dict(self, document) -> Dict:
172
- """Convert Prisma document to dictionary"""
173
- return {
174
- 'id': document.id,
175
- 'title': document.title,
176
- 'content': document.content,
177
- 'content_type': document.contentType,
178
- 'tags': json.loads(document.tags) if document.tags else [],
179
- 'category': document.category,
180
- 'image_url': document.imageUrl,
181
- 'created_at': document.createdAt,
182
- 'updated_at': document.updatedAt
183
- }
184
-
185
- # Media File Management
186
- async def add_media_file(self, filename: str, filepath: str, mime_type: str,
187
- file_size: int, image_data: bytes = None,
188
- description: str = None, tags: List[str] = None) -> Optional[int]:
189
- """Add media file to database"""
190
- if not self.client:
191
- return None
192
-
193
- try:
194
- tags_str = json.dumps(tags) if tags else None
195
-
196
- media_file = await self.client.mediafile.create(
197
- data={
198
- 'filename': filename,
199
- 'filepath': filepath,
200
- 'mimeType': mime_type,
201
- 'fileSize': file_size,
202
- 'imageData': image_data,
203
- 'description': description,
204
- 'tags': tags_str
205
- }
206
- )
207
-
208
- logger.info(f"Added media file: {filename} (ID: {media_file.id})")
209
- return media_file.id
210
-
211
- except Exception as e:
212
- logger.error(f"Failed to add media file: {e}")
213
- return None
214
-
215
- async def get_media_files(self, tags: List[str] = None, mime_type: str = None,
216
- limit: int = 50) -> List[Dict]:
217
- """Get media files with optional filtering"""
218
- if not self.client:
219
- return []
220
-
221
- try:
222
- where_clause = {'isActive': True}
223
-
224
- if mime_type:
225
- where_clause['mimeType'] = {'contains': mime_type}
226
-
227
- if tags:
228
- # Search for any of the provided tags
229
- tag_conditions = [{'tags': {'contains': tag}} for tag in tags]
230
- where_clause['OR'] = tag_conditions
231
-
232
- media_files = await self.client.mediafile.find_many(
233
- where=where_clause,
234
- take=limit,
235
- order_by={'createdAt': 'desc'}
236
- )
237
-
238
- return [self._media_file_to_dict(file) for file in media_files]
239
-
240
- except Exception as e:
241
- logger.error(f"Failed to get media files: {e}")
242
- return []
243
-
244
- def _media_file_to_dict(self, media_file) -> Dict:
245
- """Convert Prisma media file to dictionary"""
246
- return {
247
- 'id': media_file.id,
248
- 'filename': media_file.filename,
249
- 'filepath': media_file.filepath,
250
- 'mime_type': media_file.mimeType,
251
- 'file_size': media_file.fileSize,
252
- 'description': media_file.description,
253
- 'tags': json.loads(media_file.tags) if media_file.tags else [],
254
- 'created_at': media_file.createdAt
255
- }
256
-
257
- # Enhanced Knowledge Base Operations
258
- async def add_knowledge_entry(self, entity_type: str, entity_id: int, content: str,
259
- title: str = None, description: str = None,
260
- tags: List[str] = None, category: str = None,
261
- image_url: str = None, text_content: str = None) -> Optional[int]:
262
- """Add enhanced knowledge base entry"""
263
- if not self.client:
264
- return None
265
-
266
- try:
267
- keywords_str = json.dumps(tags) if tags else None
268
-
269
- knowledge_entry = await self.client.knowledgebase.create(
270
- data={
271
- 'entityType': entity_type,
272
- 'entityId': entity_id,
273
- 'content': content,
274
- 'title': title,
275
- 'description': description,
276
- 'keywords': keywords_str,
277
- 'category': category,
278
- 'imageUrl': image_url,
279
- 'textContent': text_content
280
- }
281
- )
282
-
283
- logger.info(f"Added knowledge entry: {title or content[:50]}")
284
- return knowledge_entry.id
285
-
286
- except Exception as e:
287
- logger.error(f"Failed to add knowledge entry: {e}")
288
- return None
289
-
290
- async def search_knowledge(self, query: str, entity_type: str = None,
291
- category: str = None, limit: int = 10) -> List[Dict]:
292
- """Enhanced knowledge search"""
293
- if not self.client:
294
- return []
295
-
296
- try:
297
- where_clause = {
298
- 'OR': [
299
- {'content': {'contains': query}},
300
- {'title': {'contains': query}},
301
- {'description': {'contains': query}},
302
- {'keywords': {'contains': query}},
303
- {'textContent': {'contains': query}}
304
- ]
305
- }
306
-
307
- if entity_type:
308
- where_clause['entityType'] = entity_type
309
- if category:
310
- where_clause['category'] = category
311
-
312
- entries = await self.client.knowledgebase.find_many(
313
- where=where_clause,
314
- take=limit,
315
- order_by={'createdAt': 'desc'}
316
- )
317
-
318
- return [self._knowledge_to_dict(entry) for entry in entries]
319
-
320
- except Exception as e:
321
- logger.error(f"Knowledge search failed: {e}")
322
- return []
323
-
324
- def _knowledge_to_dict(self, entry) -> Dict:
325
- """Convert Prisma knowledge entry to dictionary"""
326
- return {
327
- 'id': entry.id,
328
- 'entity_type': entry.entityType,
329
- 'entity_id': entry.entityId,
330
- 'content': entry.content,
331
- 'title': entry.title,
332
- 'description': entry.description,
333
- 'keywords': json.loads(entry.keywords) if entry.keywords else [],
334
- 'category': entry.category,
335
- 'image_url': entry.imageUrl,
336
- 'text_content': entry.textContent,
337
- 'created_at': entry.createdAt,
338
- 'updated_at': entry.updatedAt
339
- }
340
-
341
- # Statistics and Analytics
342
- async def get_enhanced_stats(self) -> Dict:
343
- """Get comprehensive database statistics"""
344
- if not self.client:
345
- return {}
346
-
347
- try:
348
- stats = {}
349
-
350
- # Basic counts
351
- stats['faces'] = await self.client.face.count(where={'isActive': True})
352
- stats['objects'] = await self.client.object.count(where={'isActive': True})
353
- stats['documents'] = await self.client.document.count(where={'isActive': True})
354
- stats['media_files'] = await self.client.mediafile.count(where={'isActive': True})
355
- stats['knowledge_entries'] = await self.client.knowledgebase.count()
356
- stats['training_corrections'] = await self.client.trainingcorrection.count()
357
-
358
- # Recent activity (last 7 days)
359
- seven_days_ago = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
360
- stats['recent_detections'] = await self.client.detectionhistory.count(
361
- where={'createdAt': {'gte': seven_days_ago}}
362
- )
363
-
364
- return stats
365
-
366
- except Exception as e:
367
- logger.error(f"Failed to get enhanced stats: {e}")
368
- return {}
369
-
370
- # Global Prisma manager instance
371
- prisma_manager = PrismaManager()
372
-
373
- # Helper functions for async operations in Streamlit
374
- def run_async(coro):
375
- """Run async function in Streamlit"""
376
- try:
377
- loop = asyncio.get_event_loop()
378
- except RuntimeError:
379
- loop = asyncio.new_event_loop()
380
- asyncio.set_event_loop(loop)
381
-
382
- return loop.run_until_complete(coro)
383
-
384
- # Convenience functions
385
- def add_document_sync(title: str, content: str, **kwargs) -> Optional[int]:
386
- """Synchronous wrapper for adding documents"""
387
- return run_async(prisma_manager.add_document(title, content, **kwargs))
388
-
389
- def search_documents_sync(query: str, **kwargs) -> List[Dict]:
390
- """Synchronous wrapper for searching documents"""
391
- return run_async(prisma_manager.search_documents(query, **kwargs))
392
-
393
- def add_media_file_sync(filename: str, filepath: str, mime_type: str,
394
- file_size: int, **kwargs) -> Optional[int]:
395
- """Synchronous wrapper for adding media files"""
396
- return run_async(prisma_manager.add_media_file(filename, filepath, mime_type, file_size, **kwargs))
397
-
398
- def get_enhanced_stats_sync() -> Dict:
399
- """Synchronous wrapper for getting stats"""
400
  return run_async(prisma_manager.get_enhanced_stats())
 
1
+ """
2
+ Prisma Client Integration for NAVADA 2.0
3
+ Provides enhanced database operations with Prisma ORM
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ import base64
9
+ import logging
10
+ from typing import List, Dict, Optional, Any
11
+ from datetime import datetime
12
+ import numpy as np
13
+ import cv2
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class PrismaManager:
18
+ """Enhanced database manager using Prisma ORM"""
19
+
20
+ def __init__(self):
21
+ self.client = None
22
+ self._init_client()
23
+
24
+ def _init_client(self):
25
+ """Initialize Prisma client"""
26
+ try:
27
+ # Import Prisma client (needs to be generated first)
28
+ # from prisma import Prisma
29
+ # self.client = Prisma()
30
+ logger.info("Prisma client initialized")
31
+ except ImportError:
32
+ logger.warning("Prisma client not available - run 'npm run prisma:generate'")
33
+ self.client = None
34
+ except Exception as e:
35
+ logger.error(f"Failed to initialize Prisma client: {e}")
36
+ self.client = None
37
+
38
+ async def connect(self):
39
+ """Connect to database"""
40
+ if self.client:
41
+ try:
42
+ await self.client.connect()
43
+ logger.info("Connected to database via Prisma")
44
+ return True
45
+ except Exception as e:
46
+ logger.error(f"Failed to connect to database: {e}")
47
+ return False
48
+ return False
49
+
50
+ async def disconnect(self):
51
+ """Disconnect from database"""
52
+ if self.client:
53
+ try:
54
+ await self.client.disconnect()
55
+ logger.info("Disconnected from database")
56
+ except Exception as e:
57
+ logger.error(f"Error disconnecting: {e}")
58
+
59
+ # Document Management for Knowledge Retrieval
60
+ async def add_document(self, title: str, content: str, content_type: str = "text",
61
+ tags: List[str] = None, category: str = None,
62
+ image_data: bytes = None, image_url: str = None) -> Optional[int]:
63
+ """
64
+ Add document for knowledge retrieval
65
+
66
+ Args:
67
+ title: Document title
68
+ content: Document content (text)
69
+ content_type: "text", "image", "mixed"
70
+ tags: List of tags
71
+ category: Document category
72
+ image_data: Binary image data
73
+ image_url: URL to image
74
+
75
+ Returns:
76
+ Document ID if successful
77
+ """
78
+ if not self.client:
79
+ return None
80
+
81
+ try:
82
+ tags_str = json.dumps(tags) if tags else None
83
+
84
+ document = await self.client.document.create(
85
+ data={
86
+ 'title': title,
87
+ 'content': content,
88
+ 'contentType': content_type,
89
+ 'tags': tags_str,
90
+ 'category': category,
91
+ 'imageData': image_data,
92
+ 'imageUrl': image_url
93
+ }
94
+ )
95
+
96
+ # Create document chunks for better retrieval
97
+ await self._create_document_chunks(document.id, content)
98
+
99
+ logger.info(f"Added document: {title} (ID: {document.id})")
100
+ return document.id
101
+
102
+ except Exception as e:
103
+ logger.error(f"Failed to add document: {e}")
104
+ return None
105
+
106
+ async def _create_document_chunks(self, document_id: int, content: str, chunk_size: int = 500):
107
+ """Create chunks from document content for better retrieval"""
108
+ if not self.client:
109
+ return
110
+
111
+ try:
112
+ # Split content into chunks
113
+ chunks = [content[i:i+chunk_size] for i in range(0, len(content), chunk_size)]
114
+
115
+ for i, chunk in enumerate(chunks):
116
+ await self.client.documentchunk.create(
117
+ data={
118
+ 'documentId': document_id,
119
+ 'chunkIndex': i,
120
+ 'content': chunk
121
+ }
122
+ )
123
+
124
+ except Exception as e:
125
+ logger.error(f"Failed to create document chunks: {e}")
126
+
127
+ async def search_documents(self, query: str, content_type: str = None,
128
+ category: str = None, limit: int = 10) -> List[Dict]:
129
+ """
130
+ Search documents by content, tags, or category
131
+
132
+ Args:
133
+ query: Search query
134
+ content_type: Filter by content type
135
+ category: Filter by category
136
+ limit: Maximum results
137
+
138
+ Returns:
139
+ List of matching documents
140
+ """
141
+ if not self.client:
142
+ return []
143
+
144
+ try:
145
+ where_clause = {
146
+ 'isActive': True,
147
+ 'OR': [
148
+ {'title': {'contains': query}},
149
+ {'content': {'contains': query}},
150
+ {'tags': {'contains': query}}
151
+ ]
152
+ }
153
+
154
+ if content_type:
155
+ where_clause['contentType'] = content_type
156
+ if category:
157
+ where_clause['category'] = category
158
+
159
+ documents = await self.client.document.find_many(
160
+ where=where_clause,
161
+ take=limit,
162
+ order_by={'createdAt': 'desc'}
163
+ )
164
+
165
+ return [self._document_to_dict(doc) for doc in documents]
166
+
167
+ except Exception as e:
168
+ logger.error(f"Document search failed: {e}")
169
+ return []
170
+
171
+ def _document_to_dict(self, document) -> Dict:
172
+ """Convert Prisma document to dictionary"""
173
+ return {
174
+ 'id': document.id,
175
+ 'title': document.title,
176
+ 'content': document.content,
177
+ 'content_type': document.contentType,
178
+ 'tags': json.loads(document.tags) if document.tags else [],
179
+ 'category': document.category,
180
+ 'image_url': document.imageUrl,
181
+ 'created_at': document.createdAt,
182
+ 'updated_at': document.updatedAt
183
+ }
184
+
185
+ # Media File Management
186
+ async def add_media_file(self, filename: str, filepath: str, mime_type: str,
187
+ file_size: int, image_data: bytes = None,
188
+ description: str = None, tags: List[str] = None) -> Optional[int]:
189
+ """Add media file to database"""
190
+ if not self.client:
191
+ return None
192
+
193
+ try:
194
+ tags_str = json.dumps(tags) if tags else None
195
+
196
+ media_file = await self.client.mediafile.create(
197
+ data={
198
+ 'filename': filename,
199
+ 'filepath': filepath,
200
+ 'mimeType': mime_type,
201
+ 'fileSize': file_size,
202
+ 'imageData': image_data,
203
+ 'description': description,
204
+ 'tags': tags_str
205
+ }
206
+ )
207
+
208
+ logger.info(f"Added media file: {filename} (ID: {media_file.id})")
209
+ return media_file.id
210
+
211
+ except Exception as e:
212
+ logger.error(f"Failed to add media file: {e}")
213
+ return None
214
+
215
+ async def get_media_files(self, tags: List[str] = None, mime_type: str = None,
216
+ limit: int = 50) -> List[Dict]:
217
+ """Get media files with optional filtering"""
218
+ if not self.client:
219
+ return []
220
+
221
+ try:
222
+ where_clause = {'isActive': True}
223
+
224
+ if mime_type:
225
+ where_clause['mimeType'] = {'contains': mime_type}
226
+
227
+ if tags:
228
+ # Search for any of the provided tags
229
+ tag_conditions = [{'tags': {'contains': tag}} for tag in tags]
230
+ where_clause['OR'] = tag_conditions
231
+
232
+ media_files = await self.client.mediafile.find_many(
233
+ where=where_clause,
234
+ take=limit,
235
+ order_by={'createdAt': 'desc'}
236
+ )
237
+
238
+ return [self._media_file_to_dict(file) for file in media_files]
239
+
240
+ except Exception as e:
241
+ logger.error(f"Failed to get media files: {e}")
242
+ return []
243
+
244
+ def _media_file_to_dict(self, media_file) -> Dict:
245
+ """Convert Prisma media file to dictionary"""
246
+ return {
247
+ 'id': media_file.id,
248
+ 'filename': media_file.filename,
249
+ 'filepath': media_file.filepath,
250
+ 'mime_type': media_file.mimeType,
251
+ 'file_size': media_file.fileSize,
252
+ 'description': media_file.description,
253
+ 'tags': json.loads(media_file.tags) if media_file.tags else [],
254
+ 'created_at': media_file.createdAt
255
+ }
256
+
257
+ # Enhanced Knowledge Base Operations
258
+ async def add_knowledge_entry(self, entity_type: str, entity_id: int, content: str,
259
+ title: str = None, description: str = None,
260
+ tags: List[str] = None, category: str = None,
261
+ image_url: str = None, text_content: str = None) -> Optional[int]:
262
+ """Add enhanced knowledge base entry"""
263
+ if not self.client:
264
+ return None
265
+
266
+ try:
267
+ keywords_str = json.dumps(tags) if tags else None
268
+
269
+ knowledge_entry = await self.client.knowledgebase.create(
270
+ data={
271
+ 'entityType': entity_type,
272
+ 'entityId': entity_id,
273
+ 'content': content,
274
+ 'title': title,
275
+ 'description': description,
276
+ 'keywords': keywords_str,
277
+ 'category': category,
278
+ 'imageUrl': image_url,
279
+ 'textContent': text_content
280
+ }
281
+ )
282
+
283
+ logger.info(f"Added knowledge entry: {title or content[:50]}")
284
+ return knowledge_entry.id
285
+
286
+ except Exception as e:
287
+ logger.error(f"Failed to add knowledge entry: {e}")
288
+ return None
289
+
290
+ async def search_knowledge(self, query: str, entity_type: str = None,
291
+ category: str = None, limit: int = 10) -> List[Dict]:
292
+ """Enhanced knowledge search"""
293
+ if not self.client:
294
+ return []
295
+
296
+ try:
297
+ where_clause = {
298
+ 'OR': [
299
+ {'content': {'contains': query}},
300
+ {'title': {'contains': query}},
301
+ {'description': {'contains': query}},
302
+ {'keywords': {'contains': query}},
303
+ {'textContent': {'contains': query}}
304
+ ]
305
+ }
306
+
307
+ if entity_type:
308
+ where_clause['entityType'] = entity_type
309
+ if category:
310
+ where_clause['category'] = category
311
+
312
+ entries = await self.client.knowledgebase.find_many(
313
+ where=where_clause,
314
+ take=limit,
315
+ order_by={'createdAt': 'desc'}
316
+ )
317
+
318
+ return [self._knowledge_to_dict(entry) for entry in entries]
319
+
320
+ except Exception as e:
321
+ logger.error(f"Knowledge search failed: {e}")
322
+ return []
323
+
324
+ def _knowledge_to_dict(self, entry) -> Dict:
325
+ """Convert Prisma knowledge entry to dictionary"""
326
+ return {
327
+ 'id': entry.id,
328
+ 'entity_type': entry.entityType,
329
+ 'entity_id': entry.entityId,
330
+ 'content': entry.content,
331
+ 'title': entry.title,
332
+ 'description': entry.description,
333
+ 'keywords': json.loads(entry.keywords) if entry.keywords else [],
334
+ 'category': entry.category,
335
+ 'image_url': entry.imageUrl,
336
+ 'text_content': entry.textContent,
337
+ 'created_at': entry.createdAt,
338
+ 'updated_at': entry.updatedAt
339
+ }
340
+
341
+ # Statistics and Analytics
342
+ async def get_enhanced_stats(self) -> Dict:
343
+ """Get comprehensive database statistics"""
344
+ if not self.client:
345
+ return {}
346
+
347
+ try:
348
+ stats = {}
349
+
350
+ # Basic counts
351
+ stats['faces'] = await self.client.face.count(where={'isActive': True})
352
+ stats['objects'] = await self.client.object.count(where={'isActive': True})
353
+ stats['documents'] = await self.client.document.count(where={'isActive': True})
354
+ stats['media_files'] = await self.client.mediafile.count(where={'isActive': True})
355
+ stats['knowledge_entries'] = await self.client.knowledgebase.count()
356
+ stats['training_corrections'] = await self.client.trainingcorrection.count()
357
+
358
+ # Recent activity (last 7 days)
359
+ seven_days_ago = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
360
+ stats['recent_detections'] = await self.client.detectionhistory.count(
361
+ where={'createdAt': {'gte': seven_days_ago}}
362
+ )
363
+
364
+ return stats
365
+
366
+ except Exception as e:
367
+ logger.error(f"Failed to get enhanced stats: {e}")
368
+ return {}
369
+
370
+ # Global Prisma manager instance
371
+ prisma_manager = PrismaManager()
372
+
373
+ # Helper functions for async operations in Streamlit
374
+ def run_async(coro):
375
+ """Run async function in Streamlit"""
376
+ try:
377
+ loop = asyncio.get_event_loop()
378
+ except RuntimeError:
379
+ loop = asyncio.new_event_loop()
380
+ asyncio.set_event_loop(loop)
381
+
382
+ return loop.run_until_complete(coro)
383
+
384
+ # Convenience functions
385
+ def add_document_sync(title: str, content: str, **kwargs) -> Optional[int]:
386
+ """Synchronous wrapper for adding documents"""
387
+ return run_async(prisma_manager.add_document(title, content, **kwargs))
388
+
389
+ def search_documents_sync(query: str, **kwargs) -> List[Dict]:
390
+ """Synchronous wrapper for searching documents"""
391
+ return run_async(prisma_manager.search_documents(query, **kwargs))
392
+
393
+ def add_media_file_sync(filename: str, filepath: str, mime_type: str,
394
+ file_size: int, **kwargs) -> Optional[int]:
395
+ """Synchronous wrapper for adding media files"""
396
+ return run_async(prisma_manager.add_media_file(filename, filepath, mime_type, file_size, **kwargs))
397
+
398
+ def get_enhanced_stats_sync() -> Dict:
399
+ """Synchronous wrapper for getting stats"""
400
  return run_async(prisma_manager.get_enhanced_stats())
backend/recognition.py CHANGED
@@ -1,367 +1,367 @@
1
- """
2
- Advanced Recognition Module for NAVADA
3
- Handles face recognition, custom object detection, and RAG-enhanced identification
4
- """
5
-
6
- import cv2
7
- import numpy as np
8
- from typing import List, Dict, Tuple, Optional
9
- import logging
10
- from .database import db
11
- from .face_detection import face_detector
12
- import time
13
- import uuid
14
-
15
- # Configure logging
16
- logger = logging.getLogger(__name__)
17
-
18
- class NAVADARecognition:
19
- """Advanced recognition system with database integration"""
20
-
21
- def __init__(self):
22
- """Initialize recognition system"""
23
- self.face_threshold = 0.6 # Face recognition threshold
24
- self.object_threshold = 0.5 # Object recognition threshold
25
- self.session_id = str(uuid.uuid4())
26
-
27
- def extract_face_encoding(self, face_image: np.ndarray) -> Optional[np.ndarray]:
28
- """
29
- Extract face encoding for recognition
30
- This is a simplified version - in production, use face_recognition library
31
- """
32
- try:
33
- # Convert to grayscale and resize
34
- gray = cv2.cvtColor(face_image, cv2.COLOR_RGB2GRAY)
35
- resized = cv2.resize(gray, (128, 128))
36
-
37
- # Flatten and normalize as simple encoding
38
- encoding = resized.flatten().astype(np.float64)
39
- encoding = encoding / np.linalg.norm(encoding) # Normalize
40
-
41
- return encoding
42
-
43
- except Exception as e:
44
- logger.error(f"Face encoding extraction failed: {e}")
45
- return None
46
-
47
- def compare_face_encodings(self, encoding1: np.ndarray, encoding2: np.ndarray) -> float:
48
- """Compare two face encodings and return similarity score"""
49
- try:
50
- # Calculate cosine similarity
51
- similarity = np.dot(encoding1, encoding2) / (
52
- np.linalg.norm(encoding1) * np.linalg.norm(encoding2)
53
- )
54
- return float(similarity)
55
-
56
- except Exception as e:
57
- logger.error(f"Face comparison failed: {e}")
58
- return 0.0
59
-
60
- def recognize_faces(self, image: np.ndarray) -> Tuple[np.ndarray, List[Dict]]:
61
- """
62
- Recognize faces in image against database
63
-
64
- Returns:
65
- Annotated image and list of recognition results
66
- """
67
- try:
68
- if not db:
69
- return image, []
70
-
71
- # Detect faces first
72
- annotated_img, face_stats = face_detector.detect_faces(image)
73
-
74
- # Get known faces from database
75
- known_faces = db.get_faces()
76
-
77
- recognition_results = []
78
-
79
- if face_stats and face_stats['faces']:
80
- for face_info in face_stats['faces']:
81
- # Extract face region
82
- pos = face_info['position']
83
- x, y, w, h = pos['x'], pos['y'], pos['width'], pos['height']
84
- face_region = image[y:y+h, x:x+w]
85
-
86
- if face_region.size > 0:
87
- # Extract face encoding
88
- face_encoding = self.extract_face_encoding(face_region)
89
-
90
- if face_encoding is not None:
91
- # Compare with known faces
92
- best_match = None
93
- best_similarity = 0.0
94
-
95
- for known_face in known_faces:
96
- similarity = self.compare_face_encodings(
97
- face_encoding, known_face['encoding']
98
- )
99
-
100
- if similarity > best_similarity and similarity > self.face_threshold:
101
- best_similarity = similarity
102
- best_match = known_face
103
-
104
- # Add recognition result
105
- if best_match:
106
- # Draw name on image
107
- name = best_match['name']
108
- cv2.putText(annotated_img, f"{name} ({best_similarity:.2f})",
109
- (x, y-30), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
110
- (0, 255, 0), 2)
111
-
112
- recognition_results.append({
113
- 'face_id': face_info['face_id'],
114
- 'name': name,
115
- 'similarity': best_similarity,
116
- 'position': pos,
117
- 'database_id': best_match['id']
118
- })
119
- else:
120
- # Unknown face
121
- cv2.putText(annotated_img, "Unknown",
122
- (x, y-30), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
123
- (0, 0, 255), 2)
124
-
125
- recognition_results.append({
126
- 'face_id': face_info['face_id'],
127
- 'name': 'Unknown',
128
- 'similarity': 0.0,
129
- 'position': pos,
130
- 'database_id': None
131
- })
132
-
133
- return annotated_img, recognition_results
134
-
135
- except Exception as e:
136
- logger.error(f"Face recognition failed: {e}")
137
- return image, []
138
-
139
- def add_new_face(self, image: np.ndarray, name: str, face_region: Tuple = None) -> bool:
140
- """
141
- Add a new face to the database
142
-
143
- Args:
144
- image: Full image containing the face
145
- name: Person's name
146
- face_region: Optional (x, y, w, h) region, if None will detect automatically
147
-
148
- Returns:
149
- Success status
150
- """
151
- try:
152
- if not db:
153
- logger.error("Database not available")
154
- return False
155
-
156
- if face_region:
157
- # Use provided region
158
- x, y, w, h = face_region
159
- face_img = image[y:y+h, x:x+w]
160
- else:
161
- # Detect face automatically
162
- _, face_stats = face_detector.detect_faces(image)
163
-
164
- if not face_stats or not face_stats['faces']:
165
- logger.error("No face detected in image")
166
- return False
167
-
168
- # Use first detected face
169
- pos = face_stats['faces'][0]['position']
170
- x, y, w, h = pos['x'], pos['y'], pos['width'], pos['height']
171
- face_img = image[y:y+h, x:x+w]
172
-
173
- # Extract encoding
174
- encoding = self.extract_face_encoding(face_img)
175
- if encoding is None:
176
- logger.error("Failed to extract face encoding")
177
- return False
178
-
179
- # Add to database
180
- face_id = db.add_face(
181
- name=name,
182
- face_encoding=encoding,
183
- image=face_img,
184
- confidence=0.9,
185
- metadata={
186
- 'source': 'user_added',
187
- 'session_id': self.session_id,
188
- 'timestamp': time.time()
189
- }
190
- )
191
-
192
- logger.info(f"Added new face '{name}' with ID {face_id}")
193
- return True
194
-
195
- except Exception as e:
196
- logger.error(f"Failed to add new face: {e}")
197
- return False
198
-
199
- def add_custom_object(self, image: np.ndarray, label: str, category: str,
200
- bbox: Tuple = None) -> bool:
201
- """
202
- Add a custom object to the database
203
-
204
- Args:
205
- image: Full image containing the object
206
- label: Object label/name
207
- category: Object category
208
- bbox: Optional (x, y, w, h) bounding box
209
-
210
- Returns:
211
- Success status
212
- """
213
- try:
214
- if not db:
215
- logger.error("Database not available")
216
- return False
217
-
218
- if bbox:
219
- # Use provided bounding box
220
- x, y, w, h = bbox
221
- object_img = image[y:y+h, x:x+w]
222
- else:
223
- # Use entire image as object
224
- object_img = image
225
- bbox = (0, 0, image.shape[1], image.shape[0])
226
-
227
- # Extract simple features (can be enhanced with deep learning)
228
- features = self.extract_object_features(object_img)
229
-
230
- # Add to database
231
- object_id = db.add_object(
232
- label=label,
233
- category=category,
234
- features=features,
235
- image=object_img,
236
- bounding_box=bbox,
237
- confidence=0.8,
238
- metadata={
239
- 'source': 'user_added',
240
- 'session_id': self.session_id,
241
- 'timestamp': time.time()
242
- }
243
- )
244
-
245
- logger.info(f"Added custom object '{label}' with ID {object_id}")
246
- return True
247
-
248
- except Exception as e:
249
- logger.error(f"Failed to add custom object: {e}")
250
- return False
251
-
252
- def extract_object_features(self, object_img: np.ndarray) -> np.ndarray:
253
- """Extract features from object image (simplified implementation)"""
254
- try:
255
- # Convert to grayscale and resize
256
- gray = cv2.cvtColor(object_img, cv2.COLOR_RGB2GRAY)
257
- resized = cv2.resize(gray, (64, 64))
258
-
259
- # Extract histogram features
260
- hist = cv2.calcHist([resized], [0], None, [256], [0, 256])
261
- hist_normalized = hist.flatten() / hist.sum()
262
-
263
- # Extract edge features
264
- edges = cv2.Canny(resized, 50, 150)
265
- edge_density = edges.sum() / edges.size
266
-
267
- # Combine features
268
- features = np.concatenate([hist_normalized, [edge_density]])
269
-
270
- return features.astype(np.float64)
271
-
272
- except Exception as e:
273
- logger.error(f"Feature extraction failed: {e}")
274
- return np.array([])
275
-
276
- def enhance_with_rag(self, detections: List, face_matches: List = None) -> str:
277
- """
278
- Use RAG to enhance detection results with context
279
-
280
- Args:
281
- detections: List of detected objects
282
- face_matches: List of face recognition results
283
-
284
- Returns:
285
- Enhanced description with context
286
- """
287
- try:
288
- if not db:
289
- return "Enhanced analysis not available (database offline)"
290
-
291
- # Build search queries from detections
292
- queries = []
293
-
294
- # Add object queries
295
- for detection in detections:
296
- queries.append(detection)
297
-
298
- # Add face queries
299
- if face_matches:
300
- for match in face_matches:
301
- if match['name'] != 'Unknown':
302
- queries.append(match['name'])
303
-
304
- # Search knowledge base
305
- knowledge_results = []
306
- for query in queries:
307
- results = db.search_knowledge(query)
308
- knowledge_results.extend(results)
309
-
310
- # Build enhanced description
311
- if knowledge_results:
312
- enhanced_desc = "🧠 **Enhanced Analysis with Context:**\n\n"
313
-
314
- # Group by entity type
315
- face_context = [r for r in knowledge_results if r['entity_type'] == 'face']
316
- object_context = [r for r in knowledge_results if r['entity_type'] == 'object']
317
-
318
- if face_context:
319
- enhanced_desc += "👥 **Known Individuals:**\n"
320
- for ctx in face_context[:3]: # Limit to 3 results
321
- enhanced_desc += f" • {ctx['content']}\n"
322
- enhanced_desc += "\n"
323
-
324
- if object_context:
325
- enhanced_desc += "🏷️ **Recognized Objects:**\n"
326
- for ctx in object_context[:3]: # Limit to 3 results
327
- enhanced_desc += f" • {ctx['content']}\n"
328
- enhanced_desc += "\n"
329
-
330
- enhanced_desc += "📊 **Context Insights:**\n"
331
- enhanced_desc += f" • Found {len(knowledge_results)} relevant knowledge entries\n"
332
- enhanced_desc += f" • Analysis includes both detected and learned objects\n"
333
-
334
- return enhanced_desc
335
- else:
336
- return "🔍 **Context Analysis:** No additional context found in knowledge base."
337
-
338
- except Exception as e:
339
- logger.error(f"RAG enhancement failed: {e}")
340
- return "❌ Enhanced analysis unavailable due to processing error."
341
-
342
- def save_session_data(self, image: np.ndarray, detections: List,
343
- face_matches: List = None, processing_time: float = 0.0):
344
- """Save current session data to database"""
345
- try:
346
- if db:
347
- db.save_detection_history(
348
- session_id=self.session_id,
349
- image=image,
350
- detections=detections,
351
- face_matches=face_matches,
352
- processing_time=processing_time,
353
- metadata={
354
- 'timestamp': time.time(),
355
- 'version': '2.0'
356
- }
357
- )
358
- except Exception as e:
359
- logger.error(f"Failed to save session data: {e}")
360
-
361
- # Global recognition instance
362
- try:
363
- recognition_system = NAVADARecognition()
364
- logger.info("Recognition system initialized successfully")
365
- except Exception as e:
366
- logger.error(f"Failed to initialize recognition system: {e}")
367
  recognition_system = None
 
1
+ """
2
+ Advanced Recognition Module for NAVADA
3
+ Handles face recognition, custom object detection, and RAG-enhanced identification
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from typing import List, Dict, Tuple, Optional
9
+ import logging
10
+ from .database import db
11
+ from .face_detection import face_detector
12
+ import time
13
+ import uuid
14
+
15
+ # Configure logging
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class NAVADARecognition:
19
+ """Advanced recognition system with database integration"""
20
+
21
+ def __init__(self):
22
+ """Initialize recognition system"""
23
+ self.face_threshold = 0.6 # Face recognition threshold
24
+ self.object_threshold = 0.5 # Object recognition threshold
25
+ self.session_id = str(uuid.uuid4())
26
+
27
+ def extract_face_encoding(self, face_image: np.ndarray) -> Optional[np.ndarray]:
28
+ """
29
+ Extract face encoding for recognition
30
+ This is a simplified version - in production, use face_recognition library
31
+ """
32
+ try:
33
+ # Convert to grayscale and resize
34
+ gray = cv2.cvtColor(face_image, cv2.COLOR_RGB2GRAY)
35
+ resized = cv2.resize(gray, (128, 128))
36
+
37
+ # Flatten and normalize as simple encoding
38
+ encoding = resized.flatten().astype(np.float64)
39
+ encoding = encoding / np.linalg.norm(encoding) # Normalize
40
+
41
+ return encoding
42
+
43
+ except Exception as e:
44
+ logger.error(f"Face encoding extraction failed: {e}")
45
+ return None
46
+
47
+ def compare_face_encodings(self, encoding1: np.ndarray, encoding2: np.ndarray) -> float:
48
+ """Compare two face encodings and return similarity score"""
49
+ try:
50
+ # Calculate cosine similarity
51
+ similarity = np.dot(encoding1, encoding2) / (
52
+ np.linalg.norm(encoding1) * np.linalg.norm(encoding2)
53
+ )
54
+ return float(similarity)
55
+
56
+ except Exception as e:
57
+ logger.error(f"Face comparison failed: {e}")
58
+ return 0.0
59
+
60
+ def recognize_faces(self, image: np.ndarray) -> Tuple[np.ndarray, List[Dict]]:
61
+ """
62
+ Recognize faces in image against database
63
+
64
+ Returns:
65
+ Annotated image and list of recognition results
66
+ """
67
+ try:
68
+ if not db:
69
+ return image, []
70
+
71
+ # Detect faces first
72
+ annotated_img, face_stats = face_detector.detect_faces(image)
73
+
74
+ # Get known faces from database
75
+ known_faces = db.get_faces()
76
+
77
+ recognition_results = []
78
+
79
+ if face_stats and face_stats['faces']:
80
+ for face_info in face_stats['faces']:
81
+ # Extract face region
82
+ pos = face_info['position']
83
+ x, y, w, h = pos['x'], pos['y'], pos['width'], pos['height']
84
+ face_region = image[y:y+h, x:x+w]
85
+
86
+ if face_region.size > 0:
87
+ # Extract face encoding
88
+ face_encoding = self.extract_face_encoding(face_region)
89
+
90
+ if face_encoding is not None:
91
+ # Compare with known faces
92
+ best_match = None
93
+ best_similarity = 0.0
94
+
95
+ for known_face in known_faces:
96
+ similarity = self.compare_face_encodings(
97
+ face_encoding, known_face['encoding']
98
+ )
99
+
100
+ if similarity > best_similarity and similarity > self.face_threshold:
101
+ best_similarity = similarity
102
+ best_match = known_face
103
+
104
+ # Add recognition result
105
+ if best_match:
106
+ # Draw name on image
107
+ name = best_match['name']
108
+ cv2.putText(annotated_img, f"{name} ({best_similarity:.2f})",
109
+ (x, y-30), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
110
+ (0, 255, 0), 2)
111
+
112
+ recognition_results.append({
113
+ 'face_id': face_info['face_id'],
114
+ 'name': name,
115
+ 'similarity': best_similarity,
116
+ 'position': pos,
117
+ 'database_id': best_match['id']
118
+ })
119
+ else:
120
+ # Unknown face
121
+ cv2.putText(annotated_img, "Unknown",
122
+ (x, y-30), cv2.FONT_HERSHEY_SIMPLEX, 0.6,
123
+ (0, 0, 255), 2)
124
+
125
+ recognition_results.append({
126
+ 'face_id': face_info['face_id'],
127
+ 'name': 'Unknown',
128
+ 'similarity': 0.0,
129
+ 'position': pos,
130
+ 'database_id': None
131
+ })
132
+
133
+ return annotated_img, recognition_results
134
+
135
+ except Exception as e:
136
+ logger.error(f"Face recognition failed: {e}")
137
+ return image, []
138
+
139
+ def add_new_face(self, image: np.ndarray, name: str, face_region: Tuple = None) -> bool:
140
+ """
141
+ Add a new face to the database
142
+
143
+ Args:
144
+ image: Full image containing the face
145
+ name: Person's name
146
+ face_region: Optional (x, y, w, h) region, if None will detect automatically
147
+
148
+ Returns:
149
+ Success status
150
+ """
151
+ try:
152
+ if not db:
153
+ logger.error("Database not available")
154
+ return False
155
+
156
+ if face_region:
157
+ # Use provided region
158
+ x, y, w, h = face_region
159
+ face_img = image[y:y+h, x:x+w]
160
+ else:
161
+ # Detect face automatically
162
+ _, face_stats = face_detector.detect_faces(image)
163
+
164
+ if not face_stats or not face_stats['faces']:
165
+ logger.error("No face detected in image")
166
+ return False
167
+
168
+ # Use first detected face
169
+ pos = face_stats['faces'][0]['position']
170
+ x, y, w, h = pos['x'], pos['y'], pos['width'], pos['height']
171
+ face_img = image[y:y+h, x:x+w]
172
+
173
+ # Extract encoding
174
+ encoding = self.extract_face_encoding(face_img)
175
+ if encoding is None:
176
+ logger.error("Failed to extract face encoding")
177
+ return False
178
+
179
+ # Add to database
180
+ face_id = db.add_face(
181
+ name=name,
182
+ face_encoding=encoding,
183
+ image=face_img,
184
+ confidence=0.9,
185
+ metadata={
186
+ 'source': 'user_added',
187
+ 'session_id': self.session_id,
188
+ 'timestamp': time.time()
189
+ }
190
+ )
191
+
192
+ logger.info(f"Added new face '{name}' with ID {face_id}")
193
+ return True
194
+
195
+ except Exception as e:
196
+ logger.error(f"Failed to add new face: {e}")
197
+ return False
198
+
199
+ def add_custom_object(self, image: np.ndarray, label: str, category: str,
200
+ bbox: Tuple = None) -> bool:
201
+ """
202
+ Add a custom object to the database
203
+
204
+ Args:
205
+ image: Full image containing the object
206
+ label: Object label/name
207
+ category: Object category
208
+ bbox: Optional (x, y, w, h) bounding box
209
+
210
+ Returns:
211
+ Success status
212
+ """
213
+ try:
214
+ if not db:
215
+ logger.error("Database not available")
216
+ return False
217
+
218
+ if bbox:
219
+ # Use provided bounding box
220
+ x, y, w, h = bbox
221
+ object_img = image[y:y+h, x:x+w]
222
+ else:
223
+ # Use entire image as object
224
+ object_img = image
225
+ bbox = (0, 0, image.shape[1], image.shape[0])
226
+
227
+ # Extract simple features (can be enhanced with deep learning)
228
+ features = self.extract_object_features(object_img)
229
+
230
+ # Add to database
231
+ object_id = db.add_object(
232
+ label=label,
233
+ category=category,
234
+ features=features,
235
+ image=object_img,
236
+ bounding_box=bbox,
237
+ confidence=0.8,
238
+ metadata={
239
+ 'source': 'user_added',
240
+ 'session_id': self.session_id,
241
+ 'timestamp': time.time()
242
+ }
243
+ )
244
+
245
+ logger.info(f"Added custom object '{label}' with ID {object_id}")
246
+ return True
247
+
248
+ except Exception as e:
249
+ logger.error(f"Failed to add custom object: {e}")
250
+ return False
251
+
252
+ def extract_object_features(self, object_img: np.ndarray) -> np.ndarray:
253
+ """Extract features from object image (simplified implementation)"""
254
+ try:
255
+ # Convert to grayscale and resize
256
+ gray = cv2.cvtColor(object_img, cv2.COLOR_RGB2GRAY)
257
+ resized = cv2.resize(gray, (64, 64))
258
+
259
+ # Extract histogram features
260
+ hist = cv2.calcHist([resized], [0], None, [256], [0, 256])
261
+ hist_normalized = hist.flatten() / hist.sum()
262
+
263
+ # Extract edge features
264
+ edges = cv2.Canny(resized, 50, 150)
265
+ edge_density = edges.sum() / edges.size
266
+
267
+ # Combine features
268
+ features = np.concatenate([hist_normalized, [edge_density]])
269
+
270
+ return features.astype(np.float64)
271
+
272
+ except Exception as e:
273
+ logger.error(f"Feature extraction failed: {e}")
274
+ return np.array([])
275
+
276
+ def enhance_with_rag(self, detections: List, face_matches: List = None) -> str:
277
+ """
278
+ Use RAG to enhance detection results with context
279
+
280
+ Args:
281
+ detections: List of detected objects
282
+ face_matches: List of face recognition results
283
+
284
+ Returns:
285
+ Enhanced description with context
286
+ """
287
+ try:
288
+ if not db:
289
+ return "Enhanced analysis not available (database offline)"
290
+
291
+ # Build search queries from detections
292
+ queries = []
293
+
294
+ # Add object queries
295
+ for detection in detections:
296
+ queries.append(detection)
297
+
298
+ # Add face queries
299
+ if face_matches:
300
+ for match in face_matches:
301
+ if match['name'] != 'Unknown':
302
+ queries.append(match['name'])
303
+
304
+ # Search knowledge base
305
+ knowledge_results = []
306
+ for query in queries:
307
+ results = db.search_knowledge(query)
308
+ knowledge_results.extend(results)
309
+
310
+ # Build enhanced description
311
+ if knowledge_results:
312
+ enhanced_desc = "🧠 **Enhanced Analysis with Context:**\n\n"
313
+
314
+ # Group by entity type
315
+ face_context = [r for r in knowledge_results if r['entity_type'] == 'face']
316
+ object_context = [r for r in knowledge_results if r['entity_type'] == 'object']
317
+
318
+ if face_context:
319
+ enhanced_desc += "👥 **Known Individuals:**\n"
320
+ for ctx in face_context[:3]: # Limit to 3 results
321
+ enhanced_desc += f" • {ctx['content']}\n"
322
+ enhanced_desc += "\n"
323
+
324
+ if object_context:
325
+ enhanced_desc += "🏷️ **Recognized Objects:**\n"
326
+ for ctx in object_context[:3]: # Limit to 3 results
327
+ enhanced_desc += f" • {ctx['content']}\n"
328
+ enhanced_desc += "\n"
329
+
330
+ enhanced_desc += "📊 **Context Insights:**\n"
331
+ enhanced_desc += f" • Found {len(knowledge_results)} relevant knowledge entries\n"
332
+ enhanced_desc += f" • Analysis includes both detected and learned objects\n"
333
+
334
+ return enhanced_desc
335
+ else:
336
+ return "🔍 **Context Analysis:** No additional context found in knowledge base."
337
+
338
+ except Exception as e:
339
+ logger.error(f"RAG enhancement failed: {e}")
340
+ return "❌ Enhanced analysis unavailable due to processing error."
341
+
342
+ def save_session_data(self, image: np.ndarray, detections: List,
343
+ face_matches: List = None, processing_time: float = 0.0):
344
+ """Save current session data to database"""
345
+ try:
346
+ if db:
347
+ db.save_detection_history(
348
+ session_id=self.session_id,
349
+ image=image,
350
+ detections=detections,
351
+ face_matches=face_matches,
352
+ processing_time=processing_time,
353
+ metadata={
354
+ 'timestamp': time.time(),
355
+ 'version': '2.0'
356
+ }
357
+ )
358
+ except Exception as e:
359
+ logger.error(f"Failed to save session data: {e}")
360
+
361
+ # Global recognition instance
362
+ try:
363
+ recognition_system = NAVADARecognition()
364
+ logger.info("Recognition system initialized successfully")
365
+ except Exception as e:
366
+ logger.error(f"Failed to initialize recognition system: {e}")
367
  recognition_system = None
backend/two_stage_inference.py CHANGED
@@ -1,285 +1,285 @@
1
- """
2
- Two-Stage Inference System
3
- Combines YOLO detection with custom classifier for improved accuracy
4
- """
5
- import torch
6
- import torchvision.transforms as transforms
7
- import numpy as np
8
- import cv2
9
- from typing import List, Dict, Tuple, Optional
10
- import pickle
11
- from pathlib import Path
12
- import logging
13
-
14
- from .yolo_enhanced import detect_objects_enhanced, model as yolo_model
15
- from .custom_trainer import CustomClassifier
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
- class TwoStageInference:
20
- """Two-stage detection and classification system"""
21
-
22
- def __init__(self, models_dir='models/'):
23
- """
24
- Initialize two-stage inference system
25
-
26
- Args:
27
- models_dir: Directory containing trained custom models
28
- """
29
- self.models_dir = Path(models_dir)
30
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
-
32
- # Load active custom model if available
33
- self.custom_model = None
34
- self.class_info = None
35
- self.load_active_model()
36
-
37
- # Image preprocessing for custom classifier
38
- self.preprocess = transforms.Compose([
39
- transforms.ToPILImage(),
40
- transforms.Resize((224, 224)),
41
- transforms.ToTensor(),
42
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
43
- ])
44
-
45
- def load_active_model(self):
46
- """Load the most recent trained custom model"""
47
- try:
48
- # Find latest model file
49
- model_files = list(self.models_dir.glob('custom_classifier_*.pkl'))
50
- if not model_files:
51
- logger.info("No custom models found. Using YOLO only.")
52
- return
53
-
54
- # Get most recent model
55
- latest_model = max(model_files, key=lambda x: x.stat().st_mtime)
56
-
57
- # Load model info
58
- with open(latest_model, 'rb') as f:
59
- model_info = pickle.load(f)
60
-
61
- self.class_info = model_info['class_info']
62
-
63
- # Initialize and load custom model
64
- self.custom_model = CustomClassifier(
65
- num_classes=self.class_info['num_classes'],
66
- backbone=model_info['training_config']['backbone']
67
- )
68
- self.custom_model.load_state_dict(model_info['model_state'])
69
- self.custom_model = self.custom_model.to(self.device)
70
- self.custom_model.eval()
71
-
72
- logger.info(f"Loaded custom model: {latest_model.name}")
73
- logger.info(f"Custom classes: {list(self.class_info['idx_to_label'].values())}")
74
-
75
- except Exception as e:
76
- logger.error(f"Failed to load custom model: {e}")
77
- self.custom_model = None
78
- self.class_info = None
79
-
80
- def classify_object(self, object_crop: np.ndarray) -> Tuple[str, float]:
81
- """
82
- Classify object crop using custom model
83
-
84
- Args:
85
- object_crop: Cropped image region
86
-
87
- Returns:
88
- Tuple of (predicted_label, confidence)
89
- """
90
- if self.custom_model is None:
91
- return None, 0.0
92
-
93
- try:
94
- # Preprocess image
95
- if object_crop.size == 0:
96
- return None, 0.0
97
-
98
- # Convert BGR to RGB
99
- if len(object_crop.shape) == 3 and object_crop.shape[2] == 3:
100
- object_crop = cv2.cvtColor(object_crop, cv2.COLOR_BGR2RGB)
101
-
102
- # Preprocess for model
103
- input_tensor = self.preprocess(object_crop).unsqueeze(0).to(self.device)
104
-
105
- # Inference
106
- with torch.no_grad():
107
- outputs = self.custom_model(input_tensor)
108
- probabilities = torch.softmax(outputs, dim=1)
109
- confidence, predicted = torch.max(probabilities, 1)
110
-
111
- predicted_idx = predicted.item()
112
- confidence_score = confidence.item()
113
-
114
- # Convert to label
115
- predicted_label = self.class_info['idx_to_label'][predicted_idx]
116
-
117
- return predicted_label, confidence_score
118
-
119
- except Exception as e:
120
- logger.error(f"Custom classification failed: {e}")
121
- return None, 0.0
122
-
123
- def should_override_yolo(self, yolo_label: str, yolo_confidence: float,
124
- custom_label: str, custom_confidence: float) -> bool:
125
- """
126
- Decide whether to override YOLO prediction with custom model
127
-
128
- Args:
129
- yolo_label: YOLO predicted label
130
- yolo_confidence: YOLO confidence
131
- custom_label: Custom model predicted label
132
- custom_confidence: Custom model confidence
133
-
134
- Returns:
135
- True if should use custom model prediction
136
- """
137
- # Don't override if custom model not confident enough
138
- if custom_confidence < 0.7:
139
- return False
140
-
141
- # Always override if YOLO has low confidence and custom has high
142
- if yolo_confidence < 0.5 and custom_confidence > 0.8:
143
- return True
144
-
145
- # Override if custom model is significantly more confident
146
- if custom_confidence > yolo_confidence + 0.2:
147
- return True
148
-
149
- # Override if we have training data for this custom class
150
- if custom_label in self.class_info.get('valid_classes', []):
151
- return True
152
-
153
- return False
154
-
155
- def detect_with_custom_model(self, image: np.ndarray, confidence_threshold: float = 0.5) -> Tuple[np.ndarray, List[str], List[Dict]]:
156
- """
157
- Two-stage detection: YOLO + Custom Classification
158
-
159
- Args:
160
- image: Input image
161
- confidence_threshold: YOLO confidence threshold
162
-
163
- Returns:
164
- Tuple of (annotated_image, detected_objects, detailed_attributes)
165
- """
166
- # Stage 1: YOLO Detection
167
- try:
168
- annotated_img, detected_objects, detailed_attributes = detect_objects_enhanced(
169
- image, confidence_threshold
170
- )
171
- except:
172
- # Fallback to basic YOLO
173
- from .yolo import detect_objects
174
- annotated_img, detected_objects = detect_objects(image)
175
- detailed_attributes = []
176
-
177
- # Stage 2: Custom Classification (if model available)
178
- if self.custom_model is None or not detailed_attributes:
179
- return annotated_img, detected_objects, detailed_attributes
180
-
181
- # Process each detection with custom model
182
- enhanced_attributes = []
183
- enhanced_objects = []
184
-
185
- for i, attr in enumerate(detailed_attributes):
186
- yolo_label = attr['label']
187
- yolo_confidence = float(attr['confidence'].rstrip('%')) / 100.0
188
- bbox = attr.get('bbox', [0, 0, 100, 100])
189
-
190
- # Extract object region
191
- x1, y1, x2, y2 = [int(coord) for coord in bbox]
192
- object_crop = image[max(0, y1):min(image.shape[0], y2),
193
- max(0, x1):min(image.shape[1], x2)]
194
-
195
- # Classify with custom model
196
- custom_label, custom_confidence = self.classify_object(object_crop)
197
-
198
- # Decide which prediction to use
199
- if custom_label and self.should_override_yolo(yolo_label, yolo_confidence,
200
- custom_label, custom_confidence):
201
- # Use custom model prediction
202
- final_label = custom_label
203
- final_confidence = custom_confidence
204
- attr['prediction_source'] = 'custom_model'
205
- attr['original_yolo'] = {'label': yolo_label, 'confidence': yolo_confidence}
206
- else:
207
- # Use YOLO prediction
208
- final_label = yolo_label
209
- final_confidence = yolo_confidence
210
- attr['prediction_source'] = 'yolo'
211
- if custom_label:
212
- attr['custom_alternative'] = {'label': custom_label, 'confidence': custom_confidence}
213
-
214
- # Update attributes
215
- attr['label'] = final_label
216
- attr['confidence'] = f"{final_confidence:.2%}"
217
-
218
- enhanced_attributes.append(attr)
219
- enhanced_objects.append(final_label)
220
-
221
- # Update annotated image if we made changes
222
- if any(attr.get('prediction_source') == 'custom_model' for attr in enhanced_attributes):
223
- # Re-annotate image with updated predictions
224
- annotated_img = self.annotate_image_with_predictions(image, enhanced_attributes)
225
-
226
- return annotated_img, enhanced_objects, enhanced_attributes
227
-
228
- def annotate_image_with_predictions(self, image: np.ndarray, attributes: List[Dict]) -> np.ndarray:
229
- """
230
- Annotate image with updated predictions
231
-
232
- Args:
233
- image: Original image
234
- attributes: Detection attributes with updated labels
235
-
236
- Returns:
237
- Annotated image
238
- """
239
- annotated = image.copy()
240
-
241
- for attr in attributes:
242
- bbox = attr.get('bbox', [0, 0, 100, 100])
243
- label = attr['label']
244
- confidence = attr['confidence']
245
- source = attr.get('prediction_source', 'yolo')
246
-
247
- x1, y1, x2, y2 = [int(coord) for coord in bbox]
248
-
249
- # Choose color based on source
250
- if source == 'custom_model':
251
- color = (0, 255, 0) # Green for custom model
252
- label_text = f"{label} {confidence} (Custom)"
253
- else:
254
- color = (255, 0, 0) # Red for YOLO
255
- label_text = f"{label} {confidence}"
256
-
257
- # Draw bounding box
258
- cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
259
-
260
- # Draw label
261
- cv2.putText(annotated, label_text, (x1, y1-10),
262
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
263
-
264
- return annotated
265
-
266
- def get_model_info(self) -> Dict:
267
- """Get information about loaded models"""
268
- info = {
269
- 'yolo_model': 'YOLOv8m',
270
- 'custom_model_loaded': self.custom_model is not None,
271
- 'device': self.device
272
- }
273
-
274
- if self.custom_model is not None and self.class_info is not None:
275
- info.update({
276
- 'custom_classes': list(self.class_info['idx_to_label'].values()),
277
- 'num_custom_classes': self.class_info['num_classes'],
278
- 'training_samples': self.class_info.get('train_samples', 0),
279
- 'validation_samples': self.class_info.get('val_samples', 0)
280
- })
281
-
282
- return info
283
-
284
- # Global two-stage inference instance
285
  two_stage_inference = TwoStageInference()
 
1
+ """
2
+ Two-Stage Inference System
3
+ Combines YOLO detection with custom classifier for improved accuracy
4
+ """
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import numpy as np
8
+ import cv2
9
+ from typing import List, Dict, Tuple, Optional
10
+ import pickle
11
+ from pathlib import Path
12
+ import logging
13
+
14
+ from .yolo_enhanced import detect_objects_enhanced, model as yolo_model
15
+ from .custom_trainer import CustomClassifier
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class TwoStageInference:
20
+ """Two-stage detection and classification system"""
21
+
22
+ def __init__(self, models_dir='models/'):
23
+ """
24
+ Initialize two-stage inference system
25
+
26
+ Args:
27
+ models_dir: Directory containing trained custom models
28
+ """
29
+ self.models_dir = Path(models_dir)
30
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
+
32
+ # Load active custom model if available
33
+ self.custom_model = None
34
+ self.class_info = None
35
+ self.load_active_model()
36
+
37
+ # Image preprocessing for custom classifier
38
+ self.preprocess = transforms.Compose([
39
+ transforms.ToPILImage(),
40
+ transforms.Resize((224, 224)),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
43
+ ])
44
+
45
+ def load_active_model(self):
46
+ """Load the most recent trained custom model"""
47
+ try:
48
+ # Find latest model file
49
+ model_files = list(self.models_dir.glob('custom_classifier_*.pkl'))
50
+ if not model_files:
51
+ logger.info("No custom models found. Using YOLO only.")
52
+ return
53
+
54
+ # Get most recent model
55
+ latest_model = max(model_files, key=lambda x: x.stat().st_mtime)
56
+
57
+ # Load model info
58
+ with open(latest_model, 'rb') as f:
59
+ model_info = pickle.load(f)
60
+
61
+ self.class_info = model_info['class_info']
62
+
63
+ # Initialize and load custom model
64
+ self.custom_model = CustomClassifier(
65
+ num_classes=self.class_info['num_classes'],
66
+ backbone=model_info['training_config']['backbone']
67
+ )
68
+ self.custom_model.load_state_dict(model_info['model_state'])
69
+ self.custom_model = self.custom_model.to(self.device)
70
+ self.custom_model.eval()
71
+
72
+ logger.info(f"Loaded custom model: {latest_model.name}")
73
+ logger.info(f"Custom classes: {list(self.class_info['idx_to_label'].values())}")
74
+
75
+ except Exception as e:
76
+ logger.error(f"Failed to load custom model: {e}")
77
+ self.custom_model = None
78
+ self.class_info = None
79
+
80
+ def classify_object(self, object_crop: np.ndarray) -> Tuple[str, float]:
81
+ """
82
+ Classify object crop using custom model
83
+
84
+ Args:
85
+ object_crop: Cropped image region
86
+
87
+ Returns:
88
+ Tuple of (predicted_label, confidence)
89
+ """
90
+ if self.custom_model is None:
91
+ return None, 0.0
92
+
93
+ try:
94
+ # Preprocess image
95
+ if object_crop.size == 0:
96
+ return None, 0.0
97
+
98
+ # Convert BGR to RGB
99
+ if len(object_crop.shape) == 3 and object_crop.shape[2] == 3:
100
+ object_crop = cv2.cvtColor(object_crop, cv2.COLOR_BGR2RGB)
101
+
102
+ # Preprocess for model
103
+ input_tensor = self.preprocess(object_crop).unsqueeze(0).to(self.device)
104
+
105
+ # Inference
106
+ with torch.no_grad():
107
+ outputs = self.custom_model(input_tensor)
108
+ probabilities = torch.softmax(outputs, dim=1)
109
+ confidence, predicted = torch.max(probabilities, 1)
110
+
111
+ predicted_idx = predicted.item()
112
+ confidence_score = confidence.item()
113
+
114
+ # Convert to label
115
+ predicted_label = self.class_info['idx_to_label'][predicted_idx]
116
+
117
+ return predicted_label, confidence_score
118
+
119
+ except Exception as e:
120
+ logger.error(f"Custom classification failed: {e}")
121
+ return None, 0.0
122
+
123
+ def should_override_yolo(self, yolo_label: str, yolo_confidence: float,
124
+ custom_label: str, custom_confidence: float) -> bool:
125
+ """
126
+ Decide whether to override YOLO prediction with custom model
127
+
128
+ Args:
129
+ yolo_label: YOLO predicted label
130
+ yolo_confidence: YOLO confidence
131
+ custom_label: Custom model predicted label
132
+ custom_confidence: Custom model confidence
133
+
134
+ Returns:
135
+ True if should use custom model prediction
136
+ """
137
+ # Don't override if custom model not confident enough
138
+ if custom_confidence < 0.7:
139
+ return False
140
+
141
+ # Always override if YOLO has low confidence and custom has high
142
+ if yolo_confidence < 0.5 and custom_confidence > 0.8:
143
+ return True
144
+
145
+ # Override if custom model is significantly more confident
146
+ if custom_confidence > yolo_confidence + 0.2:
147
+ return True
148
+
149
+ # Override if we have training data for this custom class
150
+ if custom_label in self.class_info.get('valid_classes', []):
151
+ return True
152
+
153
+ return False
154
+
155
+ def detect_with_custom_model(self, image: np.ndarray, confidence_threshold: float = 0.5) -> Tuple[np.ndarray, List[str], List[Dict]]:
156
+ """
157
+ Two-stage detection: YOLO + Custom Classification
158
+
159
+ Args:
160
+ image: Input image
161
+ confidence_threshold: YOLO confidence threshold
162
+
163
+ Returns:
164
+ Tuple of (annotated_image, detected_objects, detailed_attributes)
165
+ """
166
+ # Stage 1: YOLO Detection
167
+ try:
168
+ annotated_img, detected_objects, detailed_attributes = detect_objects_enhanced(
169
+ image, confidence_threshold
170
+ )
171
+ except:
172
+ # Fallback to basic YOLO
173
+ from .yolo import detect_objects
174
+ annotated_img, detected_objects = detect_objects(image)
175
+ detailed_attributes = []
176
+
177
+ # Stage 2: Custom Classification (if model available)
178
+ if self.custom_model is None or not detailed_attributes:
179
+ return annotated_img, detected_objects, detailed_attributes
180
+
181
+ # Process each detection with custom model
182
+ enhanced_attributes = []
183
+ enhanced_objects = []
184
+
185
+ for i, attr in enumerate(detailed_attributes):
186
+ yolo_label = attr['label']
187
+ yolo_confidence = float(attr['confidence'].rstrip('%')) / 100.0
188
+ bbox = attr.get('bbox', [0, 0, 100, 100])
189
+
190
+ # Extract object region
191
+ x1, y1, x2, y2 = [int(coord) for coord in bbox]
192
+ object_crop = image[max(0, y1):min(image.shape[0], y2),
193
+ max(0, x1):min(image.shape[1], x2)]
194
+
195
+ # Classify with custom model
196
+ custom_label, custom_confidence = self.classify_object(object_crop)
197
+
198
+ # Decide which prediction to use
199
+ if custom_label and self.should_override_yolo(yolo_label, yolo_confidence,
200
+ custom_label, custom_confidence):
201
+ # Use custom model prediction
202
+ final_label = custom_label
203
+ final_confidence = custom_confidence
204
+ attr['prediction_source'] = 'custom_model'
205
+ attr['original_yolo'] = {'label': yolo_label, 'confidence': yolo_confidence}
206
+ else:
207
+ # Use YOLO prediction
208
+ final_label = yolo_label
209
+ final_confidence = yolo_confidence
210
+ attr['prediction_source'] = 'yolo'
211
+ if custom_label:
212
+ attr['custom_alternative'] = {'label': custom_label, 'confidence': custom_confidence}
213
+
214
+ # Update attributes
215
+ attr['label'] = final_label
216
+ attr['confidence'] = f"{final_confidence:.2%}"
217
+
218
+ enhanced_attributes.append(attr)
219
+ enhanced_objects.append(final_label)
220
+
221
+ # Update annotated image if we made changes
222
+ if any(attr.get('prediction_source') == 'custom_model' for attr in enhanced_attributes):
223
+ # Re-annotate image with updated predictions
224
+ annotated_img = self.annotate_image_with_predictions(image, enhanced_attributes)
225
+
226
+ return annotated_img, enhanced_objects, enhanced_attributes
227
+
228
+ def annotate_image_with_predictions(self, image: np.ndarray, attributes: List[Dict]) -> np.ndarray:
229
+ """
230
+ Annotate image with updated predictions
231
+
232
+ Args:
233
+ image: Original image
234
+ attributes: Detection attributes with updated labels
235
+
236
+ Returns:
237
+ Annotated image
238
+ """
239
+ annotated = image.copy()
240
+
241
+ for attr in attributes:
242
+ bbox = attr.get('bbox', [0, 0, 100, 100])
243
+ label = attr['label']
244
+ confidence = attr['confidence']
245
+ source = attr.get('prediction_source', 'yolo')
246
+
247
+ x1, y1, x2, y2 = [int(coord) for coord in bbox]
248
+
249
+ # Choose color based on source
250
+ if source == 'custom_model':
251
+ color = (0, 255, 0) # Green for custom model
252
+ label_text = f"{label} {confidence} (Custom)"
253
+ else:
254
+ color = (255, 0, 0) # Red for YOLO
255
+ label_text = f"{label} {confidence}"
256
+
257
+ # Draw bounding box
258
+ cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
259
+
260
+ # Draw label
261
+ cv2.putText(annotated, label_text, (x1, y1-10),
262
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
263
+
264
+ return annotated
265
+
266
+ def get_model_info(self) -> Dict:
267
+ """Get information about loaded models"""
268
+ info = {
269
+ 'yolo_model': 'YOLOv8m',
270
+ 'custom_model_loaded': self.custom_model is not None,
271
+ 'device': self.device
272
+ }
273
+
274
+ if self.custom_model is not None and self.class_info is not None:
275
+ info.update({
276
+ 'custom_classes': list(self.class_info['idx_to_label'].values()),
277
+ 'num_custom_classes': self.class_info['num_classes'],
278
+ 'training_samples': self.class_info.get('train_samples', 0),
279
+ 'validation_samples': self.class_info.get('val_samples', 0)
280
+ })
281
+
282
+ return info
283
+
284
+ # Global two-stage inference instance
285
  two_stage_inference = TwoStageInference()
backend/yolo.py CHANGED
@@ -1,34 +1,34 @@
1
- from ultralytics import YOLO # type: ignore
2
- import cv2
3
- import numpy as np
4
-
5
- # Load a pre-trained YOLOv8 model (nano version = small & fast)
6
- model = YOLO("yolov8n.pt")
7
-
8
- def detect_objects(image):
9
- """
10
- Run YOLO on the input image.
11
- Returns:
12
- - annotated image with bounding boxes
13
- - list of detected object names
14
- """
15
- # Handle different image formats and channel counts
16
- if isinstance(image, np.ndarray):
17
- # If image has 4 channels (RGBA), convert to RGB
18
- if image.shape[-1] == 4:
19
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
20
- # If image has 1 channel (grayscale), convert to RGB
21
- elif len(image.shape) == 2 or image.shape[-1] == 1:
22
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
23
-
24
- results = model(image)
25
- annotated_img = results[0].plot()
26
-
27
- # Extract detected object names
28
- detected_objects = []
29
- for box in results[0].boxes:
30
- cls_id = int(box.cls[0].item()) # class ID
31
- label = results[0].names[cls_id] # class name
32
- detected_objects.append(label)
33
-
34
- return annotated_img, detected_objects
 
1
+ from ultralytics import YOLO # type: ignore
2
+ import cv2
3
+ import numpy as np
4
+
5
+ # Load a pre-trained YOLOv8 model (nano version = small & fast)
6
+ model = YOLO("yolov8n.pt")
7
+
8
+ def detect_objects(image):
9
+ """
10
+ Run YOLO on the input image.
11
+ Returns:
12
+ - annotated image with bounding boxes
13
+ - list of detected object names
14
+ """
15
+ # Handle different image formats and channel counts
16
+ if isinstance(image, np.ndarray):
17
+ # If image has 4 channels (RGBA), convert to RGB
18
+ if image.shape[-1] == 4:
19
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
20
+ # If image has 1 channel (grayscale), convert to RGB
21
+ elif len(image.shape) == 2 or image.shape[-1] == 1:
22
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
23
+
24
+ results = model(image)
25
+ annotated_img = results[0].plot()
26
+
27
+ # Extract detected object names
28
+ detected_objects = []
29
+ for box in results[0].boxes:
30
+ cls_id = int(box.cls[0].item()) # class ID
31
+ label = results[0].names[cls_id] # class name
32
+ detected_objects.append(label)
33
+
34
+ return annotated_img, detected_objects
backend/yolo_enhanced.py CHANGED
@@ -1,231 +1,231 @@
1
- """
2
- Enhanced YOLO detection with improved accuracy, color detection, and detailed attributes
3
- """
4
- from ultralytics import YOLO # type: ignore
5
- import cv2 # type: ignore
6
- import numpy as np # type: ignore
7
- from collections import Counter
8
- import webcolors # type: ignore
9
- # from sklearn.cluster import KMeans # type: ignore # Temporarily disabled due to numpy compatibility
10
- import torch # type: ignore
11
-
12
- # Load a more accurate YOLO model
13
- # For better accuracy, use yolov8m.pt or yolov8l.pt instead of yolov8n.pt
14
- model_size = 'yolov8m.pt' # Medium model for better accuracy vs speed balance
15
- model = YOLO(model_size)
16
-
17
- # Set higher confidence threshold for better accuracy
18
- CONFIDENCE_THRESHOLD = 0.5 # Increase this for fewer but more accurate detections
19
- NMS_THRESHOLD = 0.45 # Non-maximum suppression threshold
20
-
21
- def get_dominant_colors(image, n_colors=3):
22
- """
23
- Extract dominant colors from an image region using simple averaging
24
- (K-means temporarily disabled due to numpy compatibility)
25
- """
26
- try:
27
- # Simple color detection without sklearn
28
- # Get average color
29
- avg_color = np.mean(image.reshape(-1, 3), axis=0).astype(int)
30
-
31
- # Get corners for variety
32
- h, w = image.shape[:2]
33
- corners = [
34
- image[0, 0], # Top-left
35
- image[0, w-1] if w > 0 else image[0, 0], # Top-right
36
- image[h-1, 0] if h > 0 else image[0, 0], # Bottom-left
37
- image[h//2, w//2] if h > 0 and w > 0 else image[0, 0] # Center
38
- ]
39
-
40
- color_names = []
41
- # Add average color
42
- try:
43
- color_names.append(get_color_name(avg_color))
44
- except:
45
- color_names.append(f"RGB({avg_color[0]},{avg_color[1]},{avg_color[2]})")
46
-
47
- # Add dominant corner color if different
48
- for corner in corners[:n_colors-1]:
49
- try:
50
- name = get_color_name(corner)
51
- if name not in color_names:
52
- color_names.append(name)
53
- if len(color_names) >= n_colors:
54
- break
55
- except:
56
- pass
57
-
58
- return color_names if color_names else ["Unknown"]
59
- except:
60
- return ["Unknown"]
61
-
62
- def get_color_name(rgb_color):
63
- """
64
- Convert RGB values to a human-readable color name
65
- """
66
- min_colors = {}
67
- for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
68
- r_c, g_c, b_c = webcolors.hex_to_rgb(key)
69
- rd = (r_c - rgb_color[0]) ** 2
70
- gd = (g_c - rgb_color[1]) ** 2
71
- bd = (b_c - rgb_color[2]) ** 2
72
- min_colors[(rd + gd + bd)] = name
73
- return min_colors[min(min_colors.keys())]
74
-
75
- def analyze_object_attributes(image, box, label):
76
- """
77
- Analyze detailed attributes of detected objects
78
- """
79
- x1, y1, x2, y2 = box
80
- object_region = image[int(y1):int(y2), int(x1):int(x2)]
81
-
82
- attributes = {
83
- 'label': label,
84
- 'position': get_position_description(x1, y1, x2, y2, image.shape),
85
- 'size': get_size_description(x2-x1, y2-y1, image.shape),
86
- 'colors': get_dominant_colors(object_region, n_colors=2),
87
- 'confidence': None, # Will be set from detection
88
- 'bbox': [float(x1), float(y1), float(x2), float(y2)] # Add bounding box coordinates
89
- }
90
-
91
- return attributes
92
-
93
- def get_position_description(x1, y1, x2, y2, image_shape):
94
- """
95
- Describe object position in human terms
96
- """
97
- h, w = image_shape[:2]
98
- center_x = (x1 + x2) / 2
99
- center_y = (y1 + y2) / 2
100
-
101
- # Horizontal position
102
- if center_x < w / 3:
103
- h_pos = "left"
104
- elif center_x > 2 * w / 3:
105
- h_pos = "right"
106
- else:
107
- h_pos = "center"
108
-
109
- # Vertical position
110
- if center_y < h / 3:
111
- v_pos = "top"
112
- elif center_y > 2 * h / 3:
113
- v_pos = "bottom"
114
- else:
115
- v_pos = "middle"
116
-
117
- if h_pos == "center" and v_pos == "middle":
118
- return "center"
119
- elif v_pos == "middle":
120
- return h_pos
121
- elif h_pos == "center":
122
- return v_pos
123
- else:
124
- return f"{v_pos}-{h_pos}"
125
-
126
- def get_size_description(width, height, image_shape):
127
- """
128
- Describe object size relative to image
129
- """
130
- img_area = image_shape[0] * image_shape[1]
131
- obj_area = width * height
132
- ratio = obj_area / img_area
133
-
134
- if ratio > 0.5:
135
- return "very large"
136
- elif ratio > 0.25:
137
- return "large"
138
- elif ratio > 0.1:
139
- return "medium"
140
- elif ratio > 0.05:
141
- return "small"
142
- else:
143
- return "tiny"
144
-
145
- def detect_objects_enhanced(image, confidence_threshold=CONFIDENCE_THRESHOLD):
146
- """
147
- Enhanced YOLO detection with improved accuracy and detailed attributes
148
- Returns:
149
- - annotated image with bounding boxes
150
- - list of detected object names
151
- - detailed attributes for each detection
152
- """
153
- # Handle different image formats
154
- if isinstance(image, np.ndarray):
155
- if image.shape[-1] == 4:
156
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
157
- elif len(image.shape) == 2 or image.shape[-1] == 1:
158
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
159
-
160
- # Run YOLO with custom parameters for better accuracy
161
- results = model(
162
- image,
163
- conf=confidence_threshold, # Confidence threshold
164
- iou=NMS_THRESHOLD, # NMS IoU threshold
165
- imgsz=640, # Image size (can increase for better accuracy)
166
- device='cuda' if torch.cuda.is_available() else 'cpu'
167
- )
168
-
169
- # Get annotated image
170
- annotated_img = results[0].plot(
171
- conf=True, # Show confidence scores
172
- line_width=2,
173
- font_size=10
174
- )
175
-
176
- # Extract detailed information
177
- detected_objects = []
178
- detailed_attributes = []
179
-
180
- for box in results[0].boxes:
181
- if box.conf[0] >= confidence_threshold: # Double-check confidence
182
- cls_id = int(box.cls[0].item())
183
- label = results[0].names[cls_id]
184
- confidence = float(box.conf[0].item())
185
-
186
- # Get box coordinates
187
- xyxy = box.xyxy[0].tolist()
188
-
189
- # Analyze attributes
190
- attributes = analyze_object_attributes(image, xyxy, label)
191
- attributes['confidence'] = f"{confidence:.2%}"
192
-
193
- detected_objects.append(label)
194
- detailed_attributes.append(attributes)
195
-
196
- return annotated_img, detected_objects, detailed_attributes
197
-
198
- def get_intelligence_report(detailed_attributes):
199
- """
200
- Generate an intelligent report about detected objects
201
- """
202
- if not detailed_attributes:
203
- return "No objects detected in the image."
204
-
205
- report = []
206
- report.append(f"Detected {len(detailed_attributes)} object(s):")
207
-
208
- for attr in detailed_attributes:
209
- colors_str = " and ".join(attr['colors'][:2]) if attr['colors'] else "unknown colors"
210
- report.append(
211
- f"- A {attr['size']} {colors_str} {attr['label']} "
212
- f"in the {attr['position']} of the image "
213
- f"(confidence: {attr['confidence']})"
214
- )
215
-
216
- # Add summary statistics
217
- object_types = Counter([attr['label'] for attr in detailed_attributes])
218
- if len(object_types) > 1:
219
- report.append("\nSummary:")
220
- for obj_type, count in object_types.most_common():
221
- report.append(f" • {count} {obj_type}(s)")
222
-
223
- return "\n".join(report)
224
-
225
- # Backward compatibility wrapper
226
- def detect_objects(image):
227
- """
228
- Wrapper for backward compatibility with original function
229
- """
230
- annotated_img, detected_objects, _ = detect_objects_enhanced(image)
231
  return annotated_img, detected_objects
 
1
+ """
2
+ Enhanced YOLO detection with improved accuracy, color detection, and detailed attributes
3
+ """
4
+ from ultralytics import YOLO # type: ignore
5
+ import cv2 # type: ignore
6
+ import numpy as np # type: ignore
7
+ from collections import Counter
8
+ import webcolors # type: ignore
9
+ # from sklearn.cluster import KMeans # type: ignore # Temporarily disabled due to numpy compatibility
10
+ import torch # type: ignore
11
+
12
+ # Load a more accurate YOLO model
13
+ # For better accuracy, use yolov8m.pt or yolov8l.pt instead of yolov8n.pt
14
+ model_size = 'yolov8m.pt' # Medium model for better accuracy vs speed balance
15
+ model = YOLO(model_size)
16
+
17
+ # Set higher confidence threshold for better accuracy
18
+ CONFIDENCE_THRESHOLD = 0.5 # Increase this for fewer but more accurate detections
19
+ NMS_THRESHOLD = 0.45 # Non-maximum suppression threshold
20
+
21
+ def get_dominant_colors(image, n_colors=3):
22
+ """
23
+ Extract dominant colors from an image region using simple averaging
24
+ (K-means temporarily disabled due to numpy compatibility)
25
+ """
26
+ try:
27
+ # Simple color detection without sklearn
28
+ # Get average color
29
+ avg_color = np.mean(image.reshape(-1, 3), axis=0).astype(int)
30
+
31
+ # Get corners for variety
32
+ h, w = image.shape[:2]
33
+ corners = [
34
+ image[0, 0], # Top-left
35
+ image[0, w-1] if w > 0 else image[0, 0], # Top-right
36
+ image[h-1, 0] if h > 0 else image[0, 0], # Bottom-left
37
+ image[h//2, w//2] if h > 0 and w > 0 else image[0, 0] # Center
38
+ ]
39
+
40
+ color_names = []
41
+ # Add average color
42
+ try:
43
+ color_names.append(get_color_name(avg_color))
44
+ except:
45
+ color_names.append(f"RGB({avg_color[0]},{avg_color[1]},{avg_color[2]})")
46
+
47
+ # Add dominant corner color if different
48
+ for corner in corners[:n_colors-1]:
49
+ try:
50
+ name = get_color_name(corner)
51
+ if name not in color_names:
52
+ color_names.append(name)
53
+ if len(color_names) >= n_colors:
54
+ break
55
+ except:
56
+ pass
57
+
58
+ return color_names if color_names else ["Unknown"]
59
+ except:
60
+ return ["Unknown"]
61
+
62
+ def get_color_name(rgb_color):
63
+ """
64
+ Convert RGB values to a human-readable color name
65
+ """
66
+ min_colors = {}
67
+ for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
68
+ r_c, g_c, b_c = webcolors.hex_to_rgb(key)
69
+ rd = (r_c - rgb_color[0]) ** 2
70
+ gd = (g_c - rgb_color[1]) ** 2
71
+ bd = (b_c - rgb_color[2]) ** 2
72
+ min_colors[(rd + gd + bd)] = name
73
+ return min_colors[min(min_colors.keys())]
74
+
75
+ def analyze_object_attributes(image, box, label):
76
+ """
77
+ Analyze detailed attributes of detected objects
78
+ """
79
+ x1, y1, x2, y2 = box
80
+ object_region = image[int(y1):int(y2), int(x1):int(x2)]
81
+
82
+ attributes = {
83
+ 'label': label,
84
+ 'position': get_position_description(x1, y1, x2, y2, image.shape),
85
+ 'size': get_size_description(x2-x1, y2-y1, image.shape),
86
+ 'colors': get_dominant_colors(object_region, n_colors=2),
87
+ 'confidence': None, # Will be set from detection
88
+ 'bbox': [float(x1), float(y1), float(x2), float(y2)] # Add bounding box coordinates
89
+ }
90
+
91
+ return attributes
92
+
93
+ def get_position_description(x1, y1, x2, y2, image_shape):
94
+ """
95
+ Describe object position in human terms
96
+ """
97
+ h, w = image_shape[:2]
98
+ center_x = (x1 + x2) / 2
99
+ center_y = (y1 + y2) / 2
100
+
101
+ # Horizontal position
102
+ if center_x < w / 3:
103
+ h_pos = "left"
104
+ elif center_x > 2 * w / 3:
105
+ h_pos = "right"
106
+ else:
107
+ h_pos = "center"
108
+
109
+ # Vertical position
110
+ if center_y < h / 3:
111
+ v_pos = "top"
112
+ elif center_y > 2 * h / 3:
113
+ v_pos = "bottom"
114
+ else:
115
+ v_pos = "middle"
116
+
117
+ if h_pos == "center" and v_pos == "middle":
118
+ return "center"
119
+ elif v_pos == "middle":
120
+ return h_pos
121
+ elif h_pos == "center":
122
+ return v_pos
123
+ else:
124
+ return f"{v_pos}-{h_pos}"
125
+
126
+ def get_size_description(width, height, image_shape):
127
+ """
128
+ Describe object size relative to image
129
+ """
130
+ img_area = image_shape[0] * image_shape[1]
131
+ obj_area = width * height
132
+ ratio = obj_area / img_area
133
+
134
+ if ratio > 0.5:
135
+ return "very large"
136
+ elif ratio > 0.25:
137
+ return "large"
138
+ elif ratio > 0.1:
139
+ return "medium"
140
+ elif ratio > 0.05:
141
+ return "small"
142
+ else:
143
+ return "tiny"
144
+
145
+ def detect_objects_enhanced(image, confidence_threshold=CONFIDENCE_THRESHOLD):
146
+ """
147
+ Enhanced YOLO detection with improved accuracy and detailed attributes
148
+ Returns:
149
+ - annotated image with bounding boxes
150
+ - list of detected object names
151
+ - detailed attributes for each detection
152
+ """
153
+ # Handle different image formats
154
+ if isinstance(image, np.ndarray):
155
+ if image.shape[-1] == 4:
156
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
157
+ elif len(image.shape) == 2 or image.shape[-1] == 1:
158
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
159
+
160
+ # Run YOLO with custom parameters for better accuracy
161
+ results = model(
162
+ image,
163
+ conf=confidence_threshold, # Confidence threshold
164
+ iou=NMS_THRESHOLD, # NMS IoU threshold
165
+ imgsz=640, # Image size (can increase for better accuracy)
166
+ device='cuda' if torch.cuda.is_available() else 'cpu'
167
+ )
168
+
169
+ # Get annotated image
170
+ annotated_img = results[0].plot(
171
+ conf=True, # Show confidence scores
172
+ line_width=2,
173
+ font_size=10
174
+ )
175
+
176
+ # Extract detailed information
177
+ detected_objects = []
178
+ detailed_attributes = []
179
+
180
+ for box in results[0].boxes:
181
+ if box.conf[0] >= confidence_threshold: # Double-check confidence
182
+ cls_id = int(box.cls[0].item())
183
+ label = results[0].names[cls_id]
184
+ confidence = float(box.conf[0].item())
185
+
186
+ # Get box coordinates
187
+ xyxy = box.xyxy[0].tolist()
188
+
189
+ # Analyze attributes
190
+ attributes = analyze_object_attributes(image, xyxy, label)
191
+ attributes['confidence'] = f"{confidence:.2%}"
192
+
193
+ detected_objects.append(label)
194
+ detailed_attributes.append(attributes)
195
+
196
+ return annotated_img, detected_objects, detailed_attributes
197
+
198
+ def get_intelligence_report(detailed_attributes):
199
+ """
200
+ Generate an intelligent report about detected objects
201
+ """
202
+ if not detailed_attributes:
203
+ return "No objects detected in the image."
204
+
205
+ report = []
206
+ report.append(f"Detected {len(detailed_attributes)} object(s):")
207
+
208
+ for attr in detailed_attributes:
209
+ colors_str = " and ".join(attr['colors'][:2]) if attr['colors'] else "unknown colors"
210
+ report.append(
211
+ f"- A {attr['size']} {colors_str} {attr['label']} "
212
+ f"in the {attr['position']} of the image "
213
+ f"(confidence: {attr['confidence']})"
214
+ )
215
+
216
+ # Add summary statistics
217
+ object_types = Counter([attr['label'] for attr in detailed_attributes])
218
+ if len(object_types) > 1:
219
+ report.append("\nSummary:")
220
+ for obj_type, count in object_types.most_common():
221
+ report.append(f" • {count} {obj_type}(s)")
222
+
223
+ return "\n".join(report)
224
+
225
+ # Backward compatibility wrapper
226
+ def detect_objects(image):
227
+ """
228
+ Wrapper for backward compatibility with original function
229
+ """
230
+ annotated_img, detected_objects, _ = detect_objects_enhanced(image)
231
  return annotated_img, detected_objects
packages.txt CHANGED
@@ -1,5 +1,5 @@
1
- ffmpeg
2
- libsm6
3
- libxext6
4
- libxrender-dev
5
  libglib2.0-0
 
1
+ ffmpeg
2
+ libsm6
3
+ libxext6
4
+ libxrender-dev
5
  libglib2.0-0
requirements.txt CHANGED
@@ -1,16 +1,13 @@
1
- streamlit>=1.28.0
2
- ultralytics>=8.0.0
3
- openai>=1.0.0
4
- opencv-python-headless>=4.8.0
5
- pillow>=10.0.0
6
- numpy>=1.24.0,<2.0.0
7
- torch>=2.0.0
8
- torchvision>=0.15.0
9
- python-dotenv>=1.0.0
10
- plotly>=5.17.0
11
- kaleido>=0.2.1
12
- requests>=2.31.0
13
- pandas>=1.5.0,<2.1.0
14
- webcolors>=1.13
15
- face-recognition>=1.3.0
16
- dlib>=19.24.0
 
1
+ streamlit>=1.28.0
2
+ ultralytics>=8.0.0
3
+ openai>=1.0.0
4
+ opencv-python-headless>=4.8.0
5
+ pillow>=10.0.0
6
+ numpy>=1.24.0,<2.0.0
7
+ torch>=2.0.0,<2.5.0
8
+ torchvision>=0.15.0,<0.20.0
9
+ python-dotenv>=1.0.0
10
+ plotly>=5.17.0
11
+ requests>=2.31.0
12
+ pandas>=1.5.0,<2.1.0
13
+ webcolors>=1.13
 
 
 
requirements_lite.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.28.0
2
+ ultralytics>=8.0.0
3
+ openai>=1.0.0
4
+ opencv-python-headless>=4.8.0
5
+ pillow>=10.0.0
6
+ numpy>=1.24.0,<2.0.0
7
+ torch>=2.0.0
8
+ torchvision>=0.15.0
9
+ python-dotenv>=1.0.0
10
+ plotly>=5.17.0
11
+ kaleido>=0.2.1
12
+ requests>=2.31.0
13
+ pandas>=1.5.0,<2.1.0
14
+ webcolors>=1.13