CVE_Analysis_Agent / vector_storage.py
ChienChung's picture
Upload 11 files
be9311f verified
import chromadb
import os
class CVEVectorStorage:
def __init__(self):
try:
# Initialise ChromaDB with persistent storage
self.client = chromadb.PersistentClient(path="./data/chroma_db")
self.collection = self.client.get_or_create_collection(name="cve_collection")
# Load sentence transformer for embeddings
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.enabled = True
except Exception as e:
print(f"Vector storage disabled: {e}")
self.enabled = False
def add_cve(self, cve_id, cve_data, ai_summary=None):
"""Store CVE in vector database for similarity search"""
if not self.enabled:
return
try:
cve = cve_data['cve']
description = cve['descriptions'][0]['value']
# Combine description with AI summary for better search
text = description
if ai_summary:
text += f" {ai_summary}"
# Generate embedding vector
embedding = self.model.encode(text).tolist()
# Prepare metadata for filtering
metadata = {
"cve_id": cve_id,
"published": cve.get('published', '')[:10],
"severity": self._get_severity(cve),
"description": description[:200]
}
# Delete if exists first (avoid duplicates)
try:
self.collection.delete(ids=[cve_id])
except:
pass
# Store in vector database
self.collection.add(
embeddings=[embedding],
documents=[text],
metadatas=[metadata],
ids=[cve_id]
)
except Exception as e:
print(f"Warning: Vector storage failed for {cve_id}: {e}")
def find_similar_cves(self, cve_id, top_k=3, similarity_threshold=0.001):
"""Find CVEs similar to the given CVE ID using AI vector similarity"""
if not self.enabled:
return []
try:
# Check how many CVEs we have in total
count_result = self.collection.count()
# If database is empty or has insufficient data, pre-populate with example CVEs
if count_result < 3: # Need at least 3 CVEs for good similarity demo
self._populate_example_cves(cve_id) # Pass target CVE for smart selection
count_result = self.collection.count()
# After population, check if we now have enough data
if count_result < 2:
return []
# Get target CVE text for semantic comparison
target_doc = self.collection.get(ids=[cve_id], include=['documents'])
if not target_doc['documents']:
return []
target_text = target_doc['documents'][0]
# Generate embedding for target text using AI model
target_embedding = self.model.encode(target_text).tolist()
# Search for similar CVEs using vector similarity
results = self.collection.query(
query_embeddings=[target_embedding],
n_results=min(top_k + 1, count_result),
include=['metadatas', 'distances']
)
# Format results, excluding the target CVE itself and applying threshold
similar = []
if results and 'metadatas' in results and results['metadatas']:
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
similarity_score = max(0, 1-distance)
if metadata['cve_id'] != cve_id and similarity_score >= similarity_threshold:
similar.append({
'cve_id': metadata['cve_id'],
'similarity': f"{similarity_score:.2f}",
'severity': metadata.get('severity', 'Unknown'),
'description': metadata.get('description', '')
})
return similar[:top_k]
except Exception as e:
print(f"AI similarity search failed: {e}")
return []
def _populate_example_cves(self, target_cve_id=None):
"""Pre-populate database with ALL example CVEs for demonstration"""
try:
from api_client import fetch_cve_simple
# Add ALL example CVEs that users might try
all_example_cves = [
'CVE-2021-44228', # Log4Shell
'CVE-2021-4104', # Log4j 1.x
'CVE-2022-22965', # Spring4Shell
'CVE-2020-1472', # Zerologon
'CVE-2021-34527', # PrintNightmare
'CVE-2020-0601', # CurveBall
'CVE-2021-45046', # Log4j incomplete fix
'CVE-2020-5421', # Spring Framework
'CVE-2021-36934', # HiveNightmare
'CVE-2022-21999', # Windows Print Spooler
'CVE-2020-1350', # SIGRed (DNS vulnerability - similar to CurveBall in impact)
'CVE-2019-0708' # BlueKeep (Windows RDP - another Windows crypto-related)
]
print("Pre-populating vector database with ALL example CVEs...")
# Get existing CVE IDs to avoid duplicates
existing_cves = set()
try:
existing = self.collection.get()
if existing and 'metadatas' in existing:
existing_cves = {meta.get('cve_id') for meta in existing['metadatas'] if meta.get('cve_id')}
except:
pass
for cve_id in all_example_cves:
if cve_id in existing_cves:
print(f" • {cve_id} already exists")
continue
try:
cve_data = fetch_cve_simple(cve_id)
if cve_data:
# Add a basic summary to match format with analyzed CVEs
cve = cve_data['cve']
description = cve['descriptions'][0]['value']
basic_summary = f"Security vulnerability affecting {description.split()[0] if description else 'software'}. This is a critical vulnerability that requires immediate attention and proper security measures to mitigate potential risks."
self.add_cve(cve_id, cve_data, basic_summary)
print(f" • Added {cve_id}")
except Exception as e:
print(f" • Failed to add {cve_id}: {e}")
continue
print("Vector database ready with ALL examples!")
except Exception as e:
print(f"Failed to populate example CVEs: {e}")
def search_by_text(self, query_text, top_k=5, similarity_threshold=0.3):
"""Search CVEs using natural language queries"""
if not self.enabled:
return []
try:
# Check if we have any CVEs
count_result = self.collection.count()
if count_result == 0:
return []
# Generate embedding for the search query using AI model
query_embedding = self.model.encode(query_text).tolist()
# Search the vector database
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=min(top_k * 2, count_result),
include=['metadatas', 'distances']
)
# Format search results with threshold filtering
search_results = []
if results and 'metadatas' in results and results['metadatas']:
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
similarity_score = max(0, 1 - distance)
if similarity_score >= similarity_threshold:
search_results.append({
'cve_id': metadata['cve_id'],
'similarity': f"{similarity_score:.2f}",
'severity': metadata.get('severity', 'Unknown'),
'description': metadata.get('description', '')
})
return search_results[:top_k]
except Exception as e:
print(f"AI semantic search failed: {e}")
return []
def _get_severity(self, cve):
"""Extract CVSS severity from CVE data"""
try:
return cve['metrics']['cvssMetricV31'][0]['cvssData']['baseSeverity']
except:
return "Unknown"
# Global instance for easy access
vector_storage = CVEVectorStorage()