CMeKG-JK-GATv2-Text2KG

Introduction

This model is a Chinese Medical Knowledge Graph (CMeKG) graph embedding model based on GATv2 + JK-Net, trained with edge contrastive learning. It bridges natural language text and medical knowledge graph structural embeddings.

The model takes arbitrary Chinese medical descriptive text as input, automatically matches the most similar knowledge graph node, and outputs high-quality graph-aware embedding enhanced by structural knowledge.

Model Architecture

  • Text Encoder: BAAI/bge-m3 (fixed, for medical text semantic encoding)

  • Graph Encoder: GATv2 + JK-Net structure (two-layer graph attention + jump connection aggregation)

  • Training Objective: Edge contrastive learning on medical knowledge graph

  • Embedding Dimension: 256d normalized graph embedding

  • Task Scenarios: medical text-KG matching, medical semantic retrieval, knowledge-enhanced representation, link prediction

Performance

Trained on the full CMeKG medical knowledge graph with edge-level contrastive learning, achieving stable performance on medical knowledge graph link prediction and text-node matching tasks.

Environment Requirements

pip install torch torch-geometric transformers huggingface-hub numpy pickle-mixin

Usage Guide

Important: This model does not contain raw knowledge graph data. Users need to export the local CMeKG graph file and precompute node text embeddings by themselves.

Step 1: Export CMeKG Knowledge Graph (graph_export.pkl)

Export graph nodes and edges from local Neo4j CMeKG:

import pickle
from neo4j import GraphDatabase

def export_graph():
    driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "your_password"))
    with driver.session() as session:
        nodes = session.run(
            "MATCH (n) RETURN id(n) as id, labels(n)[0] as label, properties(n) as props"
        ).data()
        edges = session.run(
            "MATCH (a)-[r]->(b) RETURN id(a) as src, id(b) as dst, type(r) as rel_type"
        ).data()

        graph_data = {
            "nodes": nodes,
            "edges": edges,
            "num_nodes": len(nodes),
            "num_edges": len(edges)
        }
        with open("graph_export.pkl", "wb") as f:
            pickle.dump(graph_data, f)

export_graph()

Step 2: Precompute BGE-M3 Node Embeddings (bge_text_cache.npy)

Generate fixed text embeddings for all KG nodes:

import pickle
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

# Load graph nodes
with open("graph_export.pkl", "rb") as f:
    graph_data = pickle.load(f)
nodes = graph_data["nodes"]

# Build node description text
node_texts = [f"类型:{n['label']} 属性:{n['props']}" for n in nodes]

# Load BGE-M3
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
bge_model = AutoModel.from_pretrained("BAAI/bge-m3").eval()

emb_list = []
with torch.no_grad():
    for i in range(0, len(node_texts), 32):
        batch_text = node_texts[i:i+32]
        inputs = tokenizer(
            batch_text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        out = bge_model(** inputs)
        cls_emb = F.normalize(out.last_hidden_state[:, 0], p=2, dim=-1)
        emb_list.append(cls_emb.numpy())

full_emb = np.concatenate(emb_list, axis=0)
np.save("bge_text_cache.npy", full_emb)

Step 3: Load Model & Inference

One-click auto load via Hugging Face AutoModel:

from transformers import AutoModel

# Auto load the full graph embedding pipeline
model = AutoModel.from_pretrained(
    "MaxinT23/CMeKG-JK-GATv2-Text2KG",
    trust_remote_code=True
)

# Load local graph and precomputed embeddings
model.load_graph("graph_export.pkl", "bge_text_cache.npy")

# Inference with arbitrary medical text
text = "头痛伴随发热、乏力"
kg_emb, match_node_idx, similarity_score = model(text)

print("Matched Node Index:", match_node_idx.item())
print("Text-Node Similarity:", similarity_score.item())
print("Graph Embedding Shape:", kg_emb.shape)  # [1, 256]

Citation

If you use this model in your research, please cite our related work.

Training Results

image

License

MIT License

Downloads last month
50
Safetensors
Model size
3.48M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for MaxinT23/CMeKG-JK-GATv2-Text2KG

Base model

BAAI/bge-m3
Finetuned
(494)
this model