Spaces:
Running
Running
| """ | |
| Enhanced Integration API | |
| Connects XAI, model compression, A/B testing, and automated training with UX and databases | |
| """ | |
| from flask import Blueprint, request, jsonify, render_template, session | |
| import torch | |
| import numpy as np | |
| import json | |
| import logging | |
| from typing import Dict, List, Any, Optional, Tuple | |
| from datetime import datetime, timedelta | |
| import uuid | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import sqlite3 | |
| try: | |
| import psycopg2 | |
| from psycopg2.extras import RealDictCursor | |
| HAS_POSTGRES = True | |
| except ImportError: | |
| psycopg2 = None | |
| RealDictCursor = None | |
| HAS_POSTGRES = False | |
| # Import our enhancement modules | |
| from src.interpretability.advanced_xai import AdvancedXAIEngine, ConfidenceCalibrator | |
| from src.optimization.model_compression import ModelCompressor, CompressionConfig | |
| from src.testing.ab_testing_framework import ABTestingFramework, ExperimentConfig, ModelVariant, ExperimentMetric | |
| from src.training.automated_pipeline import SmartDatasetCurator, ActiveLearningTrainer | |
| # Import existing MorphGuard components | |
| from morphguard_api import MorphGuardAPI | |
| logger = logging.getLogger(__name__) | |
| # Create Flask Blueprint | |
| enhanced_api = Blueprint('enhanced_api', __name__, url_prefix='/api/enhanced') | |
| # Global instances | |
| xai_engine = None | |
| model_compressor = None | |
| ab_testing_framework = ABTestingFramework() | |
| dataset_curator = SmartDatasetCurator() | |
| morphguard_api = MorphGuardAPI() | |
| class EnhancedIntegrationManager: | |
| """Manages integration between enhancements and existing systems""" | |
| def __init__(self, db_config: Dict[str, str]): | |
| self.db_config = db_config | |
| self.timescale_conn = None | |
| self.sqlite_conn = None | |
| self._init_connections() | |
| def _init_connections(self): | |
| """Initialize database connections""" | |
| try: | |
| # TimescaleDB connection for metrics | |
| if HAS_POSTGRES: | |
| self.timescale_conn = psycopg2.connect( | |
| host=self.db_config.get('timescale_host', 'localhost'), | |
| database=self.db_config.get('timescale_db', 'morphguard'), | |
| user=self.db_config.get('timescale_user', 'postgres'), | |
| password=self.db_config.get('timescale_password', '') | |
| ) | |
| else: | |
| logger.warning("psycopg2 not installed. TimescaleDB metrics storage unavailable.") | |
| self.timescale_conn = None | |
| # SQLite connection for user data | |
| self.sqlite_conn = sqlite3.connect( | |
| self.db_config.get('sqlite_db', 'users.db'), | |
| check_same_thread=False | |
| ) | |
| self.sqlite_conn.row_factory = sqlite3.Row | |
| logger.info("Database connections established") | |
| except Exception as e: | |
| logger.error(f"Database connection failed: {e}") | |
| def store_xai_analysis(self, session_id: str, analysis_results: Dict[str, Any]) -> str: | |
| """Store XAI analysis results in TimescaleDB""" | |
| analysis_id = str(uuid.uuid4()) | |
| try: | |
| cursor = self.timescale_conn.cursor() | |
| # Store in xai_analysis table | |
| cursor.execute(""" | |
| INSERT INTO xai_analysis ( | |
| timestamp, analysis_id, session_id, method_results, | |
| confidence_calibration, interpretation_score, processing_time_ms | |
| ) VALUES (NOW(), %s, %s, %s, %s, %s, %s) | |
| """, ( | |
| analysis_id, | |
| session_id, | |
| json.dumps(analysis_results.get('explanations', {})), | |
| json.dumps(analysis_results.get('confidence_calibration', {})), | |
| analysis_results.get('average_interpretation_score', 0.0), | |
| analysis_results.get('total_processing_time_ms', 0.0) | |
| )) | |
| self.timescale_conn.commit() | |
| logger.info(f"Stored XAI analysis {analysis_id}") | |
| return analysis_id | |
| except Exception as e: | |
| logger.error(f"Failed to store XAI analysis: {e}") | |
| if self.timescale_conn: | |
| self.timescale_conn.rollback() | |
| return None | |
| def store_compression_metrics(self, compression_results: Dict[str, Any]) -> str: | |
| """Store model compression results""" | |
| compression_id = str(uuid.uuid4()) | |
| try: | |
| cursor = self.timescale_conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO model_compression_metrics ( | |
| timestamp, compression_id, original_size_mb, compressed_size_mb, | |
| compression_ratio, speedup_ratio, accuracy_drop, compression_config | |
| ) VALUES (NOW(), %s, %s, %s, %s, %s, %s, %s) | |
| """, ( | |
| compression_id, | |
| compression_results.get('original_size_mb', 0.0), | |
| compression_results.get('compressed_size_mb', 0.0), | |
| compression_results.get('compression_ratio', 1.0), | |
| compression_results.get('speedup_ratio', 1.0), | |
| compression_results.get('accuracy_drop', 0.0), | |
| json.dumps(compression_results.get('config', {})) | |
| )) | |
| self.timescale_conn.commit() | |
| logger.info(f"Stored compression metrics {compression_id}") | |
| return compression_id | |
| except Exception as e: | |
| logger.error(f"Failed to store compression metrics: {e}") | |
| if self.timescale_conn: | |
| self.timescale_conn.rollback() | |
| return None | |
| def store_ab_test_result(self, experiment_id: str, variant_id: str, metrics: Dict[str, float]) -> str: | |
| """Store A/B test result and integrate with existing user tracking""" | |
| result_id = str(uuid.uuid4()) | |
| user_id = session.get('user_id', 'anonymous') | |
| session_id = session.get('session_id', str(uuid.uuid4())) | |
| try: | |
| # Store in A/B testing framework | |
| ab_testing_framework.record_result( | |
| experiment_id=experiment_id, | |
| variant_id=variant_id, | |
| metrics=metrics, | |
| user_id=user_id, | |
| session_id=session_id | |
| ) | |
| # Also store in TimescaleDB for real-time monitoring | |
| cursor = self.timescale_conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO ab_test_results ( | |
| timestamp, result_id, experiment_id, variant_id, user_id, | |
| session_id, metrics, processing_time_ms | |
| ) VALUES (NOW(), %s, %s, %s, %s, %s, %s, %s) | |
| """, ( | |
| result_id, | |
| experiment_id, | |
| variant_id, | |
| user_id, | |
| session_id, | |
| json.dumps(metrics), | |
| metrics.get('processing_time_ms', 0.0) | |
| )) | |
| self.timescale_conn.commit() | |
| logger.info(f"Stored A/B test result {result_id}") | |
| return result_id | |
| except Exception as e: | |
| logger.error(f"Failed to store A/B test result: {e}") | |
| if self.timescale_conn: | |
| self.timescale_conn.rollback() | |
| return None | |
| def get_user_analytics(self, user_id: str, days: int = 30) -> Dict[str, Any]: | |
| """Get comprehensive analytics for a user""" | |
| try: | |
| cursor = self.timescale_conn.cursor(cursor_factory=RealDictCursor) | |
| # Get XAI analysis history | |
| cursor.execute(""" | |
| SELECT COUNT(*) as xai_analyses, | |
| AVG(interpretation_score) as avg_interpretation_score, | |
| AVG(processing_time_ms) as avg_processing_time | |
| FROM xai_analysis | |
| WHERE session_id IN ( | |
| SELECT session_id FROM face_capture_sessions | |
| WHERE user_id = %s AND timestamp >= NOW() - INTERVAL '%s days' | |
| ) | |
| """, (user_id, days)) | |
| xai_stats = cursor.fetchone() | |
| # Get A/B test participation | |
| cursor.execute(""" | |
| SELECT experiment_id, variant_id, COUNT(*) as test_count, | |
| AVG((metrics->>'accuracy')::float) as avg_accuracy, | |
| AVG((metrics->>'latency')::float) as avg_latency | |
| FROM ab_test_results | |
| WHERE user_id = %s AND timestamp >= NOW() - INTERVAL '%s days' | |
| GROUP BY experiment_id, variant_id | |
| """, (user_id, days)) | |
| ab_test_stats = cursor.fetchall() | |
| # Get face quality trends | |
| cursor.execute(""" | |
| SELECT DATE(timestamp) as date, | |
| AVG(overall_score) as avg_quality, | |
| COUNT(*) as captures_count | |
| FROM face_quality_metrics fqm | |
| JOIN face_capture_sessions fcs ON fqm.session_id = fcs.session_id | |
| WHERE fcs.user_id = %s AND fqm.timestamp >= NOW() - INTERVAL '%s days' | |
| GROUP BY DATE(timestamp) | |
| ORDER BY date | |
| """, (user_id, days)) | |
| quality_trends = cursor.fetchall() | |
| return { | |
| 'user_id': user_id, | |
| 'period_days': days, | |
| 'xai_analytics': dict(xai_stats) if xai_stats else {}, | |
| 'ab_test_participation': [dict(row) for row in ab_test_stats], | |
| 'face_quality_trends': [dict(row) for row in quality_trends], | |
| 'generated_at': datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to get user analytics: {e}") | |
| return {'error': str(e)} | |
| # Initialize integration manager | |
| integration_manager = None | |
| def init_integration_manager(db_config: Dict[str, str]): | |
| """Initialize the integration manager""" | |
| global integration_manager | |
| integration_manager = EnhancedIntegrationManager(db_config) | |
| def analyze_with_xai(): | |
| """Analyze uploaded image with advanced XAI techniques""" | |
| global xai_engine | |
| try: | |
| # Get uploaded image | |
| if 'image' not in request.files: | |
| return jsonify({'error': 'No image provided'}), 400 | |
| file = request.files['image'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No image selected'}), 400 | |
| # Process image | |
| image = Image.open(file.stream).convert('RGB') | |
| image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 | |
| # Initialize XAI engine if needed | |
| if xai_engine is None: | |
| # Load model (this would be your actual model) | |
| model = torch.nn.Sequential( | |
| torch.nn.Conv2d(3, 64, 3), | |
| torch.nn.ReLU(), | |
| torch.nn.AdaptiveAvgPool2d(1), | |
| torch.nn.Flatten(), | |
| torch.nn.Linear(64, 1) | |
| ) | |
| xai_engine = AdvancedXAIEngine(model) | |
| # Get requested methods | |
| methods = request.json.get('methods', ['integrated_gradients', 'shap', 'lime']) if request.json else ['integrated_gradients'] | |
| # Run XAI analysis | |
| explanations = xai_engine.explain_prediction(image_tensor, methods=methods) | |
| # Prepare response | |
| response_data = { | |
| 'session_id': session.get('session_id', str(uuid.uuid4())), | |
| 'explanations': {}, | |
| 'summary': { | |
| 'methods_used': list(explanations.keys()), | |
| 'average_interpretation_score': np.mean([exp.interpretation_score for exp in explanations.values()]), | |
| 'total_processing_time_ms': sum([exp.processing_time_ms for exp in explanations.values()]) | |
| } | |
| } | |
| # Convert explanations to JSON-serializable format | |
| for method, explanation in explanations.items(): | |
| response_data['explanations'][method] = { | |
| 'interpretation_score': explanation.interpretation_score, | |
| 'feature_importance': explanation.feature_importance, | |
| 'textual_explanation': explanation.textual_explanation, | |
| 'processing_time_ms': explanation.processing_time_ms, | |
| 'attribution_map': explanation.attribution_map.tolist() # Convert numpy to list | |
| } | |
| # Store in database | |
| if integration_manager: | |
| analysis_id = integration_manager.store_xai_analysis( | |
| response_data['session_id'], | |
| response_data | |
| ) | |
| response_data['analysis_id'] = analysis_id | |
| return jsonify(response_data) | |
| except Exception as e: | |
| logger.error(f"XAI analysis failed: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def compress_model(): | |
| """Compress a model with specified configuration""" | |
| global model_compressor | |
| try: | |
| # Get compression configuration | |
| config_data = request.json or {} | |
| compression_config = CompressionConfig( | |
| enable_quantization=config_data.get('enable_quantization', True), | |
| enable_pruning=config_data.get('enable_pruning', True), | |
| enable_distillation=config_data.get('enable_distillation', False), | |
| pruning_ratio=config_data.get('pruning_ratio', 0.5) | |
| ) | |
| # Initialize compressor | |
| if model_compressor is None: | |
| model_compressor = ModelCompressor(compression_config) | |
| # Create dummy model and data for demo | |
| model = torch.nn.Sequential( | |
| torch.nn.Conv2d(3, 64, 3), | |
| torch.nn.ReLU(), | |
| torch.nn.AdaptiveAvgPool2d(1), | |
| torch.nn.Flatten(), | |
| torch.nn.Linear(64, 1) | |
| ) | |
| # Simulate compression results | |
| compression_results = { | |
| 'compression_id': str(uuid.uuid4()), | |
| 'original_size_mb': 45.2, | |
| 'compressed_size_mb': 12.8, | |
| 'compression_ratio': 3.5, | |
| 'speedup_ratio': 2.8, | |
| 'accuracy_drop': 0.012, | |
| 'processing_time_ms': 15420.5, | |
| 'config': compression_config.__dict__ | |
| } | |
| # Store in database | |
| if integration_manager: | |
| compression_id = integration_manager.store_compression_metrics(compression_results) | |
| compression_results['stored_compression_id'] = compression_id | |
| return jsonify(compression_results) | |
| except Exception as e: | |
| logger.error(f"Model compression failed: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def assign_ab_test_variant(): | |
| """Assign a variant for A/B testing""" | |
| try: | |
| data = request.json or {} | |
| experiment_id = data.get('experiment_id') | |
| user_id = session.get('user_id', 'anonymous') | |
| if not experiment_id: | |
| return jsonify({'error': 'experiment_id required'}), 400 | |
| # Assign variant | |
| variant_id = ab_testing_framework.assign_variant(experiment_id, user_id) | |
| if variant_id is None: | |
| return jsonify({'error': 'Experiment not found or not active'}), 404 | |
| # Store assignment in session | |
| session[f'ab_variant_{experiment_id}'] = variant_id | |
| return jsonify({ | |
| 'experiment_id': experiment_id, | |
| 'variant_id': variant_id, | |
| 'user_id': user_id, | |
| 'assigned_at': datetime.now().isoformat() | |
| }) | |
| except Exception as e: | |
| logger.error(f"A/B test variant assignment failed: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def record_ab_test_result(): | |
| """Record an A/B test result""" | |
| try: | |
| data = request.json or {} | |
| experiment_id = data.get('experiment_id') | |
| variant_id = data.get('variant_id') | |
| metrics = data.get('metrics', {}) | |
| if not all([experiment_id, variant_id, metrics]): | |
| return jsonify({'error': 'experiment_id, variant_id, and metrics required'}), 400 | |
| # Record result | |
| if integration_manager: | |
| result_id = integration_manager.store_ab_test_result(experiment_id, variant_id, metrics) | |
| else: | |
| result_id = ab_testing_framework.record_result( | |
| experiment_id=experiment_id, | |
| variant_id=variant_id, | |
| metrics=metrics, | |
| user_id=session.get('user_id', 'anonymous'), | |
| session_id=session.get('session_id', str(uuid.uuid4())) | |
| ) | |
| return jsonify({ | |
| 'result_id': result_id, | |
| 'experiment_id': experiment_id, | |
| 'variant_id': variant_id, | |
| 'recorded_at': datetime.now().isoformat() | |
| }) | |
| except Exception as e: | |
| logger.error(f"A/B test result recording failed: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def list_ab_experiments(): | |
| """List active A/B test experiments""" | |
| try: | |
| # Get active experiments from A/B testing framework | |
| active_experiments = list(ab_testing_framework.active_experiments.keys()) | |
| experiments_info = [] | |
| for exp_id in active_experiments: | |
| config = ab_testing_framework.active_experiments[exp_id] | |
| experiments_info.append({ | |
| 'experiment_id': exp_id, | |
| 'name': getattr(config, 'name', exp_id), | |
| 'description': getattr(config, 'description', ''), | |
| 'variants': [v.variant_id for v in getattr(config, 'variants', [])], | |
| 'status': 'active' | |
| }) | |
| return jsonify({ | |
| 'experiments': experiments_info, | |
| 'total_count': len(experiments_info) | |
| }) | |
| except Exception as e: | |
| logger.error(f"Failed to list experiments: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def get_user_analytics(user_id: str): | |
| """Get comprehensive user analytics""" | |
| try: | |
| days = request.args.get('days', 30, type=int) | |
| if integration_manager: | |
| analytics = integration_manager.get_user_analytics(user_id, days) | |
| else: | |
| analytics = {'error': 'Integration manager not initialized'} | |
| return jsonify(analytics) | |
| except Exception as e: | |
| logger.error(f"Failed to get user analytics: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def get_analytics_dashboard(): | |
| """Get analytics dashboard data""" | |
| try: | |
| if not integration_manager: | |
| return jsonify({'error': 'Integration manager not initialized'}), 500 | |
| cursor = integration_manager.timescale_conn.cursor(cursor_factory=RealDictCursor) | |
| # Get XAI analysis stats | |
| cursor.execute(""" | |
| SELECT COUNT(*) as total_analyses, | |
| AVG(interpretation_score) as avg_score, | |
| COUNT(DISTINCT session_id) as unique_sessions | |
| FROM xai_analysis | |
| WHERE timestamp >= NOW() - INTERVAL '7 days' | |
| """) | |
| xai_stats = cursor.fetchone() | |
| # Get A/B test stats | |
| cursor.execute(""" | |
| SELECT experiment_id, variant_id, COUNT(*) as test_count, | |
| AVG((metrics->>'accuracy')::float) as avg_accuracy | |
| FROM ab_test_results | |
| WHERE timestamp >= NOW() - INTERVAL '7 days' | |
| GROUP BY experiment_id, variant_id | |
| """) | |
| ab_stats = cursor.fetchall() | |
| # Get face quality trends | |
| cursor.execute(""" | |
| SELECT DATE(timestamp) as date, | |
| AVG(overall_score) as avg_quality, | |
| COUNT(*) as capture_count | |
| FROM face_quality_metrics | |
| WHERE timestamp >= NOW() - INTERVAL '7 days' | |
| GROUP BY DATE(timestamp) | |
| ORDER BY date | |
| """) | |
| quality_trends = cursor.fetchall() | |
| dashboard_data = { | |
| 'xai_analytics': dict(xai_stats) if xai_stats else {}, | |
| 'ab_test_stats': [dict(row) for row in ab_stats], | |
| 'quality_trends': [dict(row) for row in quality_trends], | |
| 'generated_at': datetime.now().isoformat() | |
| } | |
| return jsonify(dashboard_data) | |
| except Exception as e: | |
| logger.error(f"Failed to get dashboard data: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def curate_dataset(): | |
| """Curate dataset using smart curation pipeline""" | |
| try: | |
| data = request.json or {} | |
| quality_threshold = data.get('quality_threshold', 0.7) | |
| target_size = data.get('target_size') | |
| # Simulate dataset curation (in practice, this would process real images) | |
| curation_results = { | |
| 'curation_id': str(uuid.uuid4()), | |
| 'original_size': 5000, | |
| 'after_quality_filter': 4250, | |
| 'duplicate_groups_found': 25, | |
| 'final_size': 4000, | |
| 'average_quality': 0.82, | |
| 'quality_threshold_used': quality_threshold, | |
| 'processing_time_ms': 45000, | |
| 'improvements': { | |
| 'duplicate_removal': 25, | |
| 'low_quality_removed': 750, | |
| 'class_balance_maintained': True | |
| } | |
| } | |
| return jsonify(curation_results) | |
| except Exception as e: | |
| logger.error(f"Dataset curation failed: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| # Error handlers | |
| def not_found(error): | |
| return jsonify({'error': 'Endpoint not found'}), 404 | |
| def internal_error(error): | |
| return jsonify({'error': 'Internal server error'}), 500 |