HimanshuGoyal2004 commited on
Commit
fbc9b21
Β·
1 Parent(s): f8c7c4a

updated retriever

Browse files
Files changed (2) hide show
  1. app.py +142 -6
  2. requirements.txt +6 -1
app.py CHANGED
@@ -5,12 +5,16 @@ from typing import Dict, List, Any
5
  import requests
6
  import gradio as gr
7
  from dotenv import load_dotenv
 
 
 
 
8
 
9
  # Load environment variables
10
  load_dotenv()
11
 
12
  class GitHubMCPServer:
13
- """GitHub MCP Server for repository scanning and file access"""
14
 
15
  def __init__(self):
16
  self.github_token = os.getenv("GITHUB_TOKEN")
@@ -21,6 +25,81 @@ class GitHubMCPServer:
21
  "Authorization": f"token {self.github_token}",
22
  "Accept": "application/vnd.github.v3+json"
23
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def get_repository_info(self, owner: str, repo: str) -> dict:
26
  """Get basic repository information"""
@@ -117,6 +196,52 @@ class GitHubMCPServer:
117
  self._scan_directory_sync(owner, repo, item["path"], extensions, all_files)
118
  except Exception:
119
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  # Initialize the GitHub MCP server
122
  github_server = GitHubMCPServer()
@@ -158,19 +283,30 @@ demo = gr.TabbedInterface(
158
  title="Scan Repository for Code Files",
159
  description="Scan a GitHub repository for code files with specified extensions",
160
  api_name="scan_repository"
 
 
 
 
 
 
 
 
 
 
161
  )
162
  ],
163
  [
164
  "Repository Info",
165
  "File Content",
166
- "Repository Scanner"
 
167
  ],
168
- title="πŸ™ GitHub MCP Server"
169
  )
170
 
171
  if __name__ == "__main__":
172
- print("πŸš€ Starting GitHub MCP Server with Gradio...")
173
- print("πŸ“‘ Server will provide GitHub repository access via MCP")
174
- print("πŸ› οΈ Available tools: repository info, file content, repository scanner")
175
 
176
  demo.launch(mcp_server=True)
 
5
  import requests
6
  import gradio as gr
7
  from dotenv import load_dotenv
8
+ from datasets import load_dataset
9
+ from langchain.docstore.document import Document
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.retrievers import BM25Retriever
12
 
13
  # Load environment variables
14
  load_dotenv()
15
 
16
  class GitHubMCPServer:
17
+ """GitHub MCP Server for repository scanning, file access, and CVE retrieval"""
18
 
19
  def __init__(self):
20
  self.github_token = os.getenv("GITHUB_TOKEN")
 
25
  "Authorization": f"token {self.github_token}",
26
  "Accept": "application/vnd.github.v3+json"
27
  }
28
+
29
+ # Initialize CVE retriever
30
+ self.cve_retriever = None
31
+ self._initialize_cve_retriever()
32
+
33
+ def _initialize_cve_retriever(self):
34
+ """Initialize the CVE retriever with Hugging Face dataset"""
35
+ try:
36
+ print("πŸ”„ Loading CVE dataset from Hugging Face...")
37
+
38
+ # Load CVE dataset from Hugging Face
39
+ # Login using `huggingface-cli login` to access this dataset
40
+ dataset = load_dataset("Baction/cve", split="train")
41
+
42
+ print(f"πŸ“Š Loaded {len(dataset)} CVE records from Hugging Face")
43
+
44
+ # Create documents from CVE data
45
+ documents = []
46
+ for idx, record in enumerate(dataset):
47
+ # Extract relevant fields from the dataset
48
+ cve_id = record.get('cve_id', f'CVE-{idx}')
49
+ cwe_code = record.get('cwe_code', 'Unknown')
50
+ cwe_name = record.get('cwe_name', 'Unknown')
51
+ cvss_score = record.get('cvss_score', record.get('cvss', 'N/A'))
52
+ summary = record.get('summary', record.get('description', 'No summary available'))
53
+
54
+ # Skip records without essential information
55
+ if not summary or summary == 'No summary available':
56
+ continue
57
+
58
+ # Create document content
59
+ content = f"""
60
+ CVE ID: {cve_id}
61
+ CWE Code: {cwe_code}
62
+ CWE Name: {cwe_name}
63
+ CVSS Score: {cvss_score}
64
+ Summary: {summary}
65
+ """
66
+
67
+ # Create metadata
68
+ metadata = {
69
+ 'cve_id': str(cve_id),
70
+ 'cwe_code': str(cwe_code),
71
+ 'cwe_name': str(cwe_name),
72
+ 'cvss': cvss_score,
73
+ }
74
+
75
+ documents.append(Document(page_content=content.strip(), metadata=metadata))
76
+
77
+ print(f"πŸ“ Created {len(documents)} CVE documents")
78
+
79
+ # Split documents for better retrieval
80
+ text_splitter = RecursiveCharacterTextSplitter(
81
+ chunk_size=500, # Increased chunk size for better context
82
+ chunk_overlap=50,
83
+ add_start_index=True,
84
+ strip_whitespace=True,
85
+ separators=["\n\n", "\n", ".", " "]
86
+ )
87
+
88
+ processed_docs = text_splitter.split_documents(documents)
89
+
90
+ # Initialize BM25 retriever
91
+ self.cve_retriever = BM25Retriever.from_documents(
92
+ processed_docs,
93
+ k=10 # Return top 10 most relevant documents
94
+ )
95
+
96
+ print(f"βœ… CVE Retriever initialized with {len(processed_docs)} document chunks")
97
+
98
+ except Exception as e:
99
+ print(f"❌ Error initializing CVE retriever: {str(e)}")
100
+ print("πŸ’‘ Make sure you have access to the Hugging Face dataset 'Baction/cve'")
101
+ print("πŸ’‘ You may need to login with: huggingface-cli login")
102
+ self.cve_retriever = None
103
 
104
  def get_repository_info(self, owner: str, repo: str) -> dict:
105
  """Get basic repository information"""
 
196
  self._scan_directory_sync(owner, repo, item["path"], extensions, all_files)
197
  except Exception:
198
  pass
199
+
200
+ def search_cve_database(self, query: str) -> str:
201
+ """Search CVE database for relevant vulnerability information"""
202
+ if not self.cve_retriever:
203
+ return "❌ CVE retriever not properly initialized. Please check Hugging Face dataset access."
204
+
205
+ try:
206
+ # Retrieve relevant documents
207
+ docs = self.cve_retriever.invoke(query)
208
+
209
+ if not docs:
210
+ return f"No relevant CVE information found for query: '{query}'"
211
+
212
+ # Format the retrieved CVE information
213
+ result = f"πŸ” **CVE Knowledge Base Results for: '{query}'**\n\n"
214
+
215
+ for i, doc in enumerate(docs, 1):
216
+ metadata = doc.metadata
217
+ result += f"**Result {i}:**\n"
218
+ result += f"- **CVE ID**: {metadata.get('cve_id', 'Unknown')}\n"
219
+ result += f"- **CWE Code**: {metadata.get('cwe_code', 'Unknown')}\n"
220
+ result += f"- **CWE Name**: {metadata.get('cwe_name', 'Unknown')}\n"
221
+ result += f"- **CVSS Score**: {metadata.get('cvss', 'N/A')}\n"
222
+
223
+ # Extract summary from content
224
+ content_lines = doc.page_content.split('\n')
225
+ summary_line = next((line for line in content_lines if line.startswith('Summary:')), '')
226
+ summary = summary_line.replace('Summary: ', '').strip() if summary_line else 'No summary available'
227
+
228
+ result += f"- **Description**: {summary[:200]}{'...' if len(summary) > 200 else ''}\n"
229
+ result += "---\n"
230
+
231
+ # Add summary of common patterns
232
+ cve_ids = [doc.metadata.get('cve_id') for doc in docs if doc.metadata.get('cve_id')]
233
+ cwe_codes = [doc.metadata.get('cwe_code') for doc in docs if doc.metadata.get('cwe_code') and doc.metadata.get('cwe_code') != 'Unknown']
234
+ unique_cwes = list(set(cwe_codes))
235
+
236
+ result += f"\n**πŸ“Š Analysis Summary:**\n"
237
+ result += f"- **CVE Examples**: {', '.join(cve_ids[:3])}{'...' if len(cve_ids) > 3 else ''}\n"
238
+ result += f"- **Common CWE Codes**: {', '.join(unique_cwes[:5])}\n"
239
+ result += f"- **Total Matches**: {len(docs)}\n"
240
+
241
+ return result
242
+
243
+ except Exception as e:
244
+ return f"❌ Error retrieving CVE information: {str(e)}"
245
 
246
  # Initialize the GitHub MCP server
247
  github_server = GitHubMCPServer()
 
283
  title="Scan Repository for Code Files",
284
  description="Scan a GitHub repository for code files with specified extensions",
285
  api_name="scan_repository"
286
+ ),
287
+ gr.Interface(
288
+ fn=github_server.search_cve_database,
289
+ inputs=[
290
+ gr.Textbox(label="Vulnerability Query", placeholder="SQL injection, XSS, command injection, etc.")
291
+ ],
292
+ outputs=gr.Textbox(label="CVE Search Results", lines=25),
293
+ title="Search CVE Database",
294
+ description="Search the CVE knowledge base for vulnerability patterns and CWE information",
295
+ api_name="search_cve_database"
296
  )
297
  ],
298
  [
299
  "Repository Info",
300
  "File Content",
301
+ "Repository Scanner",
302
+ "CVE Database"
303
  ],
304
+ title="πŸ™ GitHub MCP Server with CVE Knowledge Base"
305
  )
306
 
307
  if __name__ == "__main__":
308
+ print("πŸš€ Starting GitHub MCP Server with CVE Knowledge Base...")
309
+ print("πŸ“‘ Server will provide GitHub repository access and CVE search via MCP")
310
+ print("πŸ› οΈ Available tools: repository info, file content, repository scanner, CVE database search")
311
 
312
  demo.launch(mcp_server=True)
requirements.txt CHANGED
@@ -5,4 +5,9 @@ mcp==1.10.1
5
  smolagents>=0.1.0
6
  requests>=2.28.0
7
  python-dotenv>=1.0.0
8
- pydantic>=2.11,<2.12
 
 
 
 
 
 
5
  smolagents>=0.1.0
6
  requests>=2.28.0
7
  python-dotenv>=1.0.0
8
+ pydantic>=2.11,<2.12
9
+ datasets>=2.0.0
10
+ langchain>=0.1.0
11
+ langchain-community>=0.0.20
12
+ sentence-transformers>=2.2.0
13
+ rank-bm25>=0.2.2