File size: 1,414 Bytes
c302dd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torch_geometric.data import Data
from collections import defaultdict

class TransactionGraphBuilder:
    def __init__(self):
        self.node_index = defaultdict(int)
        self.current_id = 0
        self.edges = []
        self.node_features = []
        self.node_types = []
        
    def get_node_id(self, node_key, node_type):
        if node_key not in self.node_index:
            self.node_index[node_key] = self.current_id
            self.current_id += 1
            # Simple feature representation
            self.node_features.append([1.0 if i == node_type else 0.0 for i in range(3)])
            self.node_types.append(node_type)
        return self.node_index[node_key]
    
    def add_transaction(self, transaction):
        # Account node (type 0)
        acc_id = self.get_node_id(transaction['AccountID'], 0)
        # Merchant node (type 1)
        merchant_id = self.get_node_id(transaction['MerchantID'], 1)
        # Device node (type 2)
        device_id = self.get_node_id(transaction['DeviceID'], 2)
        
        # Add edges
        self.edges.append((acc_id, merchant_id))
        self.edges.append((acc_id, device_id))
        
        # Convert to PyG format
        edge_index = torch.tensor(list(zip(*self.edges)), dtype=torch.long)
        x = torch.tensor(self.node_features, dtype=torch.float)
        
        return Data(x=x, edge_index=edge_index)