import os import gradio as gr import shap from transformers import pipeline import torch import xgboost as xgb from sklearn.ensemble import RandomForestClassifier, VotingClassifier from sklearn.svm import SVC from sklearn.linear_model import LogisticRegression from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.preprocessing import LabelEncoder from sklearn.pipeline import FeatureUnion from sklearn.base import BaseEstimator, TransformerMixin from sentence_transformers import SentenceTransformer import pandas as pd import numpy as np import matplotlib.pyplot as plt import requests from bs4 import BeautifulSoup import json import joblib import re import time import warnings from urllib.parse import urlparse from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry import socket import threading from concurrent.futures import ThreadPoolExecutor, as_completed import ssl from datetime import datetime warnings.filterwarnings('ignore') plt.switch_backend('Agg') def load_from_drive(path): """Load model components from disk with error handling""" if os.path.exists(path): try: return joblib.load(path) except Exception as e: print(f"Error loading {path}: {e}") return None else: return None # =========== Enhanced Feature Engineering ================ class SentenceTransformerFeatures(BaseEstimator, TransformerMixin): """Transformer for generating sentence embeddings""" def __init__(self, model_name='all-MiniLM-L6-v2'): self.model_name = model_name self.model = None def fit(self, X, y=None): self.model = SentenceTransformer(self.model_name) return self def transform(self, X): if self.model is None: self.model = SentenceTransformer(self.model_name) return self.model.encode(X, show_progress_bar=False) class AdvancedFeatureEngine: """Enhanced feature engineering with multiple feature types""" def __init__(self, use_embeddings=True): self.use_embeddings = use_embeddings self.vectorizer = TfidfVectorizer( max_features=2000, ngram_range=(1, 3), stop_words='english', min_df=2, max_df=0.8, analyzer='word', sublinear_tf=True) self.sentence_transformer = None self.feature_union = None def build_feature_pipeline(self): """Build feature union pipeline""" if self.use_embeddings: self.sentence_transformer = SentenceTransformerFeatures() self.feature_union = FeatureUnion([ ('tfidf', self.vectorizer), ('embeddings', self.sentence_transformer) ]) else: self.feature_union = self.vectorizer return self.feature_union # =========== Enhanced Model Class ======================== class CalibratedVulnerabilityClassifier: """Enhanced vulnerability classifier with improved accuracy""" def __init__(self, use_embeddings=True, model_path_prefix="models/"): self.feature_engine = AdvancedFeatureEngine(use_embeddings) self.label_encoder = LabelEncoder() self.models = {} self.explainer = None self.training_complete = False self.calibration_thresholds = {} self.model_path_prefix = model_path_prefix self.xgb_model = None self.rf_model = None self.svm_model = None self.lr_model = None self.ensemble = None self.ensemble_calibrated = None self.load_models() def load_models(self): """Load trained models with fallback mechanisms""" try: self.feature_engine.vectorizer = load_from_drive( os.path.join(self.model_path_prefix, "tfidf_vectorizer.joblib")) self.label_encoder = load_from_drive( os.path.join(self.model_path_prefix, "label_encoder.joblib")) or LabelEncoder() # Initialize models with better parameters self.xgb_model = load_from_drive( os.path.join(self.model_path_prefix, "xgb_model.joblib")) or xgb.XGBClassifier( n_estimators=300, max_depth=10, learning_rate=0.1, subsample=0.8, random_state=42) self.rf_model = load_from_drive( os.path.join(self.model_path_prefix, "rf_model.joblib")) or RandomForestClassifier( n_estimators=300, max_depth=20, min_samples_split=5, random_state=42) self.lr_model = load_from_drive( os.path.join(self.model_path_prefix, "lr_model.joblib")) or LogisticRegression( C=1.0, max_iter=2000, solver='liblinear', random_state=42) self.svm_model = load_from_drive( os.path.join(self.model_path_prefix, "svm_model.joblib")) or SVC( probability=True, kernel='rbf', C=1.0, gamma='scale', random_state=42) self.ensemble = load_from_drive( os.path.join(self.model_path_prefix, "ensemble_model.joblib")) self.ensemble_calibrated = load_from_drive( os.path.join(self.model_path_prefix, "calibrated_ensemble.joblib")) self.calibration_thresholds = load_from_drive( os.path.join(self.model_path_prefix, "calibration_thresholds.joblib")) or { 'SQL Injection': 0.65, 'XSS': 0.68, 'CSRF': 0.55, 'Information Disclosure': 0.58, 'Authentication Bypass': 0.62, 'Secure Config': 0.52, 'File Inclusion': 0.60, 'Command Injection': 0.70, 'XXE': 0.65, 'SSRF': 0.63, 'IDOR': 0.58, 'Buffer Overflow': 0.72 } # Initialize label encoder with comprehensive classes if not hasattr(self.label_encoder, 'classes_') or len(self.label_encoder.classes_) < 8: self.label_encoder.fit([ 'SQL Injection', 'XSS', 'CSRF', 'Information Disclosure', 'Authentication Bypass', 'Secure Config', 'File Inclusion', 'Command Injection', 'XXE', 'SSRF', 'IDOR', 'Buffer Overflow' ]) self.training_complete = False else: self.training_complete = True except Exception as e: print(f"Model loading error: {e}") self.training_complete = False def get_meaningful_predictions(self, text): """Get vulnerability predictions with improved accuracy""" preds = [] thresholds = self.calibration_thresholds if isinstance(self.calibration_thresholds, dict) else { 'SQL Injection': 0.65, 'XSS': 0.68, 'CSRF': 0.55, 'Information Disclosure': 0.58, 'Authentication Bypass': 0.62, 'Secure Config': 0.52, 'File Inclusion': 0.60, 'Command Injection': 0.70, 'XXE': 0.65, 'SSRF': 0.63, 'IDOR': 0.58, 'Buffer Overflow': 0.72 } classes = [ 'SQL Injection', 'XSS', 'CSRF', 'Information Disclosure', 'Authentication Bypass', 'Secure Config', 'File Inclusion', 'Command Injection', 'XXE', 'SSRF', 'IDOR', 'Buffer Overflow' ] # Use ensemble model if available if self.ensemble_calibrated and self.feature_engine.vectorizer: try: X = self.feature_engine.vectorizer.transform([text]) proba = self.ensemble_calibrated.predict_proba(X)[0] sorted_idx = np.argsort(proba)[::-1] for i in sorted_idx[:8]: # Top 8 predictions if i < len(self.label_encoder.classes_): cl = self.label_encoder.classes_[i] confidence = proba[i] # Enhanced risk assessment if confidence > 0.8: risk = 'Critical' elif confidence > 0.65: risk = 'High' elif confidence > 0.45: risk = 'Medium' else: risk = 'Low' preds.append({ 'type': cl, 'confidence': float(confidence), 'threshold': thresholds.get(cl, 0.5), 'above_threshold': confidence > thresholds.get(cl, 0.5), 'risk_level': risk }) except Exception as e: print(f"Model prediction error: {e}") # Fall through to keyword analysis # Enhanced keyword-based fallback with pattern matching if not preds or len(preds) < 3: keyword_preds = self._keyword_based_analysis(text, thresholds) # Merge with existing predictions existing_types = {p['type'] for p in preds} for pred in keyword_preds: if pred['type'] not in existing_types: preds.append(pred) return sorted(preds, key=lambda x: x['confidence'], reverse=True)[:8] def _keyword_based_analysis(self, text, thresholds): """Enhanced keyword-based vulnerability analysis with improved patterns""" preds = [] text_lower = text.lower() # Enhanced SQL Injection patterns sql_patterns = [ r'\b(select|insert|update|delete|union|drop|alter|create)\b.*\b(from|into|table|database)\b', r'.*\b(sql|query).*(injection|bypass|escape)\b', r'.*(union.*select|1=1|or\s+1=1|--|;)\b', r'.*(exec\s*\(|sp_|xp_)\b' ] sql_matches = sum(len(re.findall(pattern, text_lower, re.IGNORECASE)) for pattern in sql_patterns) if sql_matches > 0: confidence = min(0.85 + sql_matches * 0.08, 0.95) preds.append({ 'type': 'SQL Injection', 'confidence': confidence, 'threshold': thresholds.get('SQL Injection', 0.65), 'above_threshold': True, 'risk_level': 'Critical' if confidence > 0.8 else 'High' }) # Enhanced XSS patterns xss_patterns = [ r'.*(script|alert|document\.cookie|onclick|onload|onerror)\b', r'.*(': 0.90, 'eval': 0.82, # Command Injection 'command': 0.80, 'injection': 0.85, 'exec': 0.85, 'system': 0.80, 'shell': 0.75, 'popen': 0.80, 'passthru': 0.80, 'subprocess': 0.78, # File Inclusion 'file': 0.75, 'include': 0.80, 'require': 0.75, 'path': 0.70, 'traversal': 0.85, 'directory': 0.65, '../': 0.88, # XXE 'xxe': 0.82, 'xml': 0.75, 'entity': 0.78, 'DOCTYPE': 0.80, # SSRF 'ssrf': 0.80, 'server.side': 0.75, 'request.forgery': 0.75, 'curl': 0.70, # Authentication 'authentication': 0.80, 'bypass': 0.85, 'login': 0.75, 'password': 0.80, 'session': 0.70, 'credential': 0.75, 'admin': 0.65, 'jwt': 0.72, # Information Disclosure 'information': 0.65, 'disclosure': 0.75, 'exposed': 0.70, 'leak': 0.75, 'password': 0.80, 'credential': 0.85, 'key': 0.80, 'token': 0.75, 'config': 0.65, 'debug': 0.70, 'error': 0.60, # Buffer Overflow 'buffer': 0.78, 'overflow': 0.82, 'stack': 0.75, 'strcpy': 0.80 } features = [] text_lower = text.lower() for word, base_importance in keywords.items(): # Count occurrences and calculate frequency-based importance count = text_lower.count(word) if count > 0: # Adjust importance based on frequency and context frequency_boost = min(count * 0.1, 0.3) context_boost = 0.1 if any(ctx in text_lower for ctx in ['vulnerability', 'security', 'attack', 'exploit', 'injection']) else 0 adjusted_importance = base_importance + frequency_boost + context_boost features.append({ 'feature': word, 'importance': float(min(adjusted_importance, 1.0)), 'in_text': True, 'count': count }) # Sort by importance and return top features features.sort(key=lambda x: x['importance'], reverse=True) return {'features': features[:top_k]} # Initialize classifier classifier = CalibratedVulnerabilityClassifier(use_embeddings=True, model_path_prefix="models/") # =========== Enhanced Port Scanner ================ class PortScanner: """Enhanced port scanner with common vulnerability ports""" def __init__(self): self.common_ports = { 21: 'FTP', 22: 'SSH', 23: 'Telnet', 25: 'SMTP', 53: 'DNS', 80: 'HTTP', 110: 'POP3', 443: 'HTTPS', 993: 'IMAPS', 995: 'POP3S', 1433: 'MSSQL', 3306: 'MySQL', 3389: 'RDP', 5432: 'PostgreSQL', 5900: 'VNC', 27017: 'MongoDB', 8080: 'HTTP-Alt', 8443: 'HTTPS-Alt', 9200: 'Elasticsearch', 11211: 'Memcached', 6379: 'Redis', 5984: 'CouchDB' } self.vulnerable_ports = { 21: 'FTP - Anonymous access possible', 23: 'Telnet - Unencrypted communication', 80: 'HTTP - Potential web vulnerabilities', 443: 'HTTPS - SSL/TLS configuration issues', 3389: 'RDP - Remote Desktop vulnerabilities', 5900: 'VNC - Unencrypted remote access', 8080: 'HTTP-Alt - Alternative web service', 9200: 'Elasticsearch - Database exposure risk', 11211: 'Memcached - Unauthenticated access', 6379: 'Redis - Unauthenticated access' } def scan_port(self, host, port, timeout=2): """Scan individual port""" try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.settimeout(timeout) result = sock.connect_ex((host, port)) if result == 0: service = self.common_ports.get(port, 'Unknown') vulnerability = self.vulnerable_ports.get(port, '') return { 'port': port, 'status': 'open', 'service': service, 'vulnerability_note': vulnerability } except: pass return None def quick_scan(self, host, max_workers=20): """Quick port scan with common ports""" open_ports = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_port = { executor.submit(self.scan_port, host, port): port for port in self.common_ports.keys() } for future in as_completed(future_to_port): result = future.result() if result: open_ports.append(result) return sorted(open_ports, key=lambda x: x['port']) # =========== Enhanced Passive Website Analyzer ======== class EnhancedPassiveAnalyzer: """Enhanced website analyzer with port scanning""" def __init__(self, classifier): self.classifier = classifier self.port_scanner = PortScanner() self.session = requests.Session() self.session.headers.update({ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' }) retry_strategy = Retry( total=3, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504], ) adapter = HTTPAdapter(max_retries=retry_strategy) self.session.mount("http://", adapter) self.session.mount("https://", adapter) requests.packages.urllib3.disable_warnings() def analyze_website(self, url, quick_mode=False, enable_port_scan=False): """Comprehensive website analysis""" analysis = { 'url': url, 'timestamp': pd.Timestamp.now().isoformat(), 'quick_mode': quick_mode, 'network_info': {}, 'content_analysis': {}, 'security_headers': {}, 'technologies': [], 'vulnerability_predictions': [], 'risk_assessment': {}, 'enhanced_features': [], 'port_scan': {}, 'ssl_info': {} } try: if not url.startswith(('http://', 'https://')): url = 'https://' + url parsed_url = urlparse(url) if not parsed_url.netloc: analysis['error'] = "Invalid URL format" return analysis domain = parsed_url.netloc # Basic request response = self.session.get( url, timeout=10 if quick_mode else 15, verify=False, allow_redirects=True ) # Network and domain information analysis['network_info'] = self.get_network_info(domain) # SSL/TLS information analysis['ssl_info'] = self.get_ssl_info(domain) # Port scanning (if enabled) if enable_port_scan and not quick_mode: try: analysis['port_scan'] = self.port_scanner.quick_scan(domain) except Exception as e: analysis['port_scan'] = {'error': f'Port scan failed: {str(e)}'} # Security headers analysis['security_headers'] = self.analyze_security_headers(response) # Technology detection analysis['technologies'] = self.detect_technologies(response) if not quick_mode: # Content analysis analysis['content_analysis'] = self.analyze_content(response) # Enhanced features analysis['enhanced_features'] = self.extract_enhanced_passive_features(url, response) # Vulnerability predictions analysis['vulnerability_predictions'] = self.predict_vulnerabilities(analysis, quick_mode) # Risk assessment analysis['risk_assessment'] = self.assess_risk(analysis) except requests.exceptions.Timeout: analysis['error'] = "Request timeout - site may be unavailable" except requests.exceptions.SSLError: analysis['error'] = "SSL certificate verification failed" except requests.exceptions.ConnectionError: analysis['error'] = "Connection error - site may be unreachable" except Exception as e: analysis['error'] = f"Analysis error: {str(e)}" return analysis def get_network_info(self, domain): """Get comprehensive network information""" try: ip = socket.gethostbyname(domain) return { 'domain': domain, 'ip_address': ip, 'resolved': True } except: return { 'domain': domain, 'ip_address': 'Unresolvable', 'resolved': False } def get_ssl_info(self, domain): """Get SSL certificate information""" try: context = ssl.create_default_context() with socket.create_connection((domain, 443), timeout=5) as sock: with context.wrap_socket(sock, server_hostname=domain) as ssock: cert = ssock.getpeercert() # Check certificate expiration expiry_date = datetime.strptime(cert['notAfter'], '%b %d %H:%M:%S %Y %Z') days_until_expiry = (expiry_date - datetime.now()).days return { 'has_ssl': True, 'issuer': dict(x[0] for x in cert['issuer']) if isinstance(cert['issuer'], tuple) else str(cert['issuer']), 'subject': dict(x[0] for x in cert['subject']) if isinstance(cert['subject'], tuple) else str(cert['subject']), 'expires_in_days': days_until_expiry, 'valid': days_until_expiry > 0 } except Exception as e: return { 'has_ssl': False, 'valid': False, 'error': str(e) } def analyze_security_headers(self, response): """Analyze security headers with enhanced checks""" headers = response.headers security_headers = {} important_headers = { 'X-Frame-Options': {'purpose': 'Clickjacking protection', 'required': True}, 'X-Content-Type-Options': {'purpose': 'MIME sniffing protection', 'required': True}, 'Strict-Transport-Security': {'purpose': 'HTTPS enforcement', 'required': True}, 'Content-Security-Policy': {'purpose': 'XSS protection', 'required': True}, 'X-XSS-Protection': {'purpose': 'XSS protection', 'required': False}, 'Referrer-Policy': {'purpose': 'Referrer info control', 'required': False}, 'Permissions-Policy': {'purpose': 'Browser features control', 'required': False} } for header, info in important_headers.items(): value = headers.get(header, 'MISSING') security_headers[header] = { 'value': value, 'status': 'PRESENT' if value != 'MISSING' else 'MISSING', 'purpose': info['purpose'], 'required': info['required'] } return security_headers def detect_technologies(self, response): """Enhanced technology detection""" technologies = [] server = response.headers.get('Server', '').lower() content = response.text.lower() # Server detection if 'apache' in server: technologies.append('Apache Web Server') elif 'nginx' in server: technologies.append('Nginx Web Server') elif 'iis' in server: technologies.append('Microsoft IIS') elif 'cloudflare' in server: technologies.append('Cloudflare') # Framework detection tech_patterns = { 'WordPress': ['wp-content', 'wp-includes', 'wordpress'], 'React': ['react', 'next.js', 'gatsby'], 'Angular': ['angular', 'ng-'], 'Vue.js': ['vue', 'vue.js'], 'Django': ['django', 'csrfmiddleware'], 'Laravel': ['laravel'], 'PHP': ['.php', 'php/'], 'jQuery': ['jquery'], 'Bootstrap': ['bootstrap'], 'Google Analytics': ['ga.js', 'google-analytics'], 'Font Awesome': ['font-awesome'] } for tech, patterns in tech_patterns.items(): if any(pattern in content for pattern in patterns): technologies.append(tech) return list(set(technologies)) def analyze_content(self, response): """Enhanced content analysis""" try: soup = BeautifulSoup(response.content, 'html.parser') text_content = soup.get_text()[:2000] # Enhanced security indicators security_indicators = { 'exposed_emails': len(re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text_content)), 'php_errors': 'php' in text_content.lower() and any(err in text_content.lower() for err in ['error', 'warning', 'notice']), 'database_errors': any(db in text_content.lower() for db in ['mysql', 'postgresql', 'oracle', 'sql server', 'database error']), 'debug_info': any(term in text_content.lower() for term in ['debug', 'test', 'development', 'staging']), 'exposed_paths': len(re.findall(r'/[\w/.-]+', text_content)) > 50, 'comments_with_info': len(re.findall(r'', text_content, re.IGNORECASE)) > 0 } return { 'text_sample': text_content[:800], 'security_indicators': security_indicators, 'forms_count': len(soup.find_all('form')), 'scripts_count': len(soup.find_all('script')), 'inputs_count': len(soup.find_all('input')), 'links_count': len(soup.find_all('a')) } except Exception as e: return { 'text_sample': f'Content analysis failed: {str(e)}', 'security_indicators': {}, 'forms_count': 0, 'scripts_count': 0, 'inputs_count': 0, 'links_count': 0 } def extract_enhanced_passive_features(self, url, response): """Extract enhanced passive security features""" features = [] domain = urlparse(url).netloc try: # Check robots.txt robots_features = self.check_robots_txt(url) features.extend(robots_features) # Check sitemap.xml sitemap_features = self.check_sitemap(url) features.extend(sitemap_features) # Check common sensitive files sensitive_files = self.check_sensitive_files(url) features.extend(sensitive_files) except Exception as e: features.append(f"Feature extraction error: {str(e)}") return features def check_robots_txt(self, url): """Check robots.txt for sensitive information""" features = [] try: robots_url = f"{url.rstrip('/')}/robots.txt" response = self.session.get(robots_url, timeout=3, verify=False) if response.status_code == 200: features.append("robots.txt present") content = response.text.lower() sensitive_paths = ['admin', 'login', 'config', 'backup', 'database', 'sql'] if any(path in content for path in sensitive_paths): features.append("sensitive paths exposed in robots.txt") except: pass return features def check_sitemap(self, url): """Check sitemap.xml for information disclosure""" features = [] try: sitemap_url = f"{url.rstrip('/')}/sitemap.xml" response = self.session.get(sitemap_url, timeout=3, verify=False) if response.status_code == 200: features.append("sitemap.xml present") except: pass return features def check_sensitive_files(self, url): """Check for common sensitive files""" features = [] sensitive_files = [ '.env', 'config.php', 'backup.sql', 'wp-config.php', 'web.config', '.git/config', 'phpinfo.php' ] for file in sensitive_files[:3]: # Check first 3 to avoid too many requests try: file_url = f"{url.rstrip('/')}/{file}" response = self.session.get(file_url, timeout=2, verify=False) if response.status_code == 200: features.append(f"sensitive file accessible: {file}") except: pass return features def predict_vulnerabilities(self, analysis, quick_mode=False): """Predict vulnerabilities based on analysis""" feature_text = self.create_feature_text(analysis) if feature_text: try: return self.classifier.get_meaningful_predictions(feature_text) except Exception as e: print(f"Prediction error: {e}") return [] return [] def create_feature_text(self, analysis): """Create feature text for vulnerability prediction""" text_parts = [] # Content analysis if 'content_analysis' in analysis: content = analysis['content_analysis'] text_parts.append(content.get('text_sample', '')) indicators = content.get('security_indicators', {}) if indicators.get('php_errors'): text_parts.append("php error messages exposed") if indicators.get('database_errors'): text_parts.append("database errors visible") if indicators.get('exposed_emails', 0) > 0: text_parts.append(f"{indicators['exposed_emails']} emails exposed") if indicators.get('comments_with_info'): text_parts.append("sensitive information in comments") # Technologies tech_text = " ".join(analysis.get('technologies', [])) text_parts.append(tech_text) # Security headers missing_headers = [ h for h, info in analysis.get('security_headers', {}).items() if info.get('status') == 'MISSING' and info.get('required', False) ] if missing_headers: text_parts.append(f"missing security headers: {', '.join(missing_headers)}") # Enhanced features enhanced_features = analysis.get('enhanced_features', []) text_parts.extend(enhanced_features) # Port scan results open_ports = analysis.get('port_scan', []) if open_ports and isinstance(open_ports, list): vulnerable_ports = [p for p in open_ports if p.get('vulnerability_note')] if vulnerable_ports: text_parts.append(f"vulnerable ports open: {[p['port'] for p in vulnerable_ports]}") # SSL information ssl_info = analysis.get('ssl_info', {}) if not ssl_info.get('valid', False): text_parts.append("ssl certificate issues") return " ".join(text_parts) def assess_risk(self, analysis): """Enhanced risk assessment""" risk_score = 0 factors = [] # Security headers missing_headers = sum( 1 for h, info in analysis.get('security_headers', {}).items() if info.get('status') == 'MISSING' and info.get('required', False) ) if missing_headers > 0: risk_score += missing_headers * 12 factors.append(f"Missing {missing_headers} critical security headers") # Content analysis indicators content = analysis.get('content_analysis', {}) indicators = content.get('security_indicators', {}) if indicators.get('php_errors'): risk_score += 25 factors.append("PHP errors exposed to users") if indicators.get('database_errors'): risk_score += 30 factors.append("Database errors visible") if indicators.get('exposed_emails', 0) > 0: risk_score += indicators['exposed_emails'] * 5 factors.append(f"{indicators['exposed_emails']} email addresses exposed") if indicators.get('comments_with_info'): risk_score += 20 factors.append("Sensitive information in HTML comments") # Vulnerability predictions vuln_predictions = analysis.get('vulnerability_predictions', []) critical_risk_vulns = sum(1 for v in vuln_predictions if v['risk_level'] == 'Critical') high_risk_vulns = sum(1 for v in vuln_predictions if v['risk_level'] == 'High') medium_risk_vulns = sum(1 for v in vuln_predictions if v['risk_level'] == 'Medium') if critical_risk_vulns > 0: risk_score += critical_risk_vulns * 40 factors.append(f"{critical_risk_vulns} critical-risk vulnerabilities predicted") if high_risk_vulns > 0: risk_score += high_risk_vulns * 25 factors.append(f"{high_risk_vulns} high-risk vulnerabilities predicted") if medium_risk_vulns > 0: risk_score += medium_risk_vulns * 15 factors.append(f"{medium_risk_vulns} medium-risk vulnerabilities predicted") # Port scan results open_ports = analysis.get('port_scan', []) if open_ports and isinstance(open_ports, list): vulnerable_ports = [p for p in open_ports if p.get('vulnerability_note')] if vulnerable_ports: risk_score += len(vulnerable_ports) * 10 factors.append(f"{len(vulnerable_ports)} potentially vulnerable ports open") # SSL issues ssl_info = analysis.get('ssl_info', {}) if not ssl_info.get('valid', False): risk_score += 20 factors.append("SSL certificate issues detected") # Determine risk level if risk_score >= 85: level, color = "CRITICAL", "#dc2626" elif risk_score >= 65: level, color = "HIGH", "#ea580c" elif risk_score >= 45: level, color = "MEDIUM", "#d97706" elif risk_score >= 20: level, color = "LOW", "#2563eb" else: level, color = "MINIMAL", "#16a34a" return { 'level': level, 'score': min(risk_score, 100), 'color': color, 'factors': factors } # Initialize analyzer analyzer = EnhancedPassiveAnalyzer(classifier) # =============== Enhanced UI Visualization ============ def create_confidence_chart(result): """Create enhanced confidence chart with proper visualization""" vulns = result.get('vulnerability_predictions', []) # Create figure with better styling plt.style.use('default') fig, ax = plt.subplots(figsize=(14, 8)) if not vulns: # Create a proper empty chart with message ax.text(0.5, 0.5, 'No vulnerabilities detected\nAll systems secure!', ha='center', va='center', transform=ax.transAxes, fontsize=18, bbox=dict(boxstyle="round,pad=0.5", facecolor="#d1fae5", edgecolor="#10b981", alpha=0.8)) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis('off') else: # Prepare data for chart vuln_types = [v['type'] for v in vulns[:8]] confidences = [v['confidence'] for v in vulns[:8]] thresholds = [v['threshold'] for v in vulns[:8]] # Enhanced color coding based on risk level colors = [] risk_colors = { 'Critical': '#991b1b', 'High': '#dc2626', 'Medium': '#ea580c', 'Low': '#2563eb' } for v in vulns[:8]: colors.append(risk_colors.get(v['risk_level'], '#6b7280')) # Create horizontal bar chart y_pos = np.arange(len(vuln_types)) bar_height = 0.6 # Create main bars bars = ax.barh(y_pos, confidences, color=colors, alpha=0.85, height=bar_height, label='Confidence') # Add threshold markers for i, (confidence, threshold) in enumerate(zip(confidences, thresholds)): ax.axvline(x=threshold, ymin=(i-bar_height/2)/len(vuln_types), ymax=(i+bar_height/2)/len(vuln_types), color='#6b7280', linestyle='--', alpha=0.8, linewidth=2) # Add threshold label ax.text(threshold + 0.01, i, f'Threshold: {threshold:.0%}', va='center', fontsize=9, color='#6b7280', fontweight='bold') # Customize the chart ax.set_yticks(y_pos) ax.set_yticklabels(vuln_types, fontsize=12, fontweight='bold') ax.set_xlabel('Confidence Score', fontsize=14, fontweight='bold', color='#374151') ax.set_title('Vulnerability Confidence Analysis', fontsize=16, fontweight='bold', color='#1f2937', pad=20) ax.set_xlim(0, 1.1) # Extra space for labels # Remove spines and add grid for spine in ['top', 'right']: ax.spines[spine].set_visible(False) ax.spines['left'].set_color('#d1d5db') ax.spines['bottom'].set_color('#d1d5db') ax.grid(axis='x', alpha=0.3, linestyle='--', color='#9ca3af') ax.set_axisbelow(True) # Add value labels on bars with better positioning for i, (bar, confidence, threshold) in enumerate(zip(bars, confidences, thresholds)): width = bar.get_width() label_x = width + 0.02 label_color = '#1f2937' # Add confidence percentage ax.text(label_x, bar.get_y() + bar.get_height()/2, f'{confidence:.1%}', ha='left', va='center', fontweight='bold', fontsize=11, color=label_color) # Add risk level inside bar if space permits if width > 0.15: ax.text(width/2, bar.get_y() + bar.get_height()/2, vulns[i]['risk_level'], ha='center', va='center', fontweight='bold', fontsize=10, color='white') # Add legend for risk levels legend_elements = [ plt.Rectangle((0,0), 1, 1, facecolor=risk_colors['Critical'], alpha=0.85, label='Critical'), plt.Rectangle((0,0), 1, 1, facecolor=risk_colors['High'], alpha=0.85, label='High'), plt.Rectangle((0,0), 1, 1, facecolor=risk_colors['Medium'], alpha=0.85, label='Medium'), plt.Rectangle((0,0), 1, 1, facecolor=risk_colors['Low'], alpha=0.85, label='Low') ] ax.legend(handles=legend_elements, loc='lower right', framealpha=0.9) plt.tight_layout() # Save with higher quality chart_path = "conf_chart.png" plt.savefig(chart_path, bbox_inches='tight', dpi=150, facecolor='white', edgecolor='none', transparent=False, pad_inches=0.1) plt.close() return chart_path def format_results(result): """Format analysis results with enhanced display""" risk = result.get('risk_assessment', {}) risk_color = risk.get('color', 'gray') # Main Card html_output = f"""
Unable to generate confidence chart: {str(e)}
All predictions are below calibrated confidence thresholds or no vulnerabilities were identified in the input.
Please check your input and try again.
No significant features identified in the input.
" def create_enhanced_dashboard(): """Create the enhanced Gradio dashboard""" legal_notice = """ ## ⚠️ Legal & Ethical Notice **Security AI is for authorized security research only.** By using this tool, you agree to: - Only scan sites you own or have explicit permission to test - Comply with all applicable laws and regulations - Not use for malicious purposes or unauthorized testing - Accept full responsibility for your actions **Analysis is PASSIVE only.** No active exploitation or intrusive scanning is performed. Port scanning is limited to common ports and should only be used on authorized systems. """ with gr.Blocks( theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"), title="Security AI Vulnerability Analyzer", css=""" .gradio-container { background: linear-gradient(135deg, #e0eaff 0%, #f8fafc 100%); font-family: 'Inter', sans-serif; } .container { max-width: 1400px; margin: 0 auto; } .footer { background: #1f2937; color: white; padding: 30px 20px; border-radius: 12px; margin-top: 30px; } """ ) as interface: gr.Markdown("""Agree to the terms, enter data, and click 'Analyze Security'