File size: 28,191 Bytes
1865bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c1422d
 
1865bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c40c656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1865bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af1c381
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
import base64
import numpy as np
import torch
from flask import Flask, render_template, request
from flask_socketio import SocketIO, emit
from PIL import Image, ImageEnhance, ImageFilter
from io import BytesIO
import logging
import threading
import time
from transformers import BlipProcessor, BlipForConditionalGeneration
from collections import deque
import cv2
import asyncio
from concurrent.futures import ThreadPoolExecutor
import hashlib
import json
from datetime import datetime, timedelta
import queue

# ---- 1. ENHANCED SETUP ----

# Suppress excessive logging from libraries
logging.getLogger('engineio').setLevel(logging.ERROR)
logging.getLogger('socketio').setLevel(logging.ERROR)

# --- Enhanced Configuration ---
FRAME_SKIP = 3  # Adaptive frame skipping
IMAGE_SIZE = 224  # Optimized size for BLIP
BUFFER_SIZE = 5  # Smart buffering
MIN_CONFIDENCE_DIFF = 0.03
MAX_WORKERS = 6  # Increased thread pool
CACHE_SIZE = 500  # Larger cache with LRU
BATCH_SIZE = 4  # Batch processing capability

# Advanced performance settings
ADAPTIVE_QUALITY = True
MIN_PROCESSING_INTERVAL = 0.1  # Minimum time between processing
SCENE_CHANGE_THRESHOLD = 0.15  # For scene change detection
CAPTION_HISTORY_SIZE = 10  # Keep caption history for context

# --- Flask & SocketIO App Initialization ---
# app = Flask(__name__)
app = Flask(__name__, template_folder='../templates', static_folder='../static')
app.config['SECRET_KEY'] = 'your-very-secret-key!'
socketio = SocketIO(app, async_mode='threading', logger=False, engineio_logger=False, 
                   cors_allowed_origins="*", ping_timeout=60, ping_interval=25)

# --- Enhanced AI Model Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Advanced thread pool with priority queue
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS, thread_name_prefix="caption_worker")
priority_queue = queue.PriorityQueue()

# Load BLIP model with advanced optimizations
try:
    print("Loading BLIP model with optimizations...")
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
    model = model.to(device)
    model.eval()
    
    # Advanced CUDA optimizations
    if device.type == 'cuda':
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False
        model = torch.jit.script(model)  # TorchScript optimization
        from torch.cuda.amp import autocast, GradScaler
        USE_AMP = True
        scaler = GradScaler()
        print("CUDA optimizations and TorchScript enabled")
    else:
        USE_AMP = False
    
    # Warm up the model
    dummy_image = Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), color='black')
    dummy_inputs = processor(dummy_image, return_tensors="pt").to(device)
    with torch.no_grad():
        _ = model.generate(**dummy_inputs, max_length=10)
    print("Model warmed up successfully!")
    
except Exception as e:
    print(f"Error loading BLIP model: {e}")
    exit()

# --- Advanced Caching System ---
class LRUCache:
    def __init__(self, max_size):
        self.max_size = max_size
        self.cache = {}
        self.access_order = deque()
        self.lock = threading.Lock()
    
    def get(self, key):
        with self.lock:
            if key in self.cache:
                # Move to end (most recently used)
                self.access_order.remove(key)
                self.access_order.append(key)
                return self.cache[key]
            return None
    
    def put(self, key, value):
        with self.lock:
            if key in self.cache:
                self.access_order.remove(key)
            elif len(self.cache) >= self.max_size:
                # Remove least recently used
                oldest = self.access_order.popleft()
                del self.cache[oldest]
            
            self.cache[key] = value
            self.access_order.append(key)
    
    def clear(self):
        with self.lock:
            self.cache.clear()
            self.access_order.clear()

# --- Advanced Frame Processing ---
frame_counters = {}
processing_locks = {}
caption_buffers = {}
last_captions = {}
processing_times = {}
caption_history = {}
last_processed_time = {}
scene_features = {}  # For scene change detection

# Enhanced caching
caption_cache = LRUCache(CACHE_SIZE)
batch_queue = {}

# --- Smart Performance Monitor ---
class AdvancedPerformanceMonitor:
    def __init__(self):
        self.metrics = {
            'total_frames': 0,
            'processed_frames': 0,
            'cache_hits': 0,
            'cache_misses': 0,
            'batch_processed': 0,
            'scene_changes': 0,
            'processing_times': deque(maxlen=100),
            'start_time': time.time()
        }
        self.lock = threading.Lock()
    
    def log_frame(self, processing_time=None, cache_hit=False, batch_size=1, scene_change=False):
        with self.lock:
            self.metrics['total_frames'] += 1
            if processing_time:
                self.metrics['processed_frames'] += 1
                self.metrics['processing_times'].append(processing_time)
                if batch_size > 1:
                    self.metrics['batch_processed'] += batch_size
            
            if cache_hit:
                self.metrics['cache_hits'] += 1
            else:
                self.metrics['cache_misses'] += 1
            
            if scene_change:
                self.metrics['scene_changes'] += 1
    
    def get_stats(self):
        with self.lock:
            if not self.metrics['processing_times']:
                return {"avg_time": 0, "cache_hit_rate": 0, "fps": 0, "efficiency": 0}
            
            total_time = time.time() - self.metrics['start_time']
            avg_processing_time = np.mean(self.metrics['processing_times'])
            cache_hit_rate = self.metrics['cache_hits'] / max(1, self.metrics['total_frames'])
            processing_fps = self.metrics['processed_frames'] / max(1, avg_processing_time * self.metrics['processed_frames'])
            efficiency = self.metrics['processed_frames'] / max(1, self.metrics['total_frames'])
            
            return {
                "avg_time": avg_processing_time,
                "cache_hit_rate": cache_hit_rate,
                "processing_fps": processing_fps,
                "efficiency": efficiency,
                "total_frames": self.metrics['total_frames'],
                "scene_changes": self.metrics['scene_changes'],
                "batch_efficiency": self.metrics['batch_processed'] / max(1, self.metrics['processed_frames'])
            }

perf_monitor = AdvancedPerformanceMonitor()

# --- Smart Image Preprocessing ---
def smart_preprocess_image(image, enhance_quality=True):
    """Enhanced image preprocessing with quality improvements."""
    # Convert to RGB if needed
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    if enhance_quality:
        # Enhance image quality
        # Sharpening
        enhancer = ImageEnhance.Sharpness(image)
        image = enhancer.enhance(1.2)
        
        # Contrast enhancement
        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(1.1)
        
        # Color enhancement
        enhancer = ImageEnhance.Color(image)
        image = enhancer.enhance(1.05)
    
    # Smart resizing with aspect ratio preservation
    original_size = image.size
    if original_size[0] != original_size[1]:  # Non-square image
        # Crop to square from center
        min_dim = min(original_size)
        left = (original_size[0] - min_dim) // 2
        top = (original_size[1] - min_dim) // 2
        image = image.crop((left, top, left + min_dim, top + min_dim))
    
    # Resize with high-quality resampling
    image = image.resize((IMAGE_SIZE, IMAGE_SIZE), Image.LANCZOS)
    
    return image

def advanced_hash_image(image):
    """Generate robust hash for image similarity detection."""
    # Create perceptual hash using multiple features
    img_small = image.resize((16, 16), Image.LANCZOS)
    img_gray = img_small.convert('L')
    
    # Get pixel values
    pixels = list(img_gray.getdata())
    
    # Create hash from average and differences
    avg = sum(pixels) / len(pixels)
    hash_bits = ''.join('1' if pixel > avg else '0' for pixel in pixels)
    
    # Additional feature: edge detection hash
    img_array = np.array(img_gray)
    edges = cv2.Canny(img_array, 50, 150)
    edge_hash = hashlib.md5(edges.tobytes()).hexdigest()[:8]
    
    return hash_bits + edge_hash

def detect_scene_change(sid, current_features):
    """Detect significant scene changes."""
    if sid not in scene_features:
        scene_features[sid] = current_features
        return True
    
    # Compare with previous features
    prev_features = scene_features[sid]
    
    # Calculate similarity (Hamming distance for hash)
    if len(current_features) == len(prev_features):
        diff_count = sum(c1 != c2 for c1, c2 in zip(current_features[:256], prev_features[:256]))
        similarity = 1 - (diff_count / 256)
        
        scene_features[sid] = current_features
        return similarity < (1 - SCENE_CHANGE_THRESHOLD)
    
    scene_features[sid] = current_features
    return True

# ---- 2. ENHANCED WEBSOCKET HANDLERS ----

@socketio.on('connect')
def handle_connect():
    """Enhanced client connection handler."""
    print(f"Client connected: {request.sid}")
    sid = request.sid
    
    # Initialize client data
    frame_counters[sid] = 0
    processing_locks[sid] = threading.Lock()
    caption_buffers[sid] = deque(maxlen=BUFFER_SIZE)
    last_captions[sid] = ""
    processing_times[sid] = deque(maxlen=20)
    caption_history[sid] = deque(maxlen=CAPTION_HISTORY_SIZE)
    last_processed_time[sid] = 0
    scene_features[sid] = ""
    batch_queue[sid] = []
    
    # Send initial status
    emit('status', {'connected': True, 'device': str(device)})

@socketio.on('disconnect')
def handle_disconnect():
    """Enhanced client disconnection handler."""
    print(f"Client disconnected: {request.sid}")
    cleanup_client(request.sid)

def cleanup_client(sid):
    """Enhanced client cleanup."""
    for data_dict in [frame_counters, processing_locks, caption_buffers, 
                      last_captions, processing_times, caption_history,
                      last_processed_time, scene_features, batch_queue]:
        if sid in data_dict:
            del data_dict[sid]

@socketio.on('image')
def handle_image(data_image):
    """Enhanced image handling with smart processing."""
    sid = request.sid
    
    # Initialize if not exists
    if sid not in frame_counters:
        handle_connect()
    
    frame_counters[sid] += 1
    current_time = time.time()
    
    # Adaptive frame skipping based on processing load
    skip_factor = FRAME_SKIP
    if sid in processing_times and processing_times[sid]:
        avg_time = np.mean(processing_times[sid])
        if avg_time > 0.5:  # If processing is slow, skip more frames
            skip_factor = FRAME_SKIP * 2
        elif avg_time < 0.1:  # If processing is fast, skip fewer frames
            skip_factor = max(1, FRAME_SKIP // 2)
    
    if frame_counters[sid] % skip_factor != 0:
        perf_monitor.log_frame()  # Count skipped frames
        return
    
    # Rate limiting
    if current_time - last_processed_time.get(sid, 0) < MIN_PROCESSING_INTERVAL:
        return
    
    # Check if we're already processing
    if not processing_locks[sid].acquire(blocking=False):
        return
    
    last_processed_time[sid] = current_time
    
    # Submit to thread pool with priority
    priority = 1  # Normal priority
    future = executor.submit(process_frame_advanced, sid, data_image, priority)

def process_frame_advanced(sid, data_image, priority=1):
    """Advanced frame processing with multiple optimizations."""
    start_time = time.time()
    
    try:
        # Decode image
        image_data = base64.b64decode(data_image.split(',')[1])
        img = Image.open(BytesIO(image_data))
        
        # Smart preprocessing
        img = smart_preprocess_image(img, enhance_quality=ADAPTIVE_QUALITY)
        
        # Generate advanced hash
        img_hash = advanced_hash_image(img)
        
        # Scene change detection
        scene_changed = detect_scene_change(sid, img_hash)
        
        # Check cache first
        cached_caption = caption_cache.get(img_hash)
        if cached_caption and not scene_changed:
            caption = cached_caption
            cache_hit = True
        else:
            # Generate new caption
            caption = generate_caption_advanced(img)
            caption_cache.put(img_hash, caption)
            cache_hit = False
        
        # Smart caption updating with context
        if should_update_caption_advanced(sid, caption, scene_changed):
            # Add to caption history
            caption_history[sid].append({
                'caption': caption,
                'timestamp': time.time(),
                'scene_changed': scene_changed
            })
            
            last_captions[sid] = caption
            
            # Enhanced caption with context
            contextual_caption = add_context_to_caption(sid, caption)
            
            print(f"New caption for {sid}: {contextual_caption}")
            
            # Send enhanced response
            socketio.emit('caption', {
                'caption': contextual_caption,
                'raw_caption': caption,
                'timestamp': time.time(),
                'confidence': 0.95 if not cache_hit else 1.0,
                'scene_changed': scene_changed,
                'processing_time': time.time() - start_time
            }, room=sid)
        
        # Update performance metrics
        processing_time = time.time() - start_time
        processing_times[sid].append(processing_time)
        perf_monitor.log_frame(processing_time, cache_hit, scene_change=scene_changed)
        
        # Periodic performance logging
        if frame_counters[sid] % 100 == 0:
            stats = perf_monitor.get_stats()
            print(f"Client {sid}: Avg: {stats['avg_time']:.3f}s, Cache: {stats['cache_hit_rate']:.2f}, "
                  f"Efficiency: {stats['efficiency']:.2f}, Scene changes: {stats['scene_changes']}")
    
    except Exception as e:
        print(f"Error processing frame for {sid}: {e}")
        socketio.emit('caption', {
            'caption': f"Processing error: {str(e)[:50]}...",
            'timestamp': time.time(),
            'confidence': 0.0,
            'error': True
        }, room=sid)
    
    finally:
        if sid in processing_locks:
            processing_locks[sid].release()

def should_update_caption_advanced(sid, new_caption, scene_changed):
    """Advanced caption update logic with context awareness."""
    if sid not in last_captions or scene_changed:
        return True
    
    last_caption = last_captions[sid]
    
    # Always update on errors or initial state
    if not last_caption or "error" in last_caption.lower() or last_caption == "Processing...":
        return True
    
    # Check caption history for patterns
    if sid in caption_history and len(caption_history[sid]) > 1:
        recent_captions = [item['caption'] for item in list(caption_history[sid])[-3:]]
        if len(set(recent_captions)) == 1 and new_caption not in recent_captions:
            return True  # Break repetition
    
    # Enhanced semantic similarity with weighted keywords
    words_old = set(last_caption.lower().split())
    words_new = set(new_caption.lower().split())
    
    # Weighted keywords for different importance levels
    high_priority_words = {'walking', 'running', 'sitting', 'standing', 'jumping', 'dancing', 
                          'eating', 'drinking', 'driving', 'flying', 'swimming', 'climbing'}
    medium_priority_words = {'holding', 'wearing', 'looking', 'pointing', 'smiling', 'talking',
                            'reading', 'writing', 'playing', 'working', 'sleeping'}
    objects_words = {'car', 'bike', 'phone', 'book', 'cup', 'computer', 'dog', 'cat', 'bird'}
    
    # Check for high priority changes
    old_high = words_old.intersection(high_priority_words)
    new_high = words_new.intersection(high_priority_words)
    if old_high != new_high:
        return True
    
    # Check for significant object changes
    old_objects = words_old.intersection(objects_words)
    new_objects = words_new.intersection(objects_words)
    if len(old_objects.symmetric_difference(new_objects)) > 1:
        return True
    
    # Advanced similarity calculation
    intersection = words_old.intersection(words_new)
    union = words_old.union(words_new)
    
    if len(union) == 0:
        return True
    
    # Weighted similarity based on word importance
    weight_old = sum(3 if word in high_priority_words else 2 if word in medium_priority_words else 1 
                    for word in words_old)
    weight_new = sum(3 if word in high_priority_words else 2 if word in medium_priority_words else 1 
                    for word in words_new)
    weight_intersection = sum(3 if word in high_priority_words else 2 if word in medium_priority_words else 1 
                             for word in intersection)
    
    weighted_similarity = (2 * weight_intersection) / (weight_old + weight_new) if (weight_old + weight_new) > 0 else 0
    
    return weighted_similarity < 0.75

def add_context_to_caption(sid, caption):
    """Add temporal context to captions."""
    if sid not in caption_history or len(caption_history[sid]) < 2:
        return caption
    
    recent_captions = [item['caption'] for item in list(caption_history[sid])[-3:]]
    
    # Detect action continuity
    action_words = {'walking', 'running', 'sitting', 'standing', 'eating', 'drinking'}
    current_actions = set(caption.lower().split()).intersection(action_words)
    
    if current_actions:
        for prev_caption in recent_captions[:-1]:
            prev_actions = set(prev_caption.lower().split()).intersection(action_words)
            if current_actions == prev_actions:
                return f"{caption} (continuing)"
    
    return caption

def generate_caption_advanced(image):
    """Advanced caption generation with optimizations."""
    try:
        inputs = processor(image, return_tensors="pt").to(device)
        
        # Enhanced generation parameters
        generation_kwargs = {
            'max_length': 30,
            'min_length': 8,
            'num_beams': 5,
            'do_sample': True,
            'temperature': 0.8,
            'top_p': 0.95,
            'top_k': 50,
            'early_stopping': True,
            'no_repeat_ngram_size': 3,
            'length_penalty': 1.1,
            'repetition_penalty': 1.2
        }
        
        if USE_AMP and device.type == 'cuda':
            with autocast():
                with torch.no_grad():
                    generated_ids = model.generate(**inputs, **generation_kwargs)
        else:
            with torch.no_grad():
                generated_ids = model.generate(**inputs, **generation_kwargs)
        
        caption = processor.decode(generated_ids[0], skip_special_tokens=True)
        return enhance_caption_advanced(caption)
        
    except Exception as e:
        print(f"Error in generate_caption_advanced: {e}")
        return "Processing scene..."

def enhance_caption_advanced(caption):
    """Advanced caption enhancement with NLP improvements."""
    caption = caption.strip()
    if not caption:
        return "Analyzing scene..."
    
    # Remove common prefixes more intelligently
    prefixes_to_remove = [
        "a picture of ", "an image of ", "this is ", "there is ", "there are ",
        "the image shows ", "this image shows ", "a photo of ", "a photograph of "
    ]
    
    caption_lower = caption.lower()
    for prefix in prefixes_to_remove:
        if caption_lower.startswith(prefix):
            caption = caption[len(prefix):]
            break
    
    # Advanced replacements for more natural language
    replacements = {
        r'\b(man|woman|person) (is )?(sitting on|standing in|walking on)\b': 
            lambda m: f"{m.group(1)} {m.group(3).replace('on', 'at').replace('in', 'within')}",
        r'\bholding a\b': 'holding',
        r'\bwearing a\b': 'wearing',
        r'\blooking at the\b': 'observing the',
        r'\bstanding next to\b': 'beside',
        r'\bwalking down\b': 'walking along',
        r'\bsitting at\b': 'seated at'
    }
    
    import re
    for pattern, replacement in replacements.items():
        if callable(replacement):
            caption = re.sub(pattern, replacement, caption, flags=re.IGNORECASE)
        else:
            caption = re.sub(pattern, replacement, caption, flags=re.IGNORECASE)
    
    # Capitalize appropriately
    if caption and not caption[0].isupper():
        caption = caption[0].upper() + caption[1:]
    
    # Add descriptive variety
    action_variations = {
        'walking': ['strolling', 'moving', 'walking'],
        'sitting': ['seated', 'resting', 'sitting'],
        'standing': ['positioned', 'standing', 'upright'],
        'holding': ['grasping', 'carrying', 'holding'],
        'looking': ['observing', 'viewing', 'watching', 'looking at']
    }
    
    # Randomly vary some common actions (seed based on caption for consistency)
    import random
    random.seed(hash(caption) % 1000)
    
    for base_action, variations in action_variations.items():
        if base_action in caption.lower():
            if random.random() < 0.3:  # 30% chance to vary
                caption = caption.replace(base_action, random.choice(variations))
    
    return caption




# --- Advanced Frame Processing ---
# ... (existing functions like smart_preprocess_image, advanced_hash_image, detect_scene_change)

# NEW FUNCTION FOR AUDIO PROCESSING
import speech_recognition as sr
import io # Needed for in-memory audio files

def process_audio_chunk(audio_data_b64):
    """
    Processes a base64 encoded audio chunk using SpeechRecognition.
    Assumes the audio_data_b64 is a data URL (e.g., "data:audio/wav;base64,...").
    """
    try:
        # Extract only the base64 part, removing the "data:audio/wav;base64," prefix
        header, b64_data = audio_data_b64.split(',', 1)
        audio_bytes = base64.b64decode(b64_data)

        # Use io.BytesIO to create an in-memory file-like object that SpeechRecognition can read
        audio_file_in_memory = io.BytesIO(audio_bytes)

        r = sr.Recognizer()
        # SpeechRecognition should be able to read a valid WAV format from BytesIO
        with sr.AudioFile(audio_file_in_memory) as source:
            audio = r.record(source) # Read the entire audio data from the in-memory file

        # Use Google Web Speech API for transcription
        transcription = r.recognize_google(audio)
        print(f"Audio Transcription: {transcription}") # Added for logging
        return transcription
    except sr.UnknownValueError:
        print("Speech Recognition could not understand audio")
        return ""
    except sr.RequestError as e:
        print(f"Could not request results from Google Web Speech API service; {e}")
        return ""
    except Exception as e:
        print(f"Error processing audio chunk: {e}")
        import traceback
        traceback.print_exc() # Print full traceback for debugging
        return f"Audio processing error: {e}"

# ... (your existing handle_image function)

# NEW HANDLER FOR AUDIO FEED
@socketio.on('audio_feed')
def handle_audio_feed(data):
    """
    Handles incoming audio chunks for transcription.
    Submits processing to executor to avoid blocking the main SocketIO thread.
    """
    sid = request.sid
    if sid not in frame_counters: # Ensure client is initialized
        handle_connect()

    audio_data_b64 = data['audio']

    # Submit the audio processing to the thread pool executor
    # We use a lambda to ensure the result can be emitted back to the client
    executor.submit(lambda: socketio.emit('transcription_update', {
        'transcription': process_audio_chunk(audio_data_b64)
    }, room=sid))
    # Note: 'transcription_update' must be handled by your app.js frontend


# ---- 3. ENHANCED FLASK ROUTES ----

@app.route('/')
def index():
    """Render the main HTML page."""
    return render_template('index.html')

@app.route('/status')
def status():
    """Enhanced server status with detailed metrics."""
    stats = perf_monitor.get_stats()
    return {
        'active_connections': len(frame_counters),
        'device': str(device),
        'configuration': {
            'frame_skip': FRAME_SKIP,
            'image_size': IMAGE_SIZE,
            'buffer_size': BUFFER_SIZE,
            'cache_size': CACHE_SIZE,
            'batch_size': BATCH_SIZE,
            'adaptive_quality': ADAPTIVE_QUALITY
        },
        'performance': stats,
        'cache_info': {
            'size': len(caption_cache.cache),
            'max_size': CACHE_SIZE
        },
        'optimizations': {
            'mixed_precision': USE_AMP,
            'torch_script': device.type == 'cuda',
            'thread_pool_size': MAX_WORKERS
        }
    }

@app.route('/metrics')
def metrics():
    """Detailed performance metrics endpoint."""
    stats = perf_monitor.get_stats()
    
    # Client-specific metrics
    client_metrics = {}
    for sid in frame_counters:
        if sid in processing_times and processing_times[sid]:
            client_metrics[sid] = {
                'frames_processed': frame_counters[sid],
                'avg_processing_time': np.mean(processing_times[sid]),
                'caption_history_size': len(caption_history.get(sid, [])),
                'last_caption': last_captions.get(sid, "None")
            }
    
    return {
        'global_metrics': stats,
        'client_metrics': client_metrics,
        'system_info': {
            'device': str(device),
            'cuda_available': torch.cuda.is_available(),
            'cuda_memory': torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else None
        }
    }

@app.route('/clear_cache')
def clear_cache():
    """Clear all caches."""
    caption_cache.clear()
    return {'status': 'cache_cleared', 'timestamp': time.time()}

@app.route('/config', methods=['GET', 'POST'])
def config():
    """Dynamic configuration endpoint."""
    global FRAME_SKIP, ADAPTIVE_QUALITY, SCENE_CHANGE_THRESHOLD
    
    if request.method == 'POST':
        config_data = request.get_json()
        if 'frame_skip' in config_data:
            FRAME_SKIP = max(1, int(config_data['frame_skip']))
        if 'adaptive_quality' in config_data:
            ADAPTIVE_QUALITY = bool(config_data['adaptive_quality'])
        if 'scene_change_threshold' in config_data:
            SCENE_CHANGE_THRESHOLD = float(config_data['scene_change_threshold'])
        
        return {'status': 'updated', 'config': {
            'frame_skip': FRAME_SKIP,
            'adaptive_quality': ADAPTIVE_QUALITY,
            'scene_change_threshold': SCENE_CHANGE_THRESHOLD
        }}
    
    return {
        'frame_skip': FRAME_SKIP,
        'adaptive_quality': ADAPTIVE_QUALITY,
        'scene_change_threshold': SCENE_CHANGE_THRESHOLD
    }

# ---- 4. ENHANCED STARTUP ----
if __name__ == '__main__':
    print("=" * 60)
    print("🚀 Starting Enhanced Real-Time Video Captioning Server")
    print("=" * 60)
    print(f"📱 Device: {device}")
    print(f"🎯 Image Processing: {IMAGE_SIZE}x{IMAGE_SIZE}")
    print(f"⚡ Frame Skip: {FRAME_SKIP} (adaptive)")
    print(f"🧠 Mixed Precision: {USE_AMP}")
    print(f"🔄 Thread Pool: {MAX_WORKERS} workers")
    print(f"💾 Cache Size: {CACHE_SIZE} entries (LRU)")
    print(f"🎨 Quality Enhancement: {ADAPTIVE_QUALITY}")
    print(f"🔍 Scene Change Detection: Enabled")
    print("=" * 60)
    
    # socketio.run(app, host='0.0.0.0', port=5000, debug=False, allow_unsafe_werkzeug=True)