Instructions to use MaxinT23/CMeKG-JK-GATv2-Text2KG with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use MaxinT23/CMeKG-JK-GATv2-Text2KG with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="MaxinT23/CMeKG-JK-GATv2-Text2KG")# Load model directly from transformers import CMeKGText2KG model = CMeKGText2KG.from_pretrained("MaxinT23/CMeKG-JK-GATv2-Text2KG", dtype="auto") - Notebooks
- Google Colab
- Kaggle
CMeKG-JK-GATv2-Text2KG
Introduction
This model is a Chinese Medical Knowledge Graph (CMeKG) graph embedding model based on GATv2 + JK-Net, trained with edge contrastive learning. It bridges natural language text and medical knowledge graph structural embeddings.
The model takes arbitrary Chinese medical descriptive text as input, automatically matches the most similar knowledge graph node, and outputs high-quality graph-aware embedding enhanced by structural knowledge.
Model Architecture
Text Encoder: BAAI/bge-m3 (fixed, for medical text semantic encoding)
Graph Encoder: GATv2 + JK-Net structure (two-layer graph attention + jump connection aggregation)
Training Objective: Edge contrastive learning on medical knowledge graph
Embedding Dimension: 256d normalized graph embedding
Task Scenarios: medical text-KG matching, medical semantic retrieval, knowledge-enhanced representation, link prediction
Performance
Trained on the full CMeKG medical knowledge graph with edge-level contrastive learning, achieving stable performance on medical knowledge graph link prediction and text-node matching tasks.
Environment Requirements
pip install torch torch-geometric transformers huggingface-hub numpy pickle-mixin
Usage Guide
Important: This model does not contain raw knowledge graph data. Users need to export the local CMeKG graph file and precompute node text embeddings by themselves.
Step 1: Export CMeKG Knowledge Graph (graph_export.pkl)
Export graph nodes and edges from local Neo4j CMeKG:
import pickle
from neo4j import GraphDatabase
def export_graph():
driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "your_password"))
with driver.session() as session:
nodes = session.run(
"MATCH (n) RETURN id(n) as id, labels(n)[0] as label, properties(n) as props"
).data()
edges = session.run(
"MATCH (a)-[r]->(b) RETURN id(a) as src, id(b) as dst, type(r) as rel_type"
).data()
graph_data = {
"nodes": nodes,
"edges": edges,
"num_nodes": len(nodes),
"num_edges": len(edges)
}
with open("graph_export.pkl", "wb") as f:
pickle.dump(graph_data, f)
export_graph()
Step 2: Precompute BGE-M3 Node Embeddings (bge_text_cache.npy)
Generate fixed text embeddings for all KG nodes:
import pickle
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
# Load graph nodes
with open("graph_export.pkl", "rb") as f:
graph_data = pickle.load(f)
nodes = graph_data["nodes"]
# Build node description text
node_texts = [f"类型:{n['label']} 属性:{n['props']}" for n in nodes]
# Load BGE-M3
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
bge_model = AutoModel.from_pretrained("BAAI/bge-m3").eval()
emb_list = []
with torch.no_grad():
for i in range(0, len(node_texts), 32):
batch_text = node_texts[i:i+32]
inputs = tokenizer(
batch_text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
out = bge_model(** inputs)
cls_emb = F.normalize(out.last_hidden_state[:, 0], p=2, dim=-1)
emb_list.append(cls_emb.numpy())
full_emb = np.concatenate(emb_list, axis=0)
np.save("bge_text_cache.npy", full_emb)
Step 3: Load Model & Inference
One-click auto load via Hugging Face AutoModel:
from transformers import AutoModel
# Auto load the full graph embedding pipeline
model = AutoModel.from_pretrained(
"MaxinT23/CMeKG-JK-GATv2-Text2KG",
trust_remote_code=True
)
# Load local graph and precomputed embeddings
model.load_graph("graph_export.pkl", "bge_text_cache.npy")
# Inference with arbitrary medical text
text = "头痛伴随发热、乏力"
kg_emb, match_node_idx, similarity_score = model(text)
print("Matched Node Index:", match_node_idx.item())
print("Text-Node Similarity:", similarity_score.item())
print("Graph Embedding Shape:", kg_emb.shape) # [1, 256]
Citation
If you use this model in your research, please cite our related work.
Training Results
License
MIT License
- Downloads last month
- 50
Model tree for MaxinT23/CMeKG-JK-GATv2-Text2KG
Base model
BAAI/bge-m3