File size: 1,224 Bytes
832456c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
from constants import *
import torch.optim as optim
from edge_generator import EdgeGenerator
from matrix_generator import MatrixGenerator
from combined_generator import CombinedGenerator

edge_generator = EdgeGenerator(INPUT_SIZE_GEN, HIDDEN_SIZE_GEN, OUTPUT_SIZE_EDGE_GEN)
matrix_generator = MatrixGenerator(INPUT_SIZE_GEN, HIDDEN_SIZE_GEN, OUTPUT_SIZE_MAT_GEN)

model = CombinedGenerator(edge_generator, matrix_generator)
model.load_state_dict(torch.load("model.pth"))
model.eval()

def get_fake_data(batch_size, combined):
    fake_graphs=[]
    for i in range(batch_size):
        rand_noise = torch.randn(1, INPUT_SIZE_GEN)
        fake_graphs.append(combined(rand_noise))
    return fake_graphs 
fake_data = get_fake_data(64,model)

def deconstructor(matrix):
    mat1 = matrix[:, :3]  
    mat2 = matrix[:, 3:]  
    return mat1, mat2

def adj_matrix_to_dict(adj_matrix):
    adj_dict = {}
    for i, row in enumerate(adj_matrix):
        adj_dict[i] = []
        for j, edge in enumerate(row):
            if edge != 0:
                adj_dict[i].append(j)
    return adj_dict

dict_list = []
for data in fake_data:
    dict_list.append(adj_matrix_to_dict(deconstructor(data)[1]))