Spaces:
Sleeping
Sleeping
File size: 5,323 Bytes
d97a439 0bee7fb d97a439 0bee7fb d97a439 fdfe8da d97a439 fdfe8da d97a439 fdfe8da d97a439 fdfe8da d97a439 0bee7fb d97a439 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
import pandas as pd
from src.config import config
class GNNClassifier(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, layers, output_dim, dropout_rate=0.5):
super().__init__()
self.dropout_rate = dropout_rate
self.hidden_dim = hidden_dim
self.layers = layers
self.output_dim = output_dim
if layers == 2:
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, output_dim)
elif layers == 3:
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.conv3 = GCNConv(hidden_dim, output_dim)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv2(x, edge_index)
if self.layers == 3:
x = F.relu(x)
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv3(x, edge_index)
return x
def load_data(version: str = "undirected"):
if version == "undirected":
graph_data = torch.load(config.GNN_GRAPH_DATA_PATH, map_location=torch.device("cpu"))
title_to_id = torch.load(config.TITLE_TO_ID_PATH, map_location=torch.device("cpu"))
label_mapping = torch.load(config.LABEL_MAPPING_PATH, map_location=torch.device("cpu"))
elif version == "no_edge":
graph_data = torch.load(config.GNN_GRAPH_DATA_PATH.replace("undirected_gnn", "no_edge_gnn"), map_location=torch.device("cpu"))
title_to_id = torch.load(config.TITLE_TO_ID_PATH.replace("undirected_gnn", "no_edge_gnn"), map_location=torch.device("cpu"))
label_mapping = torch.load(config.LABEL_MAPPING_PATH.replace("undirected_gnn", "no_edge_gnn"), map_location=torch.device("cpu"))
else:
raise ValueError(f"Unknown version: {version}")
return graph_data, title_to_id, label_mapping
def infer_new_node(
base_data: Data,
model: torch.nn.Module,
new_embedding,
referenced_titles: list[str],
title_to_id: dict[str, int],
label_mapping: dict[str, int],
device: torch.device,
make_undirected_for_new_node: bool = True,
use_edges: bool = True,
):
model.eval()
model = model.to(device)
base_data = base_data.to(device)
x_old = base_data.x
new_x = torch.tensor(new_embedding, dtype=x_old.dtype).view(1, -1)
new_x = new_x.to(device)
x = torch.cat([x_old, new_x], dim=0)
new_id = x.size(0) - 1
src_list = []
tgt_list = []
for t in referenced_titles:
if t not in title_to_id:
continue
old_id = title_to_id[t]
src_list.append(old_id)
tgt_list.append(new_id)
if make_undirected_for_new_node:
src_list.append(new_id)
tgt_list.append(old_id)
if len(src_list) > 0 and use_edges:
new_edges = torch.tensor([src_list, tgt_list], dtype=torch.long)
new_edges = new_edges.to(device)
edge_index = torch.cat([base_data.edge_index, new_edges], dim=1)
else:
edge_index = base_data.edge_index
data_aug = Data(x=x, edge_index=edge_index).to(device)
with torch.no_grad():
out = model(data_aug)
log_probs = F.log_softmax(out, dim=1)
log_probs = log_probs[new_id]
pred_id = int(torch.argmax(log_probs).item())
inv_label_mapping = {v: k for k, v in label_mapping.items()}
pred_label = inv_label_mapping[pred_id]
probs = log_probs.exp().detach().cpu()
columns = ["Class", "Score"]
result_df = pd.DataFrame(
[(inv_label_mapping[i], prob.item()) for i, prob in enumerate(probs)],
columns=columns,
).sort_values(by="Score", ascending=False)
return result_df
if __name__ == "__main__":
from src.embedding import Embedder
graph_data, title_to_id, label_mapping = load_data()
model = GNNClassifier(input_dim=768, hidden_dim=128, layers=2, output_dim=len(label_mapping), dropout_rate=0.5)
model.load_state_dict(torch.load(r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\gnn\gnn_classifier_model.pth"), map_location=torch.device("cpu"))
new_node_content = "Istanbul Türkiye'nin en büyük şehri ve kültürel başkentidir. Tarih boyunca birçok medeniyete ev sahipliği yapmıştır."
embedder = Embedder(path=r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\models\embedding\gte-multilingual-base")
new_embedding = embedder.generate_embedding(new_node_content)
referenced_titles = ["forum istanbul", "istanbul film festivali", "akıllı şehir"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result = infer_new_node(
base_data=graph_data,
model=model,
new_embedding=new_embedding,
referenced_titles=referenced_titles,
title_to_id=title_to_id,
label_mapping=label_mapping,
device=device,
make_undirected_for_new_node=True,
)
print("Prediction Results for New Node:")
print(result) |