import chromadb import os class CVEVectorStorage: def __init__(self): try: # Initialise ChromaDB with persistent storage self.client = chromadb.PersistentClient(path="./data/chroma_db") self.collection = self.client.get_or_create_collection(name="cve_collection") # Load sentence transformer for embeddings from sentence_transformers import SentenceTransformer self.model = SentenceTransformer('all-MiniLM-L6-v2') self.enabled = True except Exception as e: print(f"Vector storage disabled: {e}") self.enabled = False def add_cve(self, cve_id, cve_data, ai_summary=None): """Store CVE in vector database for similarity search""" if not self.enabled: return try: cve = cve_data['cve'] description = cve['descriptions'][0]['value'] # Combine description with AI summary for better search text = description if ai_summary: text += f" {ai_summary}" # Generate embedding vector embedding = self.model.encode(text).tolist() # Prepare metadata for filtering metadata = { "cve_id": cve_id, "published": cve.get('published', '')[:10], "severity": self._get_severity(cve), "description": description[:200] } # Delete if exists first (avoid duplicates) try: self.collection.delete(ids=[cve_id]) except: pass # Store in vector database self.collection.add( embeddings=[embedding], documents=[text], metadatas=[metadata], ids=[cve_id] ) except Exception as e: print(f"Warning: Vector storage failed for {cve_id}: {e}") def find_similar_cves(self, cve_id, top_k=3, similarity_threshold=0.001): """Find CVEs similar to the given CVE ID using AI vector similarity""" if not self.enabled: return [] try: # Check how many CVEs we have in total count_result = self.collection.count() # If database is empty or has insufficient data, pre-populate with example CVEs if count_result < 3: # Need at least 3 CVEs for good similarity demo self._populate_example_cves(cve_id) # Pass target CVE for smart selection count_result = self.collection.count() # After population, check if we now have enough data if count_result < 2: return [] # Get target CVE text for semantic comparison target_doc = self.collection.get(ids=[cve_id], include=['documents']) if not target_doc['documents']: return [] target_text = target_doc['documents'][0] # Generate embedding for target text using AI model target_embedding = self.model.encode(target_text).tolist() # Search for similar CVEs using vector similarity results = self.collection.query( query_embeddings=[target_embedding], n_results=min(top_k + 1, count_result), include=['metadatas', 'distances'] ) # Format results, excluding the target CVE itself and applying threshold similar = [] if results and 'metadatas' in results and results['metadatas']: for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): similarity_score = max(0, 1-distance) if metadata['cve_id'] != cve_id and similarity_score >= similarity_threshold: similar.append({ 'cve_id': metadata['cve_id'], 'similarity': f"{similarity_score:.2f}", 'severity': metadata.get('severity', 'Unknown'), 'description': metadata.get('description', '') }) return similar[:top_k] except Exception as e: print(f"AI similarity search failed: {e}") return [] def _populate_example_cves(self, target_cve_id=None): """Pre-populate database with ALL example CVEs for demonstration""" try: from api_client import fetch_cve_simple # Add ALL example CVEs that users might try all_example_cves = [ 'CVE-2021-44228', # Log4Shell 'CVE-2021-4104', # Log4j 1.x 'CVE-2022-22965', # Spring4Shell 'CVE-2020-1472', # Zerologon 'CVE-2021-34527', # PrintNightmare 'CVE-2020-0601', # CurveBall 'CVE-2021-45046', # Log4j incomplete fix 'CVE-2020-5421', # Spring Framework 'CVE-2021-36934', # HiveNightmare 'CVE-2022-21999', # Windows Print Spooler 'CVE-2020-1350', # SIGRed (DNS vulnerability - similar to CurveBall in impact) 'CVE-2019-0708' # BlueKeep (Windows RDP - another Windows crypto-related) ] print("Pre-populating vector database with ALL example CVEs...") # Get existing CVE IDs to avoid duplicates existing_cves = set() try: existing = self.collection.get() if existing and 'metadatas' in existing: existing_cves = {meta.get('cve_id') for meta in existing['metadatas'] if meta.get('cve_id')} except: pass for cve_id in all_example_cves: if cve_id in existing_cves: print(f" • {cve_id} already exists") continue try: cve_data = fetch_cve_simple(cve_id) if cve_data: # Add a basic summary to match format with analyzed CVEs cve = cve_data['cve'] description = cve['descriptions'][0]['value'] basic_summary = f"Security vulnerability affecting {description.split()[0] if description else 'software'}. This is a critical vulnerability that requires immediate attention and proper security measures to mitigate potential risks." self.add_cve(cve_id, cve_data, basic_summary) print(f" • Added {cve_id}") except Exception as e: print(f" • Failed to add {cve_id}: {e}") continue print("Vector database ready with ALL examples!") except Exception as e: print(f"Failed to populate example CVEs: {e}") def search_by_text(self, query_text, top_k=5, similarity_threshold=0.3): """Search CVEs using natural language queries""" if not self.enabled: return [] try: # Check if we have any CVEs count_result = self.collection.count() if count_result == 0: return [] # Generate embedding for the search query using AI model query_embedding = self.model.encode(query_text).tolist() # Search the vector database results = self.collection.query( query_embeddings=[query_embedding], n_results=min(top_k * 2, count_result), include=['metadatas', 'distances'] ) # Format search results with threshold filtering search_results = [] if results and 'metadatas' in results and results['metadatas']: for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): similarity_score = max(0, 1 - distance) if similarity_score >= similarity_threshold: search_results.append({ 'cve_id': metadata['cve_id'], 'similarity': f"{similarity_score:.2f}", 'severity': metadata.get('severity', 'Unknown'), 'description': metadata.get('description', '') }) return search_results[:top_k] except Exception as e: print(f"AI semantic search failed: {e}") return [] def _get_severity(self, cve): """Extract CVSS severity from CVE data""" try: return cve['metrics']['cvssMetricV31'][0]['cvssData']['baseSeverity'] except: return "Unknown" # Global instance for easy access vector_storage = CVEVectorStorage()