Spaces:
Sleeping
Sleeping
Samarth Naik
Update /compute endpoint to run all 3 models simultaneously with packet counting
a7252f1
| from flask import Flask, jsonify, request | |
| from flask_cors import CORS | |
| import subprocess | |
| import csv | |
| import os | |
| import tempfile | |
| import uuid | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Supported model types and their interfaces | |
| MODEL_CONFIGS = { | |
| 'lightGBM': {'file': 'lightGBM.py', 'interface': 'hardcoded'}, | |
| 'autoencoder': {'file': 'autoencoder.py', 'interface': 'hardcoded'}, | |
| 'XGB_lstm': {'file': 'XGB_lstm.py', 'interface': 'argparse'} | |
| } | |
| def validate_input_data(file_data): | |
| """Validate the input CSV data structure""" | |
| if not isinstance(file_data, list) or len(file_data) == 0: | |
| return False, "File data must be a non-empty list" | |
| # Check if all rows have the same keys | |
| first_row_keys = set(file_data[0].keys()) | |
| for i, row in enumerate(file_data[1:], 1): | |
| if set(row.keys()) != first_row_keys: | |
| return False, f"Row {i+1} has different columns than the first row" | |
| # Basic validation for expected network log columns | |
| required_columns = {'timestamp', 'src_ip', 'dst_ip', 'src_port', 'dst_port'} | |
| if not required_columns.issubset(first_row_keys): | |
| return False, f"Missing required columns: {required_columns - first_row_keys}" | |
| return True, "Valid" | |
| def compute(): | |
| temp_filename = None | |
| unique_id = str(uuid.uuid4())[:8] | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({"error": "No JSON data provided"}), 400 | |
| file_data = data.get('file') | |
| if not file_data: | |
| return jsonify({"error": "file is required"}), 400 | |
| # Validate input data | |
| is_valid, validation_msg = validate_input_data(file_data) | |
| if not is_valid: | |
| return jsonify({"error": f"Invalid input data: {validation_msg}"}), 400 | |
| # Count packets and unique flows | |
| num_packets = len(file_data) | |
| flows = set() | |
| for row in file_data: | |
| flow_key = (row['src_ip'], row['src_port'], row['dst_ip'], row['dst_port']) | |
| flows.add(flow_key) | |
| num_flows = len(flows) | |
| # Create temporary CSV file with unique name | |
| temp_filename = f"temp_input_{unique_id}.csv" | |
| # Convert JSON to CSV | |
| fieldnames = file_data[0].keys() | |
| with open(temp_filename, 'w', newline='') as temp_file: | |
| writer = csv.DictWriter(temp_file, fieldnames=fieldnames) | |
| writer.writeheader() | |
| writer.writerows(file_data) | |
| # Run all models | |
| results = { | |
| "success": True, | |
| "packets": { | |
| "total": num_packets, | |
| "unique_flows": num_flows | |
| }, | |
| "models": {} | |
| } | |
| for model_type, model_config in MODEL_CONFIGS.items(): | |
| model_file = model_config['file'] | |
| # Check if model file exists | |
| if not os.path.exists(model_file): | |
| results["models"][model_type] = { | |
| "success": False, | |
| "error": f"Model file {model_file} not found" | |
| } | |
| continue | |
| try: | |
| # Handle different model interfaces | |
| if model_config['interface'] == 'argparse': | |
| # For XGB_lstm.py which uses --logfile argument | |
| cmd = ['python', model_file, '--logfile', temp_filename] | |
| else: | |
| # For models that expect hardcoded filename | |
| expected_filename = "network_logs.csv" | |
| backup_filename = None | |
| # Backup existing file if it exists | |
| if os.path.exists(expected_filename): | |
| backup_filename = f"backup_{expected_filename}_{unique_id}" | |
| os.rename(expected_filename, backup_filename) | |
| # Create symlink or copy | |
| try: | |
| os.symlink(os.path.abspath(temp_filename), expected_filename) | |
| except OSError: | |
| # Fallback to copy if symlink fails | |
| import shutil | |
| shutil.copy2(temp_filename, expected_filename) | |
| cmd = ['python', model_file] | |
| # Run the model | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| timeout=300, # 5 minute timeout | |
| cwd=os.getcwd() | |
| ) | |
| # Clean up hardcoded file if used | |
| if model_config['interface'] == 'hardcoded': | |
| if os.path.exists("network_logs.csv"): | |
| os.unlink("network_logs.csv") | |
| if backup_filename and os.path.exists(backup_filename): | |
| os.rename(backup_filename, "network_logs.csv") | |
| if result.returncode == 0: | |
| # Try to read output file if it exists | |
| output_files = { | |
| 'lightGBM': 'lightgbm_breach_predictions.csv', | |
| 'autoencoder': 'breach_predictions.csv', | |
| 'XGB_lstm': 'xgb_lstm_predictions.csv' | |
| } | |
| output_data = None | |
| output_file = output_files.get(model_type) | |
| if output_file and os.path.exists(output_file): | |
| try: | |
| import pandas as pd | |
| df = pd.read_csv(output_file) | |
| output_data = df.to_dict('records') | |
| # Rename output file to avoid conflicts | |
| os.rename(output_file, f"{unique_id}_{output_file}") | |
| except Exception as e: | |
| print(f"Warning: Could not read output file: {e}") | |
| results["models"][model_type] = { | |
| "success": True, | |
| "output": result.stdout, | |
| "predictions": output_data, | |
| "error": result.stderr if result.stderr else None | |
| } | |
| else: | |
| results["models"][model_type] = { | |
| "success": False, | |
| "output": result.stdout, | |
| "error": result.stderr | |
| } | |
| results["success"] = False | |
| except subprocess.TimeoutExpired: | |
| results["models"][model_type] = { | |
| "success": False, | |
| "error": f"Model execution timed out after 5 minutes" | |
| } | |
| results["success"] = False | |
| except Exception as e: | |
| results["models"][model_type] = { | |
| "success": False, | |
| "error": f"Execution error: {str(e)}" | |
| } | |
| results["success"] = False | |
| # Clean up temp file | |
| if os.path.exists(temp_filename): | |
| os.unlink(temp_filename) | |
| status_code = 200 if results["success"] else 207 # 207 Multi-Status for partial success | |
| return jsonify(results), status_code | |
| except Exception as e: | |
| return jsonify({"error": f"Server error: {str(e)}"}), 500 | |
| finally: | |
| # Ensure cleanup | |
| if temp_filename and os.path.exists(temp_filename): | |
| try: | |
| os.unlink(temp_filename) | |
| except: | |
| pass | |
| def health(): | |
| return jsonify({"status": "healthy"}) | |
| def get_models(): | |
| """Return available models and their info""" | |
| models_info = {} | |
| for model_type, config in MODEL_CONFIGS.items(): | |
| models_info[model_type] = { | |
| "file": config["file"], | |
| "available": os.path.exists(config["file"]), | |
| "interface": config["interface"] | |
| } | |
| return jsonify({ | |
| "available_models": models_info, | |
| "required_columns": ["timestamp", "src_ip", "dst_ip", "src_port", "dst_port"], | |
| "note": "All available models will run automatically. No need to specify model_type." | |
| }), 200 | |
| if __name__ == '__main__': | |
| import os | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port, debug=False, threaded=True) |