File size: 9,443 Bytes
61ad322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# app.py (Flask Backend)

import os
import random
import time
from flask import Flask, render_template, request, jsonify
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import json # For debugging/logging if needed

# --- Configuration ---
app = Flask(__name__)
MODEL_PATH = 'zero_day_encoder_model.pth'
# Assuming the scaler was fit on data with the same number of features as input_dim
# For a real system, you'd save/load the scaler as well.
# For this demo, we'll re-initialize a dummy scaler and use the exact number of features
# determined by the loaded model's input layer.
GLOBAL_SCALER = None # Will be initialized after model loads
MODEL_INPUT_DIM = None # Will be set by the loaded model
MODEL_LATENT_DIM = 32 # Must match the latent_dim used during training
ANOMALY_THRESHOLD = 5.0 # Adjustable threshold for flagging attacks (Euclidean distance)

# --- PyTorch Model Architecture (Must match training script) ---
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, latent_dim) # Latent dimension for embeddings

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

# --- Load the Trained Model and Initialize Scaler/Centroid ---
# This will be run once when the Flask app starts
def load_model_and_params():
    global GLOBAL_SCALER, MODEL_INPUT_DIM, GLOBAL_CENTROID

    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model file not found at {MODEL_PATH}. Please train the Jupyter Notebook first.")
        # Fallback to dummy model if not found
        # This will allow the app to run but not perform real anomaly detection
        MODEL_INPUT_DIM = 7 + 39 # Dummy, assuming original 7 financial + 39 CICIDS features
        GLOBAL_SCALER = StandardScaler()
        # Initialize a dummy encoder for app startup without a model file
        dummy_encoder = Encoder(MODEL_INPUT_DIM, MODEL_LATENT_DIM)
        GLOBAL_CENTROID = np.random.rand(MODEL_LATENT_DIM) * 0.1 # Small random centroid
        return dummy_encoder

    try:
        # Create a dummy instance to load the state_dict into
        # We need to infer the input_dim from the saved state_dict or hardcode it
        # A more robust way is to save model architecture as well, or pass input_dim during saving.
        # For now, let's assume input_dim = 7 (financial) + 39 (CICIDS selected) = 46.
        # This input_dim must precisely match what the trained model expects.
        temp_input_dim = 7 + 39 # Base assumption: 7 financial + 39 CICIDS features
        temp_encoder = Encoder(temp_input_dim, MODEL_LATENT_DIM)

        # Load the state dictionary
        state_dict = torch.load(MODEL_PATH, map_location=torch.device('cpu')) # Map to CPU as Flask runs on CPU

        # Update input_dim based on the loaded state_dict if possible
        # Check the first linear layer's weight shape
        if 'fc1.weight' in state_dict:
            MODEL_INPUT_DIM = state_dict['fc1.weight'].shape[1]
            temp_encoder = Encoder(MODEL_INPUT_DIM, MODEL_LATENT_DIM) # Recreate with correct input_dim
        else:
            print(f"Warning: Could not infer input_dim from model state_dict. Using assumed: {temp_input_dim}")
            MODEL_INPUT_DIM = temp_input_dim

        temp_encoder.load_state_dict(state_dict)
        temp_encoder.eval() # Set to evaluation mode

        # Load the scaler and centroid. In a real system, you'd save these from your training notebook.
        # For this demo, we'll create a dummy scaler and centroid that correspond to the model's input_dim.
        GLOBAL_SCALER = StandardScaler()
        # In a production setup, the scaler's parameters (mean, std) and the centroid
        # would be saved during training and loaded here. For simplicity, we'll
        # just initialize a generic scaler and a placeholder centroid.
        # The centroid should ideally be learned from the *benign* training data.
        # For demonstration, we'll generate a random one and rely on the model's embeddings.
        GLOBAL_CENTROID = np.random.rand(MODEL_LATENT_DIM) # Placeholder centroid.

        print(f"Model loaded successfully. Input Dimension: {MODEL_INPUT_DIM}")
        return temp_encoder

    except Exception as e:
        print(f"Error loading model: {e}")
        print("Using dummy model for application startup.")
        MODEL_INPUT_DIM = 7 + 39 # Fallback
        GLOBAL_SCALER = StandardScaler()
        dummy_encoder = Encoder(MODEL_INPUT_DIM, MODEL_LATENT_DIM)
        GLOBAL_CENTROID = np.random.rand(MODEL_LATENT_DIM) * 0.1
        return dummy_encoder

ENCODER_MODEL = load_model_and_params()

# --- Helper Function for Anomaly Detection ---
def detect_anomaly(raw_data_point):
    """
    Processes a single raw data point and returns its anomaly score and classification.
    """
    global GLOBAL_SCALER, ENCODER_MODEL, GLOBAL_CENTROID, MODEL_INPUT_DIM

    # Ensure the input data has the correct number of features
    if len(raw_data_point) != MODEL_INPUT_DIM:
        print(f"Input data dimension mismatch: Expected {MODEL_INPUT_DIM}, got {len(raw_data_point)}")
        # Pad or truncate if dimensions don't match (for robust demo)
        if len(raw_data_point) < MODEL_INPUT_DIM:
            raw_data_point = np.pad(raw_data_point, (0, MODEL_INPUT_DIM - len(raw_data_point)), 'constant')
        else:
            raw_data_point = raw_data_point[:MODEL_INPUT_DIM]


    # Reshape for scaler (needs 2D array: n_samples, n_features)
    data_point_2d = np.array(raw_data_point).reshape(1, -1)

    # Use a dummy fit_transform if scaler hasn't seen data, otherwise transform
    # In a real app, the scaler would be loaded, or fit on a small sample of representative data at startup.
    # For robust demo: if scaler has no 'mean_' attr (not fitted), fit it on some dummy data first.
    if not hasattr(GLOBAL_SCALER, 'mean_') or GLOBAL_SCALER.mean_ is None or GLOBAL_SCALER.mean_.shape[0] != MODEL_INPUT_DIM:
        print("Scaler not fitted or dimension mismatch, fitting dummy scaler...")
        # Create dummy data for scaler to fit, matching input_dim
        dummy_fit_data = np.random.rand(100, MODEL_INPUT_DIM)
        GLOBAL_SCALER.fit(dummy_fit_data)

    scaled_data_point = GLOBAL_SCALER.transform(data_point_2d)

    # Convert to PyTorch tensor
    data_tensor = torch.tensor(scaled_data_point, dtype=torch.float32)

    with torch.no_grad():
        embedding = ENCODER_MODEL(data_tensor).cpu().numpy().flatten()

    # Calculate anomaly score (Euclidean distance to centroid)
    anomaly_score = np.linalg.norm(embedding - GLOBAL_CENTROID)

    # Classify based on threshold
    is_anomaly = anomaly_score > ANOMALY_THRESHOLD
    attack_status = "Attack Detected!" if is_anomaly else "Normal Behavior"
    reaction_message = ""
    if is_anomaly:
        reaction_message = "Immediate transaction review triggered. Connection flagged."
        # Simulate prevention by, e.g., setting a flag, initiating a block, etc.
        # In a real system, this would trigger actual security measures.
        print(f"ALERT: Zero-Day Attack Detected! Score: {anomaly_score:.2f}")

    return {
        'score': float(anomaly_score),
        # Explicitly convert numpy.bool_ to Python bool for jsonify compatibility
        'is_anomaly': bool(is_anomaly),
        'status': attack_status,
        'reaction': reaction_message,
        'attack_type': random.choice(['Phishing', 'Malware', 'DDoS', 'Insider Threat', 'Zero-Day Exploitation']) if is_anomaly else 'Benign'
    }

# --- Flask Routes ---
@app.route('/')
def index():
    """Renders the main dashboard HTML page."""
    # Pass MODEL_INPUT_DIM to the frontend for simulation logic
    return render_template('index.html', MODEL_INPUT_DIM=MODEL_INPUT_DIM)

@app.route('/api/analyze_log', methods=['POST'])
def analyze_log():
    """
    API endpoint to receive simulated log data, run anomaly detection,
    and return results.
    """
    try:
        data = request.get_json()
        raw_log_data = data.get('log_features')
        if not raw_log_data:
            return jsonify({'error': 'No log_features provided'}), 400

        # Convert list to numpy array
        raw_log_data = np.array(raw_log_data, dtype=np.float32)

        result = detect_anomaly(raw_log_data)
        return jsonify(result)

    except Exception as e:
        print(f"Error in /api/analyze_log: {e}")
        return jsonify({'error': str(e)}), 500

@app.route('/api/metrics')
def get_metrics():
    """
    Simulates real-time metrics for the dashboard.
    In a real system, this would fetch from a database or monitoring system.
    """
    total_transactions = random.randint(100000, 1000000)
    threats_detected = random.randint(50, 500)
    blocked_attempts = random.randint(30, threats_detected)
    active_users = random.randint(1000, 50000)
    return jsonify({
        'totalTransactions': total_transactions,
        'threatsDetected': threats_detected,
        'blockedAttempts': blocked_attempts,
        'activeUsers': active_users,
        'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
    })

if __name__ == '__main__':
    app.run(host="0.0.0.0", port=7860, debug=True)