Samarth Naik commited on
Commit
a7252f1
·
1 Parent(s): 1c9f2d5

Update /compute endpoint to run all 3 models simultaneously with packet counting

Browse files

- Removed model_type parameter requirement
- Endpoint now executes all models in parallel
- Single response includes all model outputs clearly separated
- Added total_packets and unique_flows counts
- Updated README.md with new request/response format examples

Files changed (2) hide show
  1. README.md +65 -12
  2. app.py +121 -95
README.md CHANGED
@@ -69,12 +69,11 @@ Returns available models and their configuration.
69
  ```
70
 
71
  ### POST `/compute`
72
- Run breach prediction on network logs.
73
 
74
  **Request:**
75
  ```json
76
  {
77
- "model_type": "lightGBM",
78
  "file": [
79
  {
80
  "timestamp": "2024-01-01T10:00:00",
@@ -84,7 +83,21 @@ Run breach prediction on network logs.
84
  "dst_port": 80,
85
  "packet_size": 1500,
86
  "seq": 1000,
87
- "ack": 2000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  }
89
  ]
90
  }
@@ -94,19 +107,59 @@ Run breach prediction on network logs.
94
  ```json
95
  {
96
  "success": true,
97
- "output": "Model execution output",
98
- "predictions": [
99
- {
100
- "timestamp": "2024-01-01T10:00:00",
101
- "src_ip": "192.168.1.100",
102
- "breach_probability": 0.95,
103
- "breach_predicted": 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  }
105
- ],
106
- "error": null
107
  }
108
  ```
109
 
 
 
 
 
 
 
 
110
  ## Required Input Columns
111
 
112
  - `timestamp`: Timestamp of the network flow
 
69
  ```
70
 
71
  ### POST `/compute`
72
+ Run breach prediction using **all 3 models simultaneously** on network logs.
73
 
74
  **Request:**
75
  ```json
76
  {
 
77
  "file": [
78
  {
79
  "timestamp": "2024-01-01T10:00:00",
 
83
  "dst_port": 80,
84
  "packet_size": 1500,
85
  "seq": 1000,
86
+ "ack": 2000,
87
+ "tcp_flags": 2,
88
+ "window": 65535
89
+ },
90
+ {
91
+ "timestamp": "2024-01-01T10:00:01",
92
+ "src_ip": "192.168.1.101",
93
+ "dst_ip": "10.0.0.2",
94
+ "src_port": 12346,
95
+ "dst_port": 443,
96
+ "packet_size": 1500,
97
+ "seq": 1001,
98
+ "ack": 2001,
99
+ "tcp_flags": 2,
100
+ "window": 65535
101
  }
102
  ]
103
  }
 
107
  ```json
108
  {
109
  "success": true,
110
+ "packets": {
111
+ "total": 2,
112
+ "unique_flows": 2
113
+ },
114
+ "models": {
115
+ "lightGBM": {
116
+ "success": true,
117
+ "output": "Model execution output",
118
+ "predictions": [
119
+ {
120
+ "timestamp": "2024-01-01T10:00:00",
121
+ "src_ip": "192.168.1.100",
122
+ "breach_probability": 0.95,
123
+ "breach_predicted": 1
124
+ }
125
+ ],
126
+ "error": null
127
+ },
128
+ "autoencoder": {
129
+ "success": true,
130
+ "output": "Model execution output",
131
+ "predictions": [
132
+ {
133
+ "timestamp": "2024-01-01T10:00:00",
134
+ "anomaly_score": 0.87,
135
+ "is_anomaly": true
136
+ }
137
+ ],
138
+ "error": null
139
+ },
140
+ "XGB_lstm": {
141
+ "success": true,
142
+ "output": "Model execution output",
143
+ "predictions": [
144
+ {
145
+ "timestamp": "2024-01-01T10:00:00",
146
+ "breach_risk": 0.92,
147
+ "prediction": 1
148
+ }
149
+ ],
150
+ "error": null
151
  }
152
+ }
 
153
  }
154
  ```
155
 
156
+ **Response Format:**
157
+ - `success`: Overall success status (all models succeeded)
158
+ - `packets.total`: Total number of packets in the request
159
+ - `packets.unique_flows`: Number of unique network flows (src_ip:src_port → dst_ip:dst_port)
160
+ - `models`: Dictionary containing results from each model with the same name as the model
161
+ - Each model includes: `success` (bool), `output` (stdout), `predictions` (array), `error` (stderr)
162
+
163
  ## Required Input Columns
164
 
165
  - `timestamp`: Timestamp of the network flow
app.py CHANGED
@@ -44,29 +44,23 @@ def compute():
44
  if not data:
45
  return jsonify({"error": "No JSON data provided"}), 400
46
 
47
- model_type = data.get('model_type')
48
  file_data = data.get('file')
49
 
50
- if not model_type or not file_data:
51
- return jsonify({"error": "model_type and file are required"}), 400
52
-
53
- # Validate model type
54
- if model_type not in MODEL_CONFIGS:
55
- return jsonify({
56
- "error": f"Unsupported model type. Available: {list(MODEL_CONFIGS.keys())}"
57
- }), 400
58
 
59
  # Validate input data
60
  is_valid, validation_msg = validate_input_data(file_data)
61
  if not is_valid:
62
  return jsonify({"error": f"Invalid input data: {validation_msg}"}), 400
63
 
64
- model_config = MODEL_CONFIGS[model_type]
65
- model_file = model_config['file']
66
-
67
- # Check if model file exists
68
- if not os.path.exists(model_file):
69
- return jsonify({"error": f"Model file {model_file} not found"}), 404
 
70
 
71
  # Create temporary CSV file with unique name
72
  temp_filename = f"temp_input_{unique_id}.csv"
@@ -78,91 +72,122 @@ def compute():
78
  writer.writeheader()
79
  writer.writerows(file_data)
80
 
81
- try:
82
- # Handle different model interfaces
83
- if model_config['interface'] == 'argparse':
84
- # For XGB_lstm.py which uses --logfile argument
85
- cmd = ['python', model_file, '--logfile', temp_filename]
86
- else:
87
- # For models that expect hardcoded filename, create a symlink
88
- expected_filename = "network_logs.csv"
89
- backup_filename = None
90
-
91
- # Backup existing file if it exists
92
- if os.path.exists(expected_filename):
93
- backup_filename = f"backup_{expected_filename}_{unique_id}"
94
- os.rename(expected_filename, backup_filename)
95
-
96
- # Create symlink or copy
97
- try:
98
- os.symlink(os.path.abspath(temp_filename), expected_filename)
99
- except OSError:
100
- # Fallback to copy if symlink fails
101
- import shutil
102
- shutil.copy2(temp_filename, expected_filename)
103
-
104
- cmd = ['python', model_file]
105
-
106
- # Run the model
107
- result = subprocess.run(
108
- cmd,
109
- capture_output=True,
110
- text=True,
111
- timeout=300, # 5 minute timeout
112
- cwd=os.getcwd()
113
- )
114
-
115
- # Clean up hardcoded file if used
116
- if model_config['interface'] == 'hardcoded':
117
- if os.path.exists("network_logs.csv"):
118
- os.unlink("network_logs.csv")
119
- if backup_filename and os.path.exists(backup_filename):
120
- os.rename(backup_filename, "network_logs.csv")
121
-
122
- # Clean up temp file
123
- if os.path.exists(temp_filename):
124
- os.unlink(temp_filename)
125
 
126
- if result.returncode == 0:
127
- # Try to read output file if it exists
128
- output_files = {
129
- 'lightGBM': 'lightgbm_breach_predictions.csv',
130
- 'autoencoder': 'breach_predictions.csv',
131
- 'XGB_lstm': 'xgb_lstm_predictions.csv'
132
  }
133
-
134
- output_data = None
135
- output_file = output_files.get(model_type)
136
- if output_file and os.path.exists(output_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  try:
138
- import pandas as pd
139
- df = pd.read_csv(output_file)
140
- output_data = df.to_dict('records')
141
- # Rename output file to avoid conflicts
142
- os.rename(output_file, f"{unique_id}_{output_file}")
143
- except Exception as e:
144
- print(f"Warning: Could not read output file: {e}")
145
 
146
- return jsonify({
147
- "success": True,
148
- "model": model_type,
149
- "output": result.stdout,
150
- "predictions": output_data,
151
- "error": result.stderr if result.stderr else None
152
- }), 200
153
- else:
154
- return jsonify({
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  "success": False,
156
- "model": model_type,
157
- "output": result.stdout,
158
- "error": result.stderr
159
- }), 500
160
 
161
- except subprocess.TimeoutExpired:
162
- return jsonify({"error": "Model execution timed out after 5 minutes"}), 408
163
-
164
- except Exception as e:
165
- return jsonify({"error": f"Execution error: {str(e)}"}), 500
 
 
 
 
 
 
 
 
166
 
167
  except Exception as e:
168
  return jsonify({"error": f"Server error: {str(e)}"}), 500
@@ -191,7 +216,8 @@ def get_models():
191
  }
192
  return jsonify({
193
  "available_models": models_info,
194
- "required_columns": ["timestamp", "src_ip", "dst_ip", "src_port", "dst_port"]
 
195
  }), 200
196
 
197
  if __name__ == '__main__':
 
44
  if not data:
45
  return jsonify({"error": "No JSON data provided"}), 400
46
 
 
47
  file_data = data.get('file')
48
 
49
+ if not file_data:
50
+ return jsonify({"error": "file is required"}), 400
 
 
 
 
 
 
51
 
52
  # Validate input data
53
  is_valid, validation_msg = validate_input_data(file_data)
54
  if not is_valid:
55
  return jsonify({"error": f"Invalid input data: {validation_msg}"}), 400
56
 
57
+ # Count packets and unique flows
58
+ num_packets = len(file_data)
59
+ flows = set()
60
+ for row in file_data:
61
+ flow_key = (row['src_ip'], row['src_port'], row['dst_ip'], row['dst_port'])
62
+ flows.add(flow_key)
63
+ num_flows = len(flows)
64
 
65
  # Create temporary CSV file with unique name
66
  temp_filename = f"temp_input_{unique_id}.csv"
 
72
  writer.writeheader()
73
  writer.writerows(file_data)
74
 
75
+ # Run all models
76
+ results = {
77
+ "success": True,
78
+ "packets": {
79
+ "total": num_packets,
80
+ "unique_flows": num_flows
81
+ },
82
+ "models": {}
83
+ }
84
+
85
+ for model_type, model_config in MODEL_CONFIGS.items():
86
+ model_file = model_config['file']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ # Check if model file exists
89
+ if not os.path.exists(model_file):
90
+ results["models"][model_type] = {
91
+ "success": False,
92
+ "error": f"Model file {model_file} not found"
 
93
  }
94
+ continue
95
+
96
+ try:
97
+ # Handle different model interfaces
98
+ if model_config['interface'] == 'argparse':
99
+ # For XGB_lstm.py which uses --logfile argument
100
+ cmd = ['python', model_file, '--logfile', temp_filename]
101
+ else:
102
+ # For models that expect hardcoded filename
103
+ expected_filename = "network_logs.csv"
104
+ backup_filename = None
105
+
106
+ # Backup existing file if it exists
107
+ if os.path.exists(expected_filename):
108
+ backup_filename = f"backup_{expected_filename}_{unique_id}"
109
+ os.rename(expected_filename, backup_filename)
110
+
111
+ # Create symlink or copy
112
  try:
113
+ os.symlink(os.path.abspath(temp_filename), expected_filename)
114
+ except OSError:
115
+ # Fallback to copy if symlink fails
116
+ import shutil
117
+ shutil.copy2(temp_filename, expected_filename)
118
+
119
+ cmd = ['python', model_file]
120
 
121
+ # Run the model
122
+ result = subprocess.run(
123
+ cmd,
124
+ capture_output=True,
125
+ text=True,
126
+ timeout=300, # 5 minute timeout
127
+ cwd=os.getcwd()
128
+ )
129
+
130
+ # Clean up hardcoded file if used
131
+ if model_config['interface'] == 'hardcoded':
132
+ if os.path.exists("network_logs.csv"):
133
+ os.unlink("network_logs.csv")
134
+ if backup_filename and os.path.exists(backup_filename):
135
+ os.rename(backup_filename, "network_logs.csv")
136
+
137
+ if result.returncode == 0:
138
+ # Try to read output file if it exists
139
+ output_files = {
140
+ 'lightGBM': 'lightgbm_breach_predictions.csv',
141
+ 'autoencoder': 'breach_predictions.csv',
142
+ 'XGB_lstm': 'xgb_lstm_predictions.csv'
143
+ }
144
+
145
+ output_data = None
146
+ output_file = output_files.get(model_type)
147
+ if output_file and os.path.exists(output_file):
148
+ try:
149
+ import pandas as pd
150
+ df = pd.read_csv(output_file)
151
+ output_data = df.to_dict('records')
152
+ # Rename output file to avoid conflicts
153
+ os.rename(output_file, f"{unique_id}_{output_file}")
154
+ except Exception as e:
155
+ print(f"Warning: Could not read output file: {e}")
156
+
157
+ results["models"][model_type] = {
158
+ "success": True,
159
+ "output": result.stdout,
160
+ "predictions": output_data,
161
+ "error": result.stderr if result.stderr else None
162
+ }
163
+ else:
164
+ results["models"][model_type] = {
165
+ "success": False,
166
+ "output": result.stdout,
167
+ "error": result.stderr
168
+ }
169
+ results["success"] = False
170
+
171
+ except subprocess.TimeoutExpired:
172
+ results["models"][model_type] = {
173
  "success": False,
174
+ "error": f"Model execution timed out after 5 minutes"
175
+ }
176
+ results["success"] = False
 
177
 
178
+ except Exception as e:
179
+ results["models"][model_type] = {
180
+ "success": False,
181
+ "error": f"Execution error: {str(e)}"
182
+ }
183
+ results["success"] = False
184
+
185
+ # Clean up temp file
186
+ if os.path.exists(temp_filename):
187
+ os.unlink(temp_filename)
188
+
189
+ status_code = 200 if results["success"] else 207 # 207 Multi-Status for partial success
190
+ return jsonify(results), status_code
191
 
192
  except Exception as e:
193
  return jsonify({"error": f"Server error: {str(e)}"}), 500
 
216
  }
217
  return jsonify({
218
  "available_models": models_info,
219
+ "required_columns": ["timestamp", "src_ip", "dst_ip", "src_port", "dst_port"],
220
+ "note": "All available models will run automatically. No need to specify model_type."
221
  }), 200
222
 
223
  if __name__ == '__main__':