Ymak7 commited on
Commit
3c7592d
Β·
verified Β·
1 Parent(s): ee922b9

Upload 4 files

Browse files
Files changed (4) hide show
  1. gnn_aml.py +146 -0
  2. graph_aml.py +97 -0
  3. test_model.py +159 -0
  4. trained_model.pth +3 -0
gnn_aml.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F # βœ… Fix: Import missing F module
5
+ import json
6
+ import numpy as np
7
+ from torch_geometric.data import Data
8
+ from torch_geometric.nn import GATConv
9
+ from graph_aml import add_transaction, detect_pattern, transaction_graphs
10
+ from collections import defaultdict
11
+ from sklearn.utils.class_weight import compute_class_weight
12
+
13
+ # Load Simulated Transactions
14
+ print("Loading simulated transactions...")
15
+ with open("simulated_transactions.json", "r") as f:
16
+ transactions = json.load(f)
17
+ print(f"Loaded {len(transactions)} transactions.")
18
+
19
+ # Define AI Model
20
+ class GAT(nn.Module):
21
+ def __init__(self, num_node_features, hidden_dim, output_dim, heads=3):
22
+ super(GAT, self).__init__()
23
+ self.conv1 = GATConv(num_node_features, hidden_dim, heads=heads, concat=True)
24
+ self.conv2 = GATConv(hidden_dim * heads, output_dim, heads=1, concat=False)
25
+ self.dropout = nn.Dropout(0.3) # Dropout to reduce overfitting
26
+
27
+ def forward(self, data):
28
+ x, edge_index = data.x, data.edge_index
29
+ x = self.conv1(x, edge_index).relu()
30
+ x = self.dropout(x) # Apply dropout
31
+ x = self.conv2(x, edge_index)
32
+ return F.log_softmax(x, dim=1) # Apply softmax for classification
33
+
34
+
35
+ # def normalize_feature(x):
36
+ # """Normalize feature vector"""
37
+ # x = np.array(x)
38
+ # return (x - np.min(x, axis=0)) / (np.max(x, axis=0) - np.min(x, axis=0) + 1e-8)
39
+
40
+
41
+ # Prepare Graph Data
42
+ def normalize_feature(x):
43
+ return (x - np.min(x)) / (np.max(x) - np.min(x) + 1e-8) if np.max(x) - np.min(x) != 0 else x
44
+
45
+
46
+ def prepare_graph():
47
+ print("Preparing graph data...")
48
+ features = []
49
+ edge_list = []
50
+ labels = []
51
+ account_map = {}
52
+
53
+ for txn in transactions:
54
+ add_transaction(txn) # Add transaction to graph
55
+
56
+ graph_list = list(transaction_graphs.values())
57
+ print(f"Total transaction graphs created: {len(graph_list)}")
58
+
59
+ for i, graph in enumerate(graph_list):
60
+ for node in graph.nodes:
61
+ if node not in account_map:
62
+ account_map[node] = len(account_map)
63
+
64
+ for node in graph.nodes:
65
+ raw_feature_vector = [
66
+ len(list(graph.successors(node))), # Outgoing Connections
67
+ len(list(graph.predecessors(node))), # Incoming Connections
68
+ 1 if detect_pattern(graph) != "Normal" else 0 # AML Label
69
+ ]
70
+ # Normalize features
71
+ feature_vector = [normalize_feature(x) for x in raw_feature_vector]
72
+ features.append(feature_vector)
73
+
74
+ labels.append(1 if detect_pattern(graph) != "Normal" else 0)
75
+
76
+ for sender, receiver in graph.edges:
77
+ if sender in account_map and receiver in account_map:
78
+ edge_list.append([account_map[sender], account_map[receiver]])
79
+
80
+ print("Graph preparation complete.")
81
+
82
+ if not features:
83
+ print("❌ No valid features found. Exiting.")
84
+ return None, None
85
+
86
+ # 🚨 Debug: Check Label Distribution
87
+ # βœ… Check class balance
88
+ print(f"Label Distribution: {np.bincount(labels)}")
89
+
90
+ x = torch.tensor(features, dtype=torch.float)
91
+ edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
92
+ return Data(x=x, edge_index=edge_index), labels
93
+
94
+
95
+ # Train AI Model
96
+ def train_gnn():
97
+ print("Starting GNN training...")
98
+ data, labels = prepare_graph()
99
+ if data is None:
100
+ print("❌ Training aborted. No valid data available.")
101
+ return
102
+
103
+ model = GAT(num_node_features=3, hidden_dim=16, output_dim=2)
104
+ optimizer = optim.Adam(model.parameters(), lr=0.005)
105
+ labels_np = np.array(labels).flatten() # Ensure it's 1D
106
+
107
+ # βœ… Ensure both classes exist
108
+ if len(np.unique(labels_np)) < 2:
109
+ print("⚠️ Warning: Only one class present in dataset! Generating synthetic samples to balance.")
110
+
111
+ num_samples = len(labels_np)
112
+ new_class = 1 if np.all(labels_np == 0) else 0 # Add the missing class
113
+ synthetic_samples = np.full((num_samples // 5,), new_class) # Add 20% of missing class
114
+
115
+ labels_np = np.concatenate([labels_np, synthetic_samples]) # Add new samples
116
+ print(f"βœ… New Label Distribution: {np.bincount(labels_np)}") # Debugging
117
+
118
+ # Compute class weights after ensuring both classes exist
119
+ class_weights = compute_class_weight(
120
+ class_weight="balanced", classes=np.array([0, 1]), y=labels_np
121
+ )
122
+ class_weights = torch.tensor(class_weights, dtype=torch.float)
123
+
124
+
125
+
126
+ # Apply weighted loss function
127
+ loss_fn = nn.CrossEntropyLoss(weight=class_weights)
128
+ labels = torch.tensor(labels, dtype=torch.long)
129
+ print("Training started...")
130
+
131
+ for epoch in range(200):
132
+ optimizer.zero_grad()
133
+ output = model(data)
134
+ loss = loss_fn(output, labels)
135
+ loss.backward()
136
+ optimizer.step()
137
+ if epoch % 20 == 0:
138
+ print(f"Epoch {epoch}, Loss: {loss.item()}")
139
+
140
+ print("βœ… GNN Training Complete.")
141
+ torch.save(model.state_dict(), "trained_model.pth")
142
+ print("βœ… Model saved as trained_model.pth")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ train_gnn()
graph_aml.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import networkx as nx
3
+ import torch
4
+
5
+ # Global Graph Storage
6
+ transaction_graphs = {}
7
+
8
+ # Generate Unique Hash for Transaction Groups
9
+ def generate_graph_hash(transactions):
10
+ hash_string = "-".join(sorted(transactions)) # Sort for consistency
11
+ return hashlib.sha256(hash_string.encode()).hexdigest()
12
+
13
+ # Hash Function for Keys
14
+ def hash_key(value):
15
+ return hashlib.sha256(value.encode()).hexdigest()
16
+
17
+ # Add Transaction to Graph
18
+ def add_transaction(txn):
19
+ sender_hash = hash_key(txn["SenderAccount"])
20
+ receiver_hash = hash_key(txn["ReceiverAccount"])
21
+
22
+ # Check if sender or receiver is already in a known graph
23
+ related_graphs = [h for h, g in transaction_graphs.items() if sender_hash in g or receiver_hash in g]
24
+
25
+ if related_graphs:
26
+ # Merge related graphs into one
27
+ new_graph_hash = generate_graph_hash(related_graphs)
28
+ merged_graph = nx.compose_all([transaction_graphs[h] for h in related_graphs])
29
+ merged_graph.add_edge(sender_hash, receiver_hash, **txn)
30
+
31
+ # Remove old graphs and add the merged one
32
+ for h in related_graphs:
33
+ del transaction_graphs[h]
34
+ transaction_graphs[new_graph_hash] = merged_graph
35
+ else:
36
+ # Create a new graph if no related transactions exist
37
+ new_graph = nx.DiGraph()
38
+ new_graph.add_edge(sender_hash, receiver_hash, **txn)
39
+ transaction_graphs[generate_graph_hash([sender_hash, receiver_hash])] = new_graph
40
+
41
+ # Detect Laundering Patterns
42
+
43
+
44
+ def detect_pattern(graph):
45
+ """Detect laundering patterns in the transaction graph."""
46
+
47
+ # If input is a Torch Geometric graph
48
+ if isinstance(graph, torch.Tensor) or hasattr(graph, "edge_index"):
49
+ # Extract unique node indices
50
+ nodes = torch.unique(graph.edge_index).tolist()
51
+ successors = {node: [] for node in nodes}
52
+ predecessors = {node: [] for node in nodes}
53
+
54
+ for i in range(graph.edge_index.shape[1]): # Process edges
55
+ sender, receiver = graph.edge_index[:, i].tolist()
56
+ successors[sender].append(receiver)
57
+ predecessors[receiver].append(sender)
58
+
59
+ # If input is a NetworkX graph
60
+ elif hasattr(graph, "nodes"):
61
+ nodes = list(graph.nodes)
62
+ successors = {node: list(graph.successors(node)) for node in nodes}
63
+ predecessors = {node: list(graph.predecessors(node)) for node in nodes}
64
+
65
+ else:
66
+ raise ValueError("Unsupported graph type")
67
+
68
+ # Pattern detection logic
69
+ for node in nodes:
70
+ outgoing = successors[node]
71
+ incoming = predecessors[node]
72
+
73
+ if len(outgoing) > 5:
74
+ return "Fan-Out" # One sender, many receivers
75
+ elif len(incoming) > 5:
76
+ return "Fan-In" # Many senders, one receiver
77
+ elif node in incoming:
78
+ return "Cycle" # Circular laundering
79
+ elif len(outgoing) > 2 and len(incoming) > 2:
80
+ return "Scatter Gather" # Money moves across multiple accounts
81
+
82
+ return "Normal"
83
+
84
+
85
+
86
+ # Store Suspicious AML Clusters
87
+ aml_clusters = {}
88
+
89
+ def flag_suspicious_graph(graph_hash):
90
+ """Mark a graph as an AML cluster if laundering is detected"""
91
+ if graph_hash in transaction_graphs:
92
+ pattern = detect_pattern(transaction_graphs[graph_hash])
93
+ if pattern != "Normal":
94
+ aml_clusters[graph_hash] = transaction_graphs[graph_hash]
95
+ print(f"🚨 AML Detected: {pattern} | Cluster ID: {graph_hash}")
96
+
97
+
test_model.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from sklearn.metrics import confusion_matrix, classification_report
7
+ from torch_geometric.data import Data
8
+ from gnn_aml import GAT, prepare_graph
9
+ from graph_aml import detect_pattern
10
+
11
+ # Load Model
12
+ print("πŸ” Loading Trained Model...")
13
+ model = GAT(num_node_features=3, hidden_dim=16, output_dim=2)
14
+ model.load_state_dict(torch.load("trained_model.pth"))
15
+ model.eval()
16
+
17
+ # Load New Test Data
18
+ print("πŸ“₯ Loading New Test Transactions...")
19
+ with open("test_transactions.json", "r") as f:
20
+ test_transactions = json.load(f)
21
+
22
+ # Prepare Graph Data
23
+ print("πŸ”„ Preparing Test Graph Data...")
24
+ test_graph, _ = prepare_graph()
25
+
26
+ # Run Model Predictions
27
+ print("🧠 Running Predictions...")
28
+ with torch.no_grad():
29
+ output = model(test_graph)
30
+ probs = torch.softmax(output, dim=1) # Convert logits to probabilities
31
+ predictions = (probs[:, 1] > 0.75).long() # 1 = AML, 0 = Normal
32
+
33
+ # Store predictions
34
+ test_results = []
35
+ y_true = [] # True labels
36
+ y_pred = [] # Predicted labels
37
+
38
+ for txn, prediction in zip(test_transactions, predictions):
39
+ risk_score = txn["RiskScore"]
40
+ true_label = 1 if txn["AML_Flag"] == 1 else 0 # True AML label
41
+ predicted_label = prediction.item()
42
+
43
+ # Update labels for confusion matrix
44
+ y_true.append(true_label)
45
+ y_pred.append(predicted_label)
46
+
47
+ if risk_score < 0.5:
48
+ predicted_pattern = "None"
49
+ elif predicted_label == 1:
50
+ predicted_pattern = detect_pattern(test_graph)
51
+ else:
52
+ predicted_pattern = "None"
53
+
54
+ test_results.append({
55
+ "TransactionID": txn["TransactionID"],
56
+ "TrueLabel": true_label,
57
+ "PredictedLabel": predicted_label,
58
+ "PredictedPattern": predicted_pattern,
59
+ "RiskScore": risk_score
60
+ })
61
+
62
+ # Save results to file
63
+ with open("new_test_results_v2.json", "w") as f:
64
+ json.dump(test_results, f, indent=4)
65
+
66
+ # **βœ… Compute Accuracy Metrics**
67
+ print("\nπŸ“Š **Final Test Results:**")
68
+ cm = confusion_matrix(y_true, y_pred)
69
+ report = classification_report(y_true, y_pred, target_names=[
70
+ "Normal", "AML"], digits=4)
71
+
72
+ print("\nπŸ”’ **Confusion Matrix:**\n", cm)
73
+ print("\nπŸ“„ **Classification Report:**\n", report)
74
+
75
+ # **βœ… Plot Confusion Matrix**
76
+ plt.figure(figsize=(6, 5))
77
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=[
78
+ "Normal", "AML"], yticklabels=["Normal", "AML"])
79
+ plt.xlabel("Predicted")
80
+ plt.ylabel("Actual")
81
+ plt.title("Confusion Matrix")
82
+ plt.show()
83
+
84
+ # **βœ… Plot Prediction Distribution**
85
+ labels, counts = np.unique(y_pred, return_counts=True)
86
+ plt.figure(figsize=(6, 5))
87
+ plt.bar(["Normal", "AML"], counts, color=["green", "red"])
88
+ plt.xlabel("Transaction Classification")
89
+ plt.ylabel("Number of Transactions")
90
+ plt.title("AML vs. Normal Transactions Detected")
91
+ plt.show()
92
+
93
+ print("βœ… Accuracy analysis complete! Check charts & logs.")
94
+
95
+
96
+ # import torch
97
+ # import json
98
+ # from torch_geometric.data import Data
99
+ # from gnn_aml import GAT, prepare_graph
100
+ # from graph_aml import detect_pattern
101
+
102
+ # # Load Model
103
+ # print("πŸ” Loading Trained Model...")
104
+ # model = GAT(num_node_features=3, hidden_dim=16, output_dim=2)
105
+ # model.load_state_dict(torch.load("trained_model.pth"))
106
+ # model.eval()
107
+
108
+ # # Load New Test Data
109
+ # print("πŸ“₯ Loading New Test Transactions...")
110
+ # with open("test_transactions.json", "r") as f:
111
+ # test_transactions = json.load(f)
112
+
113
+ # # Prepare Graph Data
114
+ # print("πŸ”„ Preparing Test Graph Data...")
115
+ # test_graph, _ = prepare_graph()
116
+
117
+ # # Run Model Predictions
118
+ # print("🧠 Running Predictions...")
119
+ # with torch.no_grad():
120
+ # output = model(test_graph)
121
+ # probs = torch.softmax(output, dim=1) # Convert logits to probabilities
122
+ # predictions = (probs[:, 1] > 0.75).long() # 1 = AML, 0 = Normal
123
+
124
+ # # Store predictions
125
+ # test_results = []
126
+ # aml_count = 0
127
+ # normal_count = 0
128
+
129
+ # for txn, prediction in zip(test_transactions, predictions):
130
+ # risk_score = txn["RiskScore"]
131
+ # predicted_label = prediction.item()
132
+
133
+ # if risk_score < 0.5:
134
+ # predicted_pattern = "None" # βœ… Mark as safe
135
+ # normal_count += 1 # βœ… Count normal transactions
136
+ # elif predicted_label == 1:
137
+ # predicted_pattern = detect_pattern(
138
+ # test_graph) # βœ… Detect actual pattern
139
+ # aml_count += 1 # βœ… Count AML transactions
140
+ # else:
141
+ # predicted_pattern = "None"
142
+ # normal_count += 1 # βœ… Count normal transactions
143
+
144
+ # test_results.append({
145
+ # "TransactionID": txn["TransactionID"],
146
+ # "PredictedPattern": predicted_pattern,
147
+ # "RiskScore": risk_score
148
+ # })
149
+
150
+ # # **βœ… Move logging here, after results are fully analyzed**
151
+ # print("\nπŸ“Š **Final Test Results:**")
152
+ # print(f"πŸ”΄ AML Detected: {aml_count}")
153
+ # print(f"🟒 Normal Transactions: {normal_count}")
154
+
155
+ # # Save results to file
156
+ # with open("new_test_results_v2.json", "w") as f:
157
+ # json.dump(test_results, f, indent=4)
158
+
159
+ # print("βœ… Test results saved to `new_test_results_v2.json`")
trained_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95cf97d804d6e6bf66a4da24e577ad6a9328272f6cafbea6df078185d6214275
3
+ size 4872