File size: 5,051 Bytes
90bc141
e49b23b
 
 
90bc141
 
e49b23b
90bc141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e49b23b
90bc141
 
 
e49b23b
 
 
 
90bc141
 
 
 
e49b23b
 
 
 
 
 
 
90bc141
 
 
e49b23b
90bc141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e49b23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90bc141
 
e49b23b
 
90bc141
 
 
e49b23b
90bc141
 
 
 
e49b23b
90bc141
 
 
e49b23b
 
90bc141
 
e49b23b
 
90bc141
 
 
 
e49b23b
90bc141
 
e49b23b
 
90bc141
 
e49b23b
 
 
90bc141
 
 
e49b23b
 
90bc141
 
 
e49b23b
 
 
90bc141
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torch.nn.functional as F
from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
from pathlib import Path
import structlog
from sentence_transformers import SentenceTransformer
from model.mlp import edge_features, PairMLP


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

structlog.configure(
    processors=[
        structlog.processors.TimeStamper(fmt="iso"),
        structlog.processors.JSONRenderer(indent=4, sort_keys=True),
    ]
)
logger = structlog.get_logger()


def abstract_to_vector(

    title: str, abstract_text: str, st_model: SentenceTransformer

) -> torch.Tensor:
    text = title + "\n" + abstract_text
    with torch.no_grad():
        vector = st_model.encode(text, convert_to_tensor=True, device=DEVICE)
    return vector


def get_citation_predictions(

    vector: torch.Tensor,

    model: PairMLP,

    z_all: torch.Tensor,

    num_nodes: int,

) -> torch.Tensor:
    model.eval()

    with torch.no_grad():
        combined_embeddings = torch.cat([vector.view(1, -1), z_all], dim=0)
        edge_index = torch.tensor([[0] * num_nodes, list(range(1, num_nodes + 1))]).to(
            DEVICE
        )
        feat = edge_features(combined_embeddings, edge_index).to(DEVICE)
        scores = torch.sigmoid(model(feat))
        return scores.squeeze()


def format_top_k_predictions(

    probs: torch.Tensor, dataset: OGBNLinkPredDataset, top_k=10, show_prob=False

) -> str:
    probs = probs.cpu()
    top_probs, top_indices = torch.topk(probs, k=top_k)
    output_lines = []
    header = f"Top {top_k} Citation Predictions:"
    output_lines.append(header)

    for i in range(top_k):
        paper_idx = top_indices[i].item()
        prob = top_probs[i].item()
        paper_info = dataset.corpus[paper_idx]
        paper_title = paper_info.split("\n")[0]
        if show_prob:
            line = f"  - Title: '{paper_title.strip()}', Probability: {prob:.4f}"
        else:
            line = f"  - Title: '{paper_title.strip()}'"
        output_lines.append(line)

    return "\n".join(output_lines)


def prepare_system(model_path: Path):
    logger.info("system_preparation.start")

    dataset = OGBNLinkPredDataset()
    logger.info("dataset.load.success")

    model_name = "bongsoo/kpf-sbert-128d-v1"
    logger.info(
        "model.load.start", model_type="SentenceTransformer", model_name=model_name
    )
    st_model = SentenceTransformer(model_name, device=DEVICE)
    logger.info("model.load.success", model_type="SentenceTransformer")

    # Load corpus embeddings
    if Path("model/embeddings.pth").exists():
        corpus_embeddings = torch.load("model/embeddings.pth", map_location=DEVICE)
        logger.info("embeddings.load.success")
    else:
        logger.info("embeddings.calculation.start")
        corpus_embeddings = st_model.encode(
            dataset.corpus, convert_to_tensor=True, show_progress_bar=True
        )
        Path("model").mkdir(parents=True, exist_ok=True)
        torch.save(corpus_embeddings, "model/embeddings.pth")
        logger.info("embeddings.calculation.success")

    corpus_embeddings = F.normalize(corpus_embeddings.to(DEVICE), p=2, dim=1)

    # Initialize PairMLP
    embedding_dim = corpus_embeddings.size(1)
    pair_mlp = PairMLP(embedding_dim * 2).to(DEVICE)

    if model_path.exists():
        pair_mlp.load_state_dict(torch.load(model_path, map_location=DEVICE))
        logger.info("model.load.success", model_type="PairMLP", path=str(model_path))
    else:
        logger.warning(
            "model.load.failure",
            model_type="PairMLP",
            path=str(model_path),
            reason="File not found, using random weights.",
        )

    pair_mlp.eval()

    logger.info(
        "embeddings.calculation.success",
        embedding_name="corpus_embeddings",
        shape=list(corpus_embeddings.shape),
    )
    logger.info("system_preparation.finish", status="ready_for_predictions")

    return pair_mlp, st_model, dataset, corpus_embeddings


if __name__ == "__main__":
    MODEL_PATH = Path("model.pth")
    pair_model, st_model, dataset, corpus_embeddings = prepare_system(MODEL_PATH)

    my_title = "A Survey of Graph Neural Networks for Link Prediction"
    my_abstract = """Link prediction is a critical task in graph analysis. 

                   In this paper, we review various GNN architectures like GCN and GraphSAGE for predicting edges."""

    new_vector = abstract_to_vector(my_title, my_abstract, st_model)
    new_vector = F.normalize(
        new_vector.view(1, -1), p=2, dim=1
    )  # Normalize like corpus embeddings

    probabilities = get_citation_predictions(
        vector=new_vector,
        model=pair_model,
        z_all=corpus_embeddings,
        num_nodes=dataset.data.num_nodes,
    )

    references = format_top_k_predictions(
        probabilities, dataset, top_k=5, show_prob=True
    )
    print(references)