Rithvickkr commited on
Commit
6e9c0fd
·
1 Parent(s): 21b1e62

Integrated NVD API via Space variables, fixed multi-threat detection and irrelevant CVEs

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +92 -69
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ corpus/cache/
app.py CHANGED
@@ -1,7 +1,5 @@
1
  import gradio as gr
2
  import requests
3
- from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
4
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
5
  import os
6
  import re
7
  import ast
@@ -12,32 +10,20 @@ import time
12
  import logging
13
  from retrying import retry
14
  import base64
 
15
 
16
  # Set up logging
17
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
18
  logger = logging.getLogger(__name__)
19
 
20
- # Suppress Hugging Face symlink warning
21
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
22
 
23
  # Modal Mistral-7B API endpoint
24
- MODAL_API = "https://rithvickkumar27--mistral-7b-api-analyze.modal.run"
25
-
26
- # Configure LlamaIndex
27
- Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
28
-
29
- # Initialize LlamaIndex
30
- def init_llama_index():
31
- try:
32
- documents = SimpleDirectoryReader("corpus", filename_as_id=True).load_data()
33
- logger.info(f"Loaded {len(documents)} corpus documents")
34
- return VectorStoreIndex.from_documents(documents)
35
- except Exception as e:
36
- logger.error(f"Error loading corpus: {e}")
37
- return None
38
-
39
- index = init_llama_index()
40
- query_engine = index.as_retriever() if index else None
41
 
42
  # Retry decorator for Mistral-7B API
43
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
@@ -61,25 +47,63 @@ def call_mistral_llm(prompt):
61
  logger.error(f"Mistral API request failed: {e}")
62
  raise
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Basic Python code analysis
65
  def analyze_python_code(content: str) -> dict:
66
  try:
67
  tree = ast.parse(content)
68
  suspicious_patterns = []
69
  for node in ast.walk(tree):
70
- # Check for Base64 decoding
71
  if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
72
  if node.func.attr == 'b64decode' and isinstance(node.func.value, ast.Name) and node.func.value.id == 'base64':
73
  suspicious_patterns.append("Base64 decoding detected")
74
- # Check for exec usage
75
  if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == 'exec':
76
- suspicious_patterns.append("Dynamic code execution (exec) detected")
77
- # Check for urllib.request or similar imports
78
  if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
79
  for name in (node.names if isinstance(node, ast.Import) else node.names):
80
  if name.name in ['urllib', 'urllib.request', 'requests']:
81
  suspicious_patterns.append(f"Suspicious import: {name.name}")
82
- # Check for suspicious URLs in strings
83
  if isinstance(node, ast.Str) or (isinstance(node, ast.Constant) and isinstance(node.value, str)):
84
  if re.search(r'http[s]?://.*(evil|malicious|bad)[^\s]*', node.value, re.IGNORECASE):
85
  suspicious_patterns.append(f"Suspicious URL: {node.value}")
@@ -88,7 +112,7 @@ def analyze_python_code(content: str) -> dict:
88
  "classification": "Malware Detected",
89
  "severity": "Critical",
90
  "mitigation": "Quarantine file, run antivirus, block suspicious URLs",
91
- "confidence": 0.95,
92
  "details": suspicious_patterns
93
  }
94
  except SyntaxError:
@@ -233,10 +257,10 @@ def dsatp_parse_log(text: str) -> dict:
233
  severity_order = {"Critical": 3, "High": 2, "Medium": 1, "Safe": 0}
234
  highest_threat = max(detected_threats, key=lambda x: (severity_order.get(x["severity"], 0), x["confidence"]))
235
  logger.info(f"Detected threats: {len(detected_threats)}, Selected: {highest_threat}")
236
- return highest_threat
237
 
238
  logger.info("No threats detected")
239
- return {"classification": "No Threat", "severity": "Safe", "mitigation": "None", "confidence": 0.5}
240
 
241
  # Enhanced DSATP YARA scanning
242
  def dsatp_yara_scan(file_path: str) -> dict:
@@ -249,7 +273,7 @@ def dsatp_yara_scan(file_path: str) -> dict:
249
  if file_path.endswith('.py'):
250
  python_analysis = analyze_python_code(content)
251
  if python_analysis:
252
- return python_analysis
253
 
254
  import yara
255
  rules = yara.compile(source="""
@@ -352,7 +376,7 @@ def dsatp_yara_scan(file_path: str) -> dict:
352
  "classification": "Malware Detected",
353
  "severity": "Critical",
354
  "mitigation": "Quarantine file, run antivirus",
355
- "confidence": 0.95
356
  })
357
  elif match.rule == "SuspiciousBehavior":
358
  detected_threats.append({
@@ -402,18 +426,19 @@ def dsatp_yara_scan(file_path: str) -> dict:
402
  severity_order = {"Critical": 3, "High": 2, "Medium": 1, "Safe": 0}
403
  highest_threat = max(detected_threats, key=lambda x: (severity_order.get(x["severity"], 0), x["confidence"]))
404
  logger.info(f"YARA scan detected threats: {len(detected_threats)}, Selected: {highest_threat}")
405
- return highest_threat
406
 
407
  logger.info("YARA scan: No threats detected")
408
  return {
409
  "classification": "No Malware",
410
  "severity": "Safe",
411
  "mitigation": "None",
412
- "confidence": 0.7
 
413
  }
414
  except Exception as e:
415
  logger.error(f"YARA scan error: {e}")
416
- return {"error": str(e), "severity": "Unknown", "mitigation": "Check file format"}
417
 
418
  # Chatbot function
419
  def chatbot_response(user_input, file, history, state):
@@ -428,56 +453,54 @@ def chatbot_response(user_input, file, history, state):
428
  try:
429
  input_text = open(file.name, "r").read()
430
  scan_result = dsatp_yara_scan(file.name)
431
- all_threats.append(scan_result)
432
  except Exception as e:
433
- scan_result = {"error": f"File error: {e}", "severity": "Unknown", "mitigation": "Check file"}
434
  else:
435
  scan_result = dsatp_parse_log(input_text)
436
- all_threats.append(scan_result)
437
 
 
438
  context_str = "No relevant vulnerabilities found."
439
- if query_engine:
440
  try:
441
  # Map classification to precise keywords for relevant CVEs
442
  threat_keywords = {
443
- "Brute-Force Attempt": "brute force, ssh, login attempt, authentication failure, openssh, password attack, cwe-287, cwe-307",
444
- "Malware Detected": "malware, trojan, ransomware, payload, malicious script, backdoor, virus, cwe-94, cwe-506, cwe-119",
445
- "Network Intrusion": "firewall, intrusion, ufw, network attack, port scan, unauthorized access, cwe-284",
446
- "Privilege Escalation": "privilege escalation, sudo, root, unauthorized access, cwe-269, cwe-250",
447
- "Persistence Mechanism": "ssh tunnel, reverse ssh, persistence, backdoor, remote access, cwe-284",
448
- "System Compromise": "compromise, breach, unauthorized access, cwe-284",
449
- "Unauthorized Access": "unauthorized access, login failure, cwe-287",
450
- "Resource Abuse": "resource abuse, crypto-miner, denial of service, cwe-400",
451
- "Firmware Vulnerability": "firmware, vulnerability, iot, cwe-119",
452
- "DDoS Attack": "ddos, denial of service, network flood, cwe-400",
453
- "Phishing Attempt": "phishing, malicious url, social engineering, cwe-601",
454
- "SQL Injection": "sql injection, database attack, cwe-89",
455
- "Cross-Site Scripting": "xss, cross-site scripting, web attack, cwe-79",
456
- "Suspicious Activity": "suspicious activity, anomaly, heuristic, cwe-693"
457
  }
458
  classification = scan_result.get("classification", "unknown")
459
- keywords = threat_keywords.get(classification, "security threat")
460
- query = f"Mitigation for: {keywords}"
461
- results = query_engine.retrieve(query)
462
- context_items = []
463
- seen_cves = set()
464
- for res in results[:3]:
465
- # Extract CVE IDs and filter duplicates
466
- cve_matches = re.findall(r'CVE-\d{4}-\d{5,7}', res.text)
467
- if cve_matches and cve_matches[0] not in seen_cves:
468
- context_items.append(res.text)
469
- seen_cves.add(cve_matches[0])
470
- context_str = "\n\n".join(context_items) if context_items else "No relevant vulnerabilities found for this threat."
471
- logger.debug(f"LlamaIndex query: {query}, Results: {len(results)}, Unique CVEs: {len(seen_cves)}")
472
  except Exception as e:
473
- logger.error(f"LlamaIndex error: {e}")
474
  context_str = f"Context error: {e}"
475
 
476
  if "error" not in scan_result:
477
- # Prepare list of all detected threats for Mistral-7B
478
- other_threats = [t for t in all_threats if t != scan_result and "error" not in t]
479
  other_threats_summary = ""
480
- if other_threats:
 
481
  other_threats_summary = "\nOther detected threats include:\n" + "\n".join(
482
  [f"- {t['classification']} (Severity: {t['severity']}, Confidence: {t['confidence']:.1f})" for t in other_threats]
483
  )
@@ -489,7 +512,7 @@ def chatbot_response(user_input, file, history, state):
489
  Mitigation: {scan_result['mitigation']}
490
  Confidence: {scan_result['confidence']}
491
  Additional Threats: {other_threats_summary}
492
- Provide a concise response to the user, summarizing the primary threat and recommended actions in a professional tone. If additional threats are detected, briefly mention them but focus on the primary threat. Include actionable steps tailored to the primary threat. Do not mention vulnerabilities from the context unless they are directly related to the detected threat.
493
  """
494
  try:
495
  llm_response = call_mistral_llm(prompt)
 
1
  import gradio as gr
2
  import requests
 
 
3
  import os
4
  import re
5
  import ast
 
10
  import logging
11
  from retrying import retry
12
  import base64
13
+ import pickle
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Suppress warnings
20
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
21
 
22
  # Modal Mistral-7B API endpoint
23
+ MODAL_API = os.getenv("MODAL_API", "https://rithvickkumar27--mistral-7b-api-analyze.modal.run")
24
+ NVD_API_KEY = os.getenv("NVD_API_KEY")
25
+ if not NVD_API_KEY:
26
+ logger.error("NVD_API_KEY not set in environment variables, API queries will fail")
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Retry decorator for Mistral-7B API
29
  @retry(stop_max_attempt_number=3, wait_fixed=2000)
 
47
  logger.error(f"Mistral API request failed: {e}")
48
  raise
49
 
50
+ # NVD API query with caching
51
+ def query_nvd(keywords):
52
+ cache_dir = "corpus/cache"
53
+ os.makedirs(cache_dir, exist_ok=True)
54
+ cache_file = f"{cache_dir}/{keywords.replace(' ', '_')}.pkl"
55
+
56
+ # Check cache
57
+ if os.path.exists(cache_file):
58
+ try:
59
+ with open(cache_file, "rb") as f:
60
+ cached_data = pickle.load(f)
61
+ if time.time() - cached_data["timestamp"] < 86400: # Cache valid for 24 hours
62
+ logger.debug(f"Using cached NVD data for: {keywords}")
63
+ return cached_data["results"]
64
+ except Exception as e:
65
+ logger.warning(f"Cache read error: {e}")
66
+
67
+ # Query NVD API
68
+ try:
69
+ url = "https://services.nvd.nist.gov/rest/json/cves/2.0"
70
+ params = {"keywordSearch": keywords, "resultsPerPage": 10}
71
+ headers = {"apiKey": NVD_API_KEY}
72
+ response = requests.get(url, params=params, headers=headers, timeout=10)
73
+ if response.status_code == 200:
74
+ data = response.json()
75
+ results = [
76
+ f"{item['cve']['id']}: {item['cve']['descriptions'][0]['value']}"
77
+ for item in data.get("vulnerabilities", [])
78
+ ]
79
+ # Save to cache
80
+ with open(cache_file, "wb") as f:
81
+ pickle.dump({"timestamp": time.time(), "results": results}, f)
82
+ logger.info(f"Fetched {len(results)} CVEs from NVD for: {keywords}")
83
+ return results
84
+ elif response.status_code == 429:
85
+ logger.error("NVD rate limit exceeded")
86
+ else:
87
+ logger.error(f"NVD API error: Status {response.status_code}")
88
+ except Exception as e:
89
+ logger.error(f"NVD API request failed: {e}")
90
+ return None
91
+
92
  # Basic Python code analysis
93
  def analyze_python_code(content: str) -> dict:
94
  try:
95
  tree = ast.parse(content)
96
  suspicious_patterns = []
97
  for node in ast.walk(tree):
 
98
  if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
99
  if node.func.attr == 'b64decode' and isinstance(node.func.value, ast.Name) and node.func.value.id == 'base64':
100
  suspicious_patterns.append("Base64 decoding detected")
 
101
  if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == 'exec':
102
+ suspicious_patterns.append("Dynamic code execution (exec)")
 
103
  if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
104
  for name in (node.names if isinstance(node, ast.Import) else node.names):
105
  if name.name in ['urllib', 'urllib.request', 'requests']:
106
  suspicious_patterns.append(f"Suspicious import: {name.name}")
 
107
  if isinstance(node, ast.Str) or (isinstance(node, ast.Constant) and isinstance(node.value, str)):
108
  if re.search(r'http[s]?://.*(evil|malicious|bad)[^\s]*', node.value, re.IGNORECASE):
109
  suspicious_patterns.append(f"Suspicious URL: {node.value}")
 
112
  "classification": "Malware Detected",
113
  "severity": "Critical",
114
  "mitigation": "Quarantine file, run antivirus, block suspicious URLs",
115
+ "confidence": 0.9,
116
  "details": suspicious_patterns
117
  }
118
  except SyntaxError:
 
257
  severity_order = {"Critical": 3, "High": 2, "Medium": 1, "Safe": 0}
258
  highest_threat = max(detected_threats, key=lambda x: (severity_order.get(x["severity"], 0), x["confidence"]))
259
  logger.info(f"Detected threats: {len(detected_threats)}, Selected: {highest_threat}")
260
+ return highest_threat | {"all_threats": detected_threats}
261
 
262
  logger.info("No threats detected")
263
+ return {"classification": "No Threat", "severity": "Safe", "mitigation": "None", "confidence": 0.5, "all_threats": []}
264
 
265
  # Enhanced DSATP YARA scanning
266
  def dsatp_yara_scan(file_path: str) -> dict:
 
273
  if file_path.endswith('.py'):
274
  python_analysis = analyze_python_code(content)
275
  if python_analysis:
276
+ return python_analysis | {"all_threats": [python_analysis]}
277
 
278
  import yara
279
  rules = yara.compile(source="""
 
376
  "classification": "Malware Detected",
377
  "severity": "Critical",
378
  "mitigation": "Quarantine file, run antivirus",
379
+ "confidence": 0.9
380
  })
381
  elif match.rule == "SuspiciousBehavior":
382
  detected_threats.append({
 
426
  severity_order = {"Critical": 3, "High": 2, "Medium": 1, "Safe": 0}
427
  highest_threat = max(detected_threats, key=lambda x: (severity_order.get(x["severity"], 0), x["confidence"]))
428
  logger.info(f"YARA scan detected threats: {len(detected_threats)}, Selected: {highest_threat}")
429
+ return highest_threat | {"all_threats": detected_threats}
430
 
431
  logger.info("YARA scan: No threats detected")
432
  return {
433
  "classification": "No Malware",
434
  "severity": "Safe",
435
  "mitigation": "None",
436
+ "confidence": 0.7,
437
+ "all_threats": []
438
  }
439
  except Exception as e:
440
  logger.error(f"YARA scan error: {e}")
441
+ return {"error": str(e), "severity": "Unknown", "mitigation": "Check file format", "all_threats": []}
442
 
443
  # Chatbot function
444
  def chatbot_response(user_input, file, history, state):
 
453
  try:
454
  input_text = open(file.name, "r").read()
455
  scan_result = dsatp_yara_scan(file.name)
 
456
  except Exception as e:
457
+ scan_result = {"error": f"File error: {e}", "severity": "Unknown", "mitigation": "Check file", "all_threats": []}
458
  else:
459
  scan_result = dsatp_parse_log(input_text)
 
460
 
461
+ all_threats = scan_result.get("all_threats", [])
462
  context_str = "No relevant vulnerabilities found."
463
+ if NVD_API_KEY:
464
  try:
465
  # Map classification to precise keywords for relevant CVEs
466
  threat_keywords = {
467
+ "Brute-Force Attempt": "brute force ssh login attempt authentication failure openssh password attack cwe-287 cwe-307",
468
+ "Malware Detected": "malware trojan ransomware payload malicious script backdoor virus cwe-94 cwe-506 cwe-119 code injection python",
469
+ "Network Intrusion": "firewall intrusion ufw network attack port scan unauthorized access cwe-284",
470
+ "Privilege Escalation": "privilege escalation sudo root unauthorized access cwe-269 cwe-250",
471
+ "Persistence Mechanism": "ssh tunnel reverse ssh persistence backdoor remote access cwe-284",
472
+ "System Compromise": "compromise breach unauthorized access cwe-284",
473
+ "Unauthorized Access": "unauthorized access login failure cwe-287",
474
+ "Resource Abuse": "resource abuse crypto-miner denial of service cwe-400",
475
+ "Firmware Vulnerability": "firmware vulnerability iot cwe-119",
476
+ "DDoS Attack": "ddos denial of service network flood cwe-400",
477
+ "Phishing Attempt": "phishing malicious url social engineering cwe-601",
478
+ "SQL Injection": "sql injection database attack cwe-89",
479
+ "Cross-Site Scripting": "xss cross-site scripting web attack cwe-79",
480
+ "Suspicious Activity": "suspicious activity anomaly heuristic cwe-693"
481
  }
482
  classification = scan_result.get("classification", "unknown")
483
+ keywords = threat_keywords.get(classification, "security threat").replace(',', '')
484
+ nvd_results = query_nvd(keywords)
485
+ if nvd_results:
486
+ context_items = []
487
+ seen_cves = set()
488
+ for result in nvd_results[:3]:
489
+ cve_matches = re.findall(r'CVE-\d{4}-\d{5,7}', result)
490
+ if cve_matches and cve_matches[0] not in seen_cves:
491
+ if any(keyword.lower() in result.lower() for keyword in keywords.split()):
492
+ context_items.append(result)
493
+ seen_cves.add(cve_matches[0])
494
+ context_str = "\n\n".join(context_items) if context_items else "No relevant vulnerabilities found for this threat."
495
+ logger.debug(f"NVD query: {keywords}, Results: {len(nvd_results) if nvd_results else 0}")
496
  except Exception as e:
497
+ logger.error(f"NVD error: {e}")
498
  context_str = f"Context error: {e}"
499
 
500
  if "error" not in scan_result:
 
 
501
  other_threats_summary = ""
502
+ if len(all_threats) > 1:
503
+ other_threats = [t for t in all_threats if t != scan_result]
504
  other_threats_summary = "\nOther detected threats include:\n" + "\n".join(
505
  [f"- {t['classification']} (Severity: {t['severity']}, Confidence: {t['confidence']:.1f})" for t in other_threats]
506
  )
 
512
  Mitigation: {scan_result['mitigation']}
513
  Confidence: {scan_result['confidence']}
514
  Additional Threats: {other_threats_summary}
515
+ Provide a concise response to the user, summarizing the primary threat and recommended actions in a professional tone. If additional threats are detected, briefly mention them but focus on the primary threat. Include actionable steps tailored to the primary threat. Do not mention vulnerabilities from the context unless explicitly confirmed as related to the detected threat.
516
  """
517
  try:
518
  llm_response = call_mistral_llm(prompt)