Spaces:
Sleeping
Sleeping
File size: 5,051 Bytes
90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 e49b23b 90bc141 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import torch
import torch.nn.functional as F
from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
from pathlib import Path
import structlog
from sentence_transformers import SentenceTransformer
from model.mlp import edge_features, PairMLP
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
structlog.configure(
processors=[
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.JSONRenderer(indent=4, sort_keys=True),
]
)
logger = structlog.get_logger()
def abstract_to_vector(
title: str, abstract_text: str, st_model: SentenceTransformer
) -> torch.Tensor:
text = title + "\n" + abstract_text
with torch.no_grad():
vector = st_model.encode(text, convert_to_tensor=True, device=DEVICE)
return vector
def get_citation_predictions(
vector: torch.Tensor,
model: PairMLP,
z_all: torch.Tensor,
num_nodes: int,
) -> torch.Tensor:
model.eval()
with torch.no_grad():
combined_embeddings = torch.cat([vector.view(1, -1), z_all], dim=0)
edge_index = torch.tensor([[0] * num_nodes, list(range(1, num_nodes + 1))]).to(
DEVICE
)
feat = edge_features(combined_embeddings, edge_index).to(DEVICE)
scores = torch.sigmoid(model(feat))
return scores.squeeze()
def format_top_k_predictions(
probs: torch.Tensor, dataset: OGBNLinkPredDataset, top_k=10, show_prob=False
) -> str:
probs = probs.cpu()
top_probs, top_indices = torch.topk(probs, k=top_k)
output_lines = []
header = f"Top {top_k} Citation Predictions:"
output_lines.append(header)
for i in range(top_k):
paper_idx = top_indices[i].item()
prob = top_probs[i].item()
paper_info = dataset.corpus[paper_idx]
paper_title = paper_info.split("\n")[0]
if show_prob:
line = f" - Title: '{paper_title.strip()}', Probability: {prob:.4f}"
else:
line = f" - Title: '{paper_title.strip()}'"
output_lines.append(line)
return "\n".join(output_lines)
def prepare_system(model_path: Path):
logger.info("system_preparation.start")
dataset = OGBNLinkPredDataset()
logger.info("dataset.load.success")
model_name = "bongsoo/kpf-sbert-128d-v1"
logger.info(
"model.load.start", model_type="SentenceTransformer", model_name=model_name
)
st_model = SentenceTransformer(model_name, device=DEVICE)
logger.info("model.load.success", model_type="SentenceTransformer")
# Load corpus embeddings
if Path("model/embeddings.pth").exists():
corpus_embeddings = torch.load("model/embeddings.pth", map_location=DEVICE)
logger.info("embeddings.load.success")
else:
logger.info("embeddings.calculation.start")
corpus_embeddings = st_model.encode(
dataset.corpus, convert_to_tensor=True, show_progress_bar=True
)
Path("model").mkdir(parents=True, exist_ok=True)
torch.save(corpus_embeddings, "model/embeddings.pth")
logger.info("embeddings.calculation.success")
corpus_embeddings = F.normalize(corpus_embeddings.to(DEVICE), p=2, dim=1)
# Initialize PairMLP
embedding_dim = corpus_embeddings.size(1)
pair_mlp = PairMLP(embedding_dim * 2).to(DEVICE)
if model_path.exists():
pair_mlp.load_state_dict(torch.load(model_path, map_location=DEVICE))
logger.info("model.load.success", model_type="PairMLP", path=str(model_path))
else:
logger.warning(
"model.load.failure",
model_type="PairMLP",
path=str(model_path),
reason="File not found, using random weights.",
)
pair_mlp.eval()
logger.info(
"embeddings.calculation.success",
embedding_name="corpus_embeddings",
shape=list(corpus_embeddings.shape),
)
logger.info("system_preparation.finish", status="ready_for_predictions")
return pair_mlp, st_model, dataset, corpus_embeddings
if __name__ == "__main__":
MODEL_PATH = Path("model.pth")
pair_model, st_model, dataset, corpus_embeddings = prepare_system(MODEL_PATH)
my_title = "A Survey of Graph Neural Networks for Link Prediction"
my_abstract = """Link prediction is a critical task in graph analysis.
In this paper, we review various GNN architectures like GCN and GraphSAGE for predicting edges."""
new_vector = abstract_to_vector(my_title, my_abstract, st_model)
new_vector = F.normalize(
new_vector.view(1, -1), p=2, dim=1
) # Normalize like corpus embeddings
probabilities = get_citation_predictions(
vector=new_vector,
model=pair_model,
z_all=corpus_embeddings,
num_nodes=dataset.data.num_nodes,
)
references = format_top_k_predictions(
probabilities, dataset, top_k=5, show_prob=True
)
print(references)
|