Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import xgboost as xgb | |
| import gradio as gr | |
| from scapy.all import rdpcap | |
| from collections import defaultdict | |
| import os | |
| def transform_new_input(new_input): | |
| #Scale input features based on predetermined min/max values | |
| scaled_min = np.array([ | |
| 1.0, 10.0, 856.0, 5775.0, 42.0, 26.0, 0.0, 278.0, 4.0, 1.0, | |
| -630355.0, 4.0, 50.0 | |
| ]) | |
| scaled_max = np.array([ | |
| 4.0, 352752.0, 271591638.0, 239241314.0, 421552.0, 3317.0, | |
| 6302708.0, 6302708.0, 5.0, 5.0, 1746749.0, 608.0, 1012128.0 | |
| ]) | |
| new_input = np.array(new_input) | |
| scaled_input = (new_input - scaled_min) / (scaled_max - scaled_min) | |
| return scaled_input | |
| class PcapProcessor: | |
| def __init__(self, pcap_file): | |
| #Initialize PCAP processor with file path | |
| self.packets = rdpcap(pcap_file) | |
| self.start_time = None | |
| self.port_stats = defaultdict(lambda: { | |
| 'rx_packets': 0, | |
| 'rx_bytes': 0, | |
| 'tx_packets': 0, | |
| 'tx_bytes': 0, | |
| 'first_seen': None, | |
| 'last_seen': None, | |
| 'active_flows': set(), | |
| 'packets_matched': 0 | |
| }) | |
| def process_packets(self, window_size=60): | |
| #Process all packets and extract features | |
| if not self.packets: | |
| return [] | |
| self.start_time = float(self.packets[0].time) | |
| # Process each packet | |
| for packet in self.packets: | |
| current_time = float(packet.time) | |
| if 'TCP' in packet or 'UDP' in packet: | |
| try: | |
| src_port = packet.sport | |
| dst_port = packet.dport | |
| pkt_size = len(packet) | |
| # Track flow information | |
| flow_tuple = (packet['IP'].src, packet['IP'].dst, | |
| src_port, dst_port) | |
| # Update port statistics | |
| self._update_port_stats(src_port, pkt_size, True, | |
| current_time, flow_tuple) | |
| self._update_port_stats(dst_port, pkt_size, False, | |
| current_time, flow_tuple) | |
| except Exception as e: | |
| print(f"Error processing packet {packet}: {str(e)}") | |
| continue | |
| # Extract features for each port | |
| features_list = [] | |
| for port, stats in self.port_stats.items(): | |
| if stats['first_seen'] is not None: | |
| features = self._extract_port_features(port, stats, window_size) | |
| features_list.append(features) | |
| return features_list | |
| def _update_port_stats(self, port, pkt_size, is_source, current_time, | |
| flow_tuple): | |
| #Update statistics for a given port | |
| stats = self.port_stats[port] | |
| if stats['first_seen'] is None: | |
| stats['first_seen'] = current_time | |
| stats['last_seen'] = current_time | |
| if is_source: | |
| stats['tx_packets'] += 1 | |
| stats['tx_bytes'] += pkt_size | |
| else: | |
| stats['rx_packets'] += 1 | |
| stats['rx_bytes'] += pkt_size | |
| stats['active_flows'].add(flow_tuple) | |
| stats['packets_matched'] += 1 | |
| def _extract_port_features(self, port, stats, window_size): | |
| #Extract the 13 features needed for the IDS model | |
| port_alive_duration = stats['last_seen'] - stats['first_seen'] | |
| delta_alive_duration = min(port_alive_duration, window_size) | |
| # Calculate rates and loads | |
| total_load = (stats['rx_bytes'] + stats['tx_bytes']) / \ | |
| max(port_alive_duration, 1) | |
| features = [ | |
| min(port % 4 + 1, 4), # Port Number (1-4) | |
| stats['rx_packets'], # Received Packets | |
| stats['rx_bytes'], # Received Bytes | |
| stats['tx_bytes'], # Sent Bytes | |
| stats['tx_packets'], # Sent Packets | |
| port_alive_duration, # Port alive Duration | |
| stats['rx_bytes'], # Delta Received Bytes | |
| stats['tx_bytes'], # Delta Sent Bytes | |
| min(delta_alive_duration, 5), # Delta Port alive Duration | |
| min((port % 5) + 1, 5), # Connection Point | |
| total_load, # Total Load/Rate | |
| len(stats['active_flows']), # Active Flow Entries | |
| stats['packets_matched'] # Packets Matched | |
| ] | |
| return features | |
| def process_pcap_for_ids(pcap_file): | |
| """Process PCAP file and return features for IDS model""" | |
| processor = PcapProcessor(pcap_file) | |
| features = processor.process_packets() | |
| return features | |
| def predict_from_features(features, model): | |
| """Make prediction from extracted features""" | |
| # Scale features | |
| scaled_features = transform_new_input(features) | |
| features_matrix = xgb.DMatrix(scaled_features.reshape(1, -1)) | |
| # Make prediction and get probability distribution | |
| raw_prediction = model.predict(features_matrix) | |
| probabilities = raw_prediction[0] # Get probability distribution | |
| prediction = np.argmax(probabilities) | |
| # Add threshold for normal traffic | |
| # If highest probability is for normal (class 0) and exceeds threshold | |
| if prediction == 0 and probabilities[0] > 0.6: # 60% confidence threshold | |
| return get_prediction_message(0) | |
| # If no class has high confidence, consider it normal | |
| elif np.max(probabilities) < 0.4: # Low confidence threshold | |
| return get_prediction_message(0) | |
| else: | |
| return get_prediction_message(prediction) | |
| def get_prediction_message(prediction): | |
| """Get formatted prediction message with confidence levels""" | |
| messages = { | |
| 0: ("NORMAL TRAFFIC - No indication of attack.", | |
| "Traffic patterns appear to be within normal parameters."), | |
| 1: ("ALERT: Potential BLACKHOLE attack detected.", | |
| "Information: BLACKHOLE attacks occur when a router maliciously drops " | |
| "packets it should forward. Investigate affected routes and traffic patterns."), | |
| 2: ("ALERT: Potential TCP-SYN flood attack detected.", | |
| "Information: TCP-SYN flood is a DDoS attack exhausting server resources " | |
| "with half-open connections. Check connection states and implement SYN cookies."), | |
| 3: ("ALERT: PORTSCAN activity detected.", | |
| "Information: Port scanning detected - systematic probing of system ports. " | |
| "Review firewall rules and implement connection rate limiting."), | |
| 4: ("ALERT: Potential DIVERSION attack detected.", | |
| "Information: Traffic diversion detected. Verify routing integrity and " | |
| "check for signs of traffic manipulation or social engineering attempts.") | |
| } | |
| return messages.get(prediction, ("Unknown Traffic Pattern", "Additional analysis required.")) | |
| def process_pcap_input(pcap_file): | |
| """Process PCAP file input""" | |
| try: | |
| model = xgb.Booster() | |
| model.load_model("m3_xg_boost.model") | |
| features_list = process_pcap_for_ids(pcap_file.name) | |
| if not features_list: | |
| return "No valid network traffic found in PCAP file." | |
| results = [] | |
| for idx, features in enumerate(features_list): | |
| result_msg, result_info = predict_from_features(features, model) | |
| results.append(f"Traffic Pattern {idx + 1}:\n{result_msg}\n{result_info}\n") | |
| return "\n".join(results) | |
| except Exception as e: | |
| return f"Error processing PCAP file: {str(e)}" | |
| def process_manual_input(port_num, rx_packets, rx_bytes, tx_bytes, tx_packets, | |
| port_duration, delta_rx_bytes, delta_tx_bytes, | |
| delta_duration, conn_point, total_load, active_flows, | |
| packets_matched): | |
| #Process manual input values | |
| try: | |
| model = xgb.Booster() | |
| model.load_model("m3_xg_boost.model") | |
| features = [ | |
| port_num, rx_packets, rx_bytes, tx_bytes, tx_packets, | |
| port_duration, delta_rx_bytes, delta_tx_bytes, delta_duration, | |
| conn_point, total_load, active_flows, packets_matched | |
| ] | |
| result_msg, result_info = predict_from_features(features, model) | |
| return f"{result_msg}\n{result_info}" | |
| except Exception as e: | |
| return f"Error processing manual input: {str(e)}" | |
| # Main execution | |
| if __name__ == "__main__": | |
| # Create the interface | |
| with gr.Blocks(theme="default") as interface: | |
| gr.Markdown(""" | |
| # Network Intrusion Detection System | |
| Upload a PCAP file or use manual input to detect potential network attacks. | |
| """) | |
| with gr.Tab("PCAP Analysis"): | |
| pcap_input = gr.File( | |
| label="Upload PCAP File", | |
| file_types=[".pcap", ".pcapng"] | |
| ) | |
| pcap_output = gr.Textbox(label="Analysis Results") | |
| pcap_button = gr.Button("Analyze PCAP") | |
| pcap_button.click( | |
| fn=process_pcap_input, | |
| inputs=[pcap_input], | |
| outputs=pcap_output | |
| ) | |
| with gr.Tab("Manual Input"): | |
| # Manual input components | |
| with gr.Row(): | |
| port_num = gr.Slider(1, 4, value=1, | |
| label="Port Number - The switch port through which the flow passed") | |
| rx_packets = gr.Slider(0, 352772, value=0, | |
| label="Received Packets - Number of packets received by the port") | |
| with gr.Row(): | |
| rx_bytes = gr.Slider(0, 2.715916e08, value=0, | |
| label="Received Bytes - Number of bytes received by the port") | |
| tx_bytes = gr.Slider(0, 2.392430e08, value=0, | |
| label="Sent Bytes - Number of bytes sent by the port") | |
| with gr.Row(): | |
| tx_packets = gr.Slider(0, 421598, value=0, | |
| label="Sent Packets - Number of packets sent by the port") | |
| port_duration = gr.Slider(0, 3317, value=0, | |
| label="Port alive Duration (S) - The time port has been alive in seconds") | |
| with gr.Row(): | |
| delta_rx_bytes = gr.Slider(0, 6500000, value=0, | |
| label="Delta Received Bytes") | |
| delta_tx_bytes = gr.Slider(0, 6500000, value=0, | |
| label="Delta Sent Bytes") | |
| with gr.Row(): | |
| delta_duration = gr.Slider(0, 5, value=0, | |
| label="Delta Port alive Duration (S)") | |
| conn_point = gr.Slider(1, 5, value=1, | |
| label="Connection Point") | |
| with gr.Row(): | |
| total_load = gr.Slider(0, 1800000, value=0, | |
| label="Total Load/Rate") | |
| active_flows = gr.Slider(0, 610, value=0, | |
| label="Active Flow Entries") | |
| with gr.Row(): | |
| packets_matched = gr.Slider(0, 1020000, value=0, | |
| label="Packets Matched") | |
| manual_output = gr.Textbox(label="Analysis Results") | |
| manual_button = gr.Button("Analyze Manual Input") | |
| # Connect manual input components | |
| manual_button.click( | |
| fn=process_manual_input, | |
| inputs=[ | |
| port_num, rx_packets, rx_bytes, tx_bytes, tx_packets, | |
| port_duration, delta_rx_bytes, delta_tx_bytes, | |
| delta_duration, conn_point, total_load, active_flows, | |
| packets_matched | |
| ], | |
| outputs=manual_output | |
| ) | |
| # Example inputs | |
| gr.Examples( | |
| examples=[ | |
| [4, 350188, 14877116, 101354648, 159524, 2910, 278, 280, | |
| 5, 4, 0, 6, 667324], | |
| [2, 2326, 12856942, 31777516, 2998, 2497, 560, 560, | |
| 5, 2, 0, 4, 7259], | |
| [4, 150, 19774, 6475473, 3054, 166, 556, 6068, | |
| 5, 4, 502, 6, 7418], | |
| [2, 209, 20671, 6316631, 274, 96, 3527, 2757949, | |
| 5, 2, 183877, 8, 90494], | |
| [2, 1733, 37865130, 38063670, 3187, 2152, 0, 556, | |
| 5, 3, 0, 4, 14864] | |
| ], | |
| inputs=[ | |
| port_num, rx_packets, rx_bytes, tx_bytes, tx_packets, | |
| port_duration, delta_rx_bytes, delta_tx_bytes, | |
| delta_duration, conn_point, total_load, active_flows, | |
| packets_matched | |
| ] | |
| ) | |
| # Launch the interface | |
| interface.launch() |