amide-models / app.py
Samarth Naik
Update /compute endpoint to run all 3 models simultaneously with packet counting
a7252f1
from flask import Flask, jsonify, request
from flask_cors import CORS
import subprocess
import csv
import os
import tempfile
import uuid
app = Flask(__name__)
CORS(app)
# Supported model types and their interfaces
MODEL_CONFIGS = {
'lightGBM': {'file': 'lightGBM.py', 'interface': 'hardcoded'},
'autoencoder': {'file': 'autoencoder.py', 'interface': 'hardcoded'},
'XGB_lstm': {'file': 'XGB_lstm.py', 'interface': 'argparse'}
}
def validate_input_data(file_data):
"""Validate the input CSV data structure"""
if not isinstance(file_data, list) or len(file_data) == 0:
return False, "File data must be a non-empty list"
# Check if all rows have the same keys
first_row_keys = set(file_data[0].keys())
for i, row in enumerate(file_data[1:], 1):
if set(row.keys()) != first_row_keys:
return False, f"Row {i+1} has different columns than the first row"
# Basic validation for expected network log columns
required_columns = {'timestamp', 'src_ip', 'dst_ip', 'src_port', 'dst_port'}
if not required_columns.issubset(first_row_keys):
return False, f"Missing required columns: {required_columns - first_row_keys}"
return True, "Valid"
@app.route('/compute', methods=['POST'])
def compute():
temp_filename = None
unique_id = str(uuid.uuid4())[:8]
try:
data = request.get_json()
if not data:
return jsonify({"error": "No JSON data provided"}), 400
file_data = data.get('file')
if not file_data:
return jsonify({"error": "file is required"}), 400
# Validate input data
is_valid, validation_msg = validate_input_data(file_data)
if not is_valid:
return jsonify({"error": f"Invalid input data: {validation_msg}"}), 400
# Count packets and unique flows
num_packets = len(file_data)
flows = set()
for row in file_data:
flow_key = (row['src_ip'], row['src_port'], row['dst_ip'], row['dst_port'])
flows.add(flow_key)
num_flows = len(flows)
# Create temporary CSV file with unique name
temp_filename = f"temp_input_{unique_id}.csv"
# Convert JSON to CSV
fieldnames = file_data[0].keys()
with open(temp_filename, 'w', newline='') as temp_file:
writer = csv.DictWriter(temp_file, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(file_data)
# Run all models
results = {
"success": True,
"packets": {
"total": num_packets,
"unique_flows": num_flows
},
"models": {}
}
for model_type, model_config in MODEL_CONFIGS.items():
model_file = model_config['file']
# Check if model file exists
if not os.path.exists(model_file):
results["models"][model_type] = {
"success": False,
"error": f"Model file {model_file} not found"
}
continue
try:
# Handle different model interfaces
if model_config['interface'] == 'argparse':
# For XGB_lstm.py which uses --logfile argument
cmd = ['python', model_file, '--logfile', temp_filename]
else:
# For models that expect hardcoded filename
expected_filename = "network_logs.csv"
backup_filename = None
# Backup existing file if it exists
if os.path.exists(expected_filename):
backup_filename = f"backup_{expected_filename}_{unique_id}"
os.rename(expected_filename, backup_filename)
# Create symlink or copy
try:
os.symlink(os.path.abspath(temp_filename), expected_filename)
except OSError:
# Fallback to copy if symlink fails
import shutil
shutil.copy2(temp_filename, expected_filename)
cmd = ['python', model_file]
# Run the model
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
cwd=os.getcwd()
)
# Clean up hardcoded file if used
if model_config['interface'] == 'hardcoded':
if os.path.exists("network_logs.csv"):
os.unlink("network_logs.csv")
if backup_filename and os.path.exists(backup_filename):
os.rename(backup_filename, "network_logs.csv")
if result.returncode == 0:
# Try to read output file if it exists
output_files = {
'lightGBM': 'lightgbm_breach_predictions.csv',
'autoencoder': 'breach_predictions.csv',
'XGB_lstm': 'xgb_lstm_predictions.csv'
}
output_data = None
output_file = output_files.get(model_type)
if output_file and os.path.exists(output_file):
try:
import pandas as pd
df = pd.read_csv(output_file)
output_data = df.to_dict('records')
# Rename output file to avoid conflicts
os.rename(output_file, f"{unique_id}_{output_file}")
except Exception as e:
print(f"Warning: Could not read output file: {e}")
results["models"][model_type] = {
"success": True,
"output": result.stdout,
"predictions": output_data,
"error": result.stderr if result.stderr else None
}
else:
results["models"][model_type] = {
"success": False,
"output": result.stdout,
"error": result.stderr
}
results["success"] = False
except subprocess.TimeoutExpired:
results["models"][model_type] = {
"success": False,
"error": f"Model execution timed out after 5 minutes"
}
results["success"] = False
except Exception as e:
results["models"][model_type] = {
"success": False,
"error": f"Execution error: {str(e)}"
}
results["success"] = False
# Clean up temp file
if os.path.exists(temp_filename):
os.unlink(temp_filename)
status_code = 200 if results["success"] else 207 # 207 Multi-Status for partial success
return jsonify(results), status_code
except Exception as e:
return jsonify({"error": f"Server error: {str(e)}"}), 500
finally:
# Ensure cleanup
if temp_filename and os.path.exists(temp_filename):
try:
os.unlink(temp_filename)
except:
pass
@app.route('/health', methods=['GET'])
def health():
return jsonify({"status": "healthy"})
@app.route('/models', methods=['GET'])
def get_models():
"""Return available models and their info"""
models_info = {}
for model_type, config in MODEL_CONFIGS.items():
models_info[model_type] = {
"file": config["file"],
"available": os.path.exists(config["file"]),
"interface": config["interface"]
}
return jsonify({
"available_models": models_info,
"required_columns": ["timestamp", "src_ip", "dst_ip", "src_port", "dst_port"],
"note": "All available models will run automatically. No need to specify model_type."
}), 200
if __name__ == '__main__':
import os
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)