Spaces:
Sleeping
Sleeping
| import chromadb | |
| import os | |
| class CVEVectorStorage: | |
| def __init__(self): | |
| try: | |
| # Initialise ChromaDB with persistent storage | |
| self.client = chromadb.PersistentClient(path="./data/chroma_db") | |
| self.collection = self.client.get_or_create_collection(name="cve_collection") | |
| # Load sentence transformer for embeddings | |
| from sentence_transformers import SentenceTransformer | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.enabled = True | |
| except Exception as e: | |
| print(f"Vector storage disabled: {e}") | |
| self.enabled = False | |
| def add_cve(self, cve_id, cve_data, ai_summary=None): | |
| """Store CVE in vector database for similarity search""" | |
| if not self.enabled: | |
| return | |
| try: | |
| cve = cve_data['cve'] | |
| description = cve['descriptions'][0]['value'] | |
| # Combine description with AI summary for better search | |
| text = description | |
| if ai_summary: | |
| text += f" {ai_summary}" | |
| # Generate embedding vector | |
| embedding = self.model.encode(text).tolist() | |
| # Prepare metadata for filtering | |
| metadata = { | |
| "cve_id": cve_id, | |
| "published": cve.get('published', '')[:10], | |
| "severity": self._get_severity(cve), | |
| "description": description[:200] | |
| } | |
| # Delete if exists first (avoid duplicates) | |
| try: | |
| self.collection.delete(ids=[cve_id]) | |
| except: | |
| pass | |
| # Store in vector database | |
| self.collection.add( | |
| embeddings=[embedding], | |
| documents=[text], | |
| metadatas=[metadata], | |
| ids=[cve_id] | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Vector storage failed for {cve_id}: {e}") | |
| def find_similar_cves(self, cve_id, top_k=3, similarity_threshold=0.001): | |
| """Find CVEs similar to the given CVE ID using AI vector similarity""" | |
| if not self.enabled: | |
| return [] | |
| try: | |
| # Check how many CVEs we have in total | |
| count_result = self.collection.count() | |
| # If database is empty or has insufficient data, pre-populate with example CVEs | |
| if count_result < 3: # Need at least 3 CVEs for good similarity demo | |
| self._populate_example_cves(cve_id) # Pass target CVE for smart selection | |
| count_result = self.collection.count() | |
| # After population, check if we now have enough data | |
| if count_result < 2: | |
| return [] | |
| # Get target CVE text for semantic comparison | |
| target_doc = self.collection.get(ids=[cve_id], include=['documents']) | |
| if not target_doc['documents']: | |
| return [] | |
| target_text = target_doc['documents'][0] | |
| # Generate embedding for target text using AI model | |
| target_embedding = self.model.encode(target_text).tolist() | |
| # Search for similar CVEs using vector similarity | |
| results = self.collection.query( | |
| query_embeddings=[target_embedding], | |
| n_results=min(top_k + 1, count_result), | |
| include=['metadatas', 'distances'] | |
| ) | |
| # Format results, excluding the target CVE itself and applying threshold | |
| similar = [] | |
| if results and 'metadatas' in results and results['metadatas']: | |
| for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): | |
| similarity_score = max(0, 1-distance) | |
| if metadata['cve_id'] != cve_id and similarity_score >= similarity_threshold: | |
| similar.append({ | |
| 'cve_id': metadata['cve_id'], | |
| 'similarity': f"{similarity_score:.2f}", | |
| 'severity': metadata.get('severity', 'Unknown'), | |
| 'description': metadata.get('description', '') | |
| }) | |
| return similar[:top_k] | |
| except Exception as e: | |
| print(f"AI similarity search failed: {e}") | |
| return [] | |
| def _populate_example_cves(self, target_cve_id=None): | |
| """Pre-populate database with ALL example CVEs for demonstration""" | |
| try: | |
| from api_client import fetch_cve_simple | |
| # Add ALL example CVEs that users might try | |
| all_example_cves = [ | |
| 'CVE-2021-44228', # Log4Shell | |
| 'CVE-2021-4104', # Log4j 1.x | |
| 'CVE-2022-22965', # Spring4Shell | |
| 'CVE-2020-1472', # Zerologon | |
| 'CVE-2021-34527', # PrintNightmare | |
| 'CVE-2020-0601', # CurveBall | |
| 'CVE-2021-45046', # Log4j incomplete fix | |
| 'CVE-2020-5421', # Spring Framework | |
| 'CVE-2021-36934', # HiveNightmare | |
| 'CVE-2022-21999', # Windows Print Spooler | |
| 'CVE-2020-1350', # SIGRed (DNS vulnerability - similar to CurveBall in impact) | |
| 'CVE-2019-0708' # BlueKeep (Windows RDP - another Windows crypto-related) | |
| ] | |
| print("Pre-populating vector database with ALL example CVEs...") | |
| # Get existing CVE IDs to avoid duplicates | |
| existing_cves = set() | |
| try: | |
| existing = self.collection.get() | |
| if existing and 'metadatas' in existing: | |
| existing_cves = {meta.get('cve_id') for meta in existing['metadatas'] if meta.get('cve_id')} | |
| except: | |
| pass | |
| for cve_id in all_example_cves: | |
| if cve_id in existing_cves: | |
| print(f" • {cve_id} already exists") | |
| continue | |
| try: | |
| cve_data = fetch_cve_simple(cve_id) | |
| if cve_data: | |
| # Add a basic summary to match format with analyzed CVEs | |
| cve = cve_data['cve'] | |
| description = cve['descriptions'][0]['value'] | |
| basic_summary = f"Security vulnerability affecting {description.split()[0] if description else 'software'}. This is a critical vulnerability that requires immediate attention and proper security measures to mitigate potential risks." | |
| self.add_cve(cve_id, cve_data, basic_summary) | |
| print(f" • Added {cve_id}") | |
| except Exception as e: | |
| print(f" • Failed to add {cve_id}: {e}") | |
| continue | |
| print("Vector database ready with ALL examples!") | |
| except Exception as e: | |
| print(f"Failed to populate example CVEs: {e}") | |
| def search_by_text(self, query_text, top_k=5, similarity_threshold=0.3): | |
| """Search CVEs using natural language queries""" | |
| if not self.enabled: | |
| return [] | |
| try: | |
| # Check if we have any CVEs | |
| count_result = self.collection.count() | |
| if count_result == 0: | |
| return [] | |
| # Generate embedding for the search query using AI model | |
| query_embedding = self.model.encode(query_text).tolist() | |
| # Search the vector database | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=min(top_k * 2, count_result), | |
| include=['metadatas', 'distances'] | |
| ) | |
| # Format search results with threshold filtering | |
| search_results = [] | |
| if results and 'metadatas' in results and results['metadatas']: | |
| for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): | |
| similarity_score = max(0, 1 - distance) | |
| if similarity_score >= similarity_threshold: | |
| search_results.append({ | |
| 'cve_id': metadata['cve_id'], | |
| 'similarity': f"{similarity_score:.2f}", | |
| 'severity': metadata.get('severity', 'Unknown'), | |
| 'description': metadata.get('description', '') | |
| }) | |
| return search_results[:top_k] | |
| except Exception as e: | |
| print(f"AI semantic search failed: {e}") | |
| return [] | |
| def _get_severity(self, cve): | |
| """Extract CVSS severity from CVE data""" | |
| try: | |
| return cve['metrics']['cvssMetricV31'][0]['cvssData']['baseSeverity'] | |
| except: | |
| return "Unknown" | |
| # Global instance for easy access | |
| vector_storage = CVEVectorStorage() |