Spaces:
Build error
Build error
Commit ·
ca764d6
1
Parent(s): ba1c7a0
Add real-time inference
Browse files- .gitignore +5 -1
- README.md +1 -1
- data/kg_edge_types.csv +75 -0
- data/kg_node_types.csv +11 -0
- pages/about.py +4 -1
- pages/explore.py +4 -1
- pages/input.py +54 -1
- pages/validate.py +98 -1
- requirements.txt +2 -1
.gitignore
CHANGED
|
@@ -5,6 +5,10 @@
|
|
| 5 |
# Ignore python cache files
|
| 6 |
__pycache__/
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
# Ignore secrets
|
| 9 |
.streamlit/secrets.toml
|
| 10 |
-
.streamlit/gravity-user-db.json
|
|
|
|
|
|
| 5 |
# Ignore python cache files
|
| 6 |
__pycache__/
|
| 7 |
|
| 8 |
+
# Ignore model files
|
| 9 |
+
data/*.pt
|
| 10 |
+
|
| 11 |
# Ignore secrets
|
| 12 |
.streamlit/secrets.toml
|
| 13 |
+
.streamlit/gravity-user-db.json
|
| 14 |
+
test.ipynb
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 💻
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: purple
|
|
|
|
| 1 |
---
|
| 2 |
+
title: GRAVITY
|
| 3 |
emoji: 💻
|
| 4 |
colorFrom: red
|
| 5 |
colorTo: purple
|
data/kg_edge_types.csv
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
x_type,relation,display_relation,y_type,direction,N
|
| 2 |
+
anatomy,anatomy_protein_present,expression present,gene/protein,forward,3831782
|
| 3 |
+
gene/protein,rev_anatomy_protein_present,expression present,anatomy,reverse,3831782
|
| 4 |
+
drug,drug_drug,synergistic interaction,drug,forward,1433261
|
| 5 |
+
drug,rev_drug_drug,synergistic interaction,drug,reverse,1433261
|
| 6 |
+
anatomy,anatomy_protein_absent,expression absent,gene/protein,forward,324186
|
| 7 |
+
gene/protein,rev_anatomy_protein_absent,expression absent,anatomy,reverse,324186
|
| 8 |
+
gene/protein,protein_protein,ppi,gene/protein,forward,321090
|
| 9 |
+
gene/protein,rev_protein_protein,ppi,gene/protein,reverse,321090
|
| 10 |
+
disease,disease_phenotype_positive,phenotype present,effect/phenotype,forward,200354
|
| 11 |
+
effect/phenotype,rev_disease_phenotype_positive,phenotype present,disease,reverse,200354
|
| 12 |
+
disease,disease_protein,associated with,gene/protein,forward,147984
|
| 13 |
+
gene/protein,rev_disease_protein,associated with,disease,reverse,147984
|
| 14 |
+
biological_process,bioprocess_protein,interacts with,gene/protein,forward,138297
|
| 15 |
+
gene/protein,rev_bioprocess_protein,interacts with,biological_process,reverse,138297
|
| 16 |
+
cellular_component,cellcomp_protein,interacts with,gene/protein,forward,83089
|
| 17 |
+
gene/protein,rev_cellcomp_protein,interacts with,cellular_component,reverse,83089
|
| 18 |
+
disease,disease_protein_negative,expression downregulated,gene/protein,forward,71135
|
| 19 |
+
gene/protein,rev_disease_protein_negative,expression downregulated,disease,reverse,71135
|
| 20 |
+
gene/protein,molfunc_protein,interacts with,molecular_function,forward,70291
|
| 21 |
+
molecular_function,rev_molfunc_protein,interacts with,gene/protein,reverse,70291
|
| 22 |
+
disease,disease_protein_positive,expression upregulated,gene/protein,forward,69488
|
| 23 |
+
gene/protein,rev_disease_protein_positive,expression upregulated,disease,reverse,69488
|
| 24 |
+
drug,drug_effect,side effect,effect/phenotype,forward,64249
|
| 25 |
+
effect/phenotype,rev_drug_effect,side effect,drug,reverse,64249
|
| 26 |
+
biological_process,bioprocess_bioprocess,parent-child,biological_process,forward,50232
|
| 27 |
+
biological_process,rev_bioprocess_bioprocess,parent-child,biological_process,reverse,50232
|
| 28 |
+
gene/protein,pathway_protein,interacts with,pathway,forward,44116
|
| 29 |
+
pathway,rev_pathway_protein,interacts with,gene/protein,reverse,44116
|
| 30 |
+
disease,disease_disease,parent-child,disease,forward,37808
|
| 31 |
+
disease,rev_disease_disease,parent-child,disease,reverse,37808
|
| 32 |
+
disease,contraindication,contraindication,drug,forward,26899
|
| 33 |
+
drug,rev_contraindication,contraindication,disease,reverse,26899
|
| 34 |
+
effect/phenotype,phenotype_phenotype,parent-child,effect/phenotype,forward,20183
|
| 35 |
+
effect/phenotype,rev_phenotype_phenotype,parent-child,effect/phenotype,reverse,20183
|
| 36 |
+
drug,drug_protein,target,gene/protein,forward,18513
|
| 37 |
+
gene/protein,rev_drug_protein,target,drug,reverse,18513
|
| 38 |
+
disease,weak_clinical_evidence,clinical candidate,drug,forward,16111
|
| 39 |
+
drug,rev_weak_clinical_evidence,clinical candidate,disease,reverse,16111
|
| 40 |
+
anatomy,anatomy_anatomy,parent-child,anatomy,forward,14383
|
| 41 |
+
anatomy,rev_anatomy_anatomy,parent-child,anatomy,reverse,14383
|
| 42 |
+
molecular_function,molfunc_molfunc,parent-child,molecular_function,forward,13735
|
| 43 |
+
molecular_function,rev_molfunc_molfunc,parent-child,molecular_function,reverse,13735
|
| 44 |
+
disease,indication,indication,drug,forward,12608
|
| 45 |
+
drug,rev_indication,indication,disease,reverse,12608
|
| 46 |
+
drug,drug_protein,enzyme,gene/protein,forward,5919
|
| 47 |
+
gene/protein,rev_drug_protein,enzyme,drug,reverse,5919
|
| 48 |
+
disease,strong_clinical_evidence,clinical candidate,drug,forward,5352
|
| 49 |
+
drug,rev_strong_clinical_evidence,clinical candidate,disease,reverse,5352
|
| 50 |
+
cellular_component,cellcomp_cellcomp,parent-child,cellular_component,forward,4683
|
| 51 |
+
cellular_component,rev_cellcomp_cellcomp,parent-child,cellular_component,reverse,4683
|
| 52 |
+
effect/phenotype,phenotype_protein,associated with,gene/protein,forward,4437
|
| 53 |
+
gene/protein,rev_phenotype_protein,associated with,effect/phenotype,reverse,4437
|
| 54 |
+
drug,drug_protein,transporter,gene/protein,forward,3349
|
| 55 |
+
gene/protein,rev_drug_protein,transporter,drug,reverse,3349
|
| 56 |
+
pathway,pathway_pathway,parent-child,pathway,forward,2647
|
| 57 |
+
pathway,rev_pathway_pathway,parent-child,pathway,reverse,2647
|
| 58 |
+
disease,exposure_disease,linked to,exposure,forward,2421
|
| 59 |
+
exposure,rev_exposure_disease,linked to,disease,reverse,2421
|
| 60 |
+
disease,off_label_use,off-label use,drug,forward,2370
|
| 61 |
+
drug,rev_off_label_use,off-label use,disease,reverse,2370
|
| 62 |
+
exposure,exposure_exposure,parent-child,exposure,forward,2263
|
| 63 |
+
exposure,rev_exposure_exposure,parent-child,exposure,reverse,2263
|
| 64 |
+
exposure,exposure_protein,interacts with,gene/protein,forward,2012
|
| 65 |
+
gene/protein,rev_exposure_protein,interacts with,exposure,reverse,2012
|
| 66 |
+
biological_process,exposure_bioprocess,interacts with,exposure,forward,1990
|
| 67 |
+
exposure,rev_exposure_bioprocess,interacts with,biological_process,reverse,1990
|
| 68 |
+
drug,drug_protein,carrier,gene/protein,forward,993
|
| 69 |
+
gene/protein,rev_drug_protein,carrier,drug,reverse,993
|
| 70 |
+
disease,disease_phenotype_negative,phenotype absent,effect/phenotype,forward,508
|
| 71 |
+
effect/phenotype,rev_disease_phenotype_negative,phenotype absent,disease,reverse,508
|
| 72 |
+
exposure,exposure_molfunc,interacts with,molecular_function,forward,45
|
| 73 |
+
molecular_function,rev_exposure_molfunc,interacts with,exposure,reverse,45
|
| 74 |
+
cellular_component,exposure_cellcomp,interacts with,exposure,forward,12
|
| 75 |
+
exposure,rev_exposure_cellcomp,interacts with,cellular_component,reverse,12
|
data/kg_node_types.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
node_type,N
|
| 2 |
+
gene/protein,35198
|
| 3 |
+
biological_process,27668
|
| 4 |
+
disease,22201
|
| 5 |
+
effect/phenotype,16711
|
| 6 |
+
anatomy,14384
|
| 7 |
+
molecular_function,11228
|
| 8 |
+
drug,8160
|
| 9 |
+
cellular_component,4054
|
| 10 |
+
pathway,2629
|
| 11 |
+
exposure,860
|
pages/about.py
CHANGED
|
@@ -14,4 +14,7 @@ menu_with_redirect()
|
|
| 14 |
st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
|
| 15 |
|
| 16 |
# Main content
|
| 17 |
-
st.markdown(f"Hello, {st.session_state.name}! Welcome to GRAVITY, a
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
|
| 15 |
|
| 16 |
# Main content
|
| 17 |
+
st.markdown(f"Hello, {st.session_state.name}! Welcome to GRAVITY, a **GR**aph **A**I **VI**sualization **T**ool to query and visualize knowledge graph-grounded biomedical AI models.")
|
| 18 |
+
|
| 19 |
+
# Subheader
|
| 20 |
+
st.subheader("About GRAVITY", divider = "grey")
|
pages/explore.py
CHANGED
|
@@ -14,4 +14,7 @@ menu_with_redirect()
|
|
| 14 |
st.image(str(project_config.MEDIA_DIR / 'explore_header.svg'), use_column_width=True)
|
| 15 |
|
| 16 |
# Main content
|
| 17 |
-
st.markdown(f"Hello, {st.session_state.name}!")
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
st.image(str(project_config.MEDIA_DIR / 'explore_header.svg'), use_column_width=True)
|
| 15 |
|
| 16 |
# Main content
|
| 17 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
| 18 |
+
|
| 19 |
+
# Coming soon
|
| 20 |
+
st.write("Coming soon...")
|
pages/input.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from menu import menu_with_redirect
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
# Path manipulation
|
| 5 |
from pathlib import Path
|
| 6 |
|
|
@@ -14,4 +18,53 @@ menu_with_redirect()
|
|
| 14 |
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
|
| 15 |
|
| 16 |
# Main content
|
| 17 |
-
st.markdown(f"Hello, {st.session_state.name}!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from menu import menu_with_redirect
|
| 3 |
|
| 4 |
+
# Standard imports
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
# Path manipulation
|
| 9 |
from pathlib import Path
|
| 10 |
|
|
|
|
| 18 |
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
|
| 19 |
|
| 20 |
# Main content
|
| 21 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
| 22 |
+
|
| 23 |
+
st.subheader("Construct Query", divider = "red")
|
| 24 |
+
|
| 25 |
+
# Checkbox to allow reverse edges
|
| 26 |
+
allow_reverse_edges = st.checkbox("Reverse Edges", value = False)
|
| 27 |
+
|
| 28 |
+
with st.spinner('Loading knowledge graph...'):
|
| 29 |
+
kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
| 30 |
+
node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
|
| 31 |
+
edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
|
| 32 |
+
|
| 33 |
+
if not allow_reverse_edges:
|
| 34 |
+
edge_types = edge_types[edge_types.direction == 'forward']
|
| 35 |
+
|
| 36 |
+
# Select source node type
|
| 37 |
+
source_node_type = st.selectbox("Source Node Type", node_types['node_type'],
|
| 38 |
+
format_func = lambda x: x.replace("_", " "))
|
| 39 |
+
|
| 40 |
+
# Select source node
|
| 41 |
+
source_node = st.selectbox("Source Node", kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name'])
|
| 42 |
+
|
| 43 |
+
# Select target node type
|
| 44 |
+
target_node_type = st.selectbox("Target Node Type", edge_types[edge_types.x_type == source_node_type].y_type.unique(),
|
| 45 |
+
format_func = lambda x: x.replace("_", " "))
|
| 46 |
+
|
| 47 |
+
# Select relation
|
| 48 |
+
relation = st.selectbox("Edge Type", edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique(),
|
| 49 |
+
format_func = lambda x: x.replace("_", "-"))
|
| 50 |
+
|
| 51 |
+
# Button to submit query
|
| 52 |
+
if st.button("Submit Query"):
|
| 53 |
+
|
| 54 |
+
# Save query to session state
|
| 55 |
+
st.session_state.query = {
|
| 56 |
+
"source_node_type": source_node_type,
|
| 57 |
+
"source_node": source_node,
|
| 58 |
+
"target_node_type": target_node_type,
|
| 59 |
+
"relation": relation
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# # Write query to console
|
| 63 |
+
# st.write("Current Query:")
|
| 64 |
+
# st.write(st.session_state.query)
|
| 65 |
+
st.write("Query submitted.")
|
| 66 |
+
|
| 67 |
+
st.subheader("Knowledge Graph", divider = "red")
|
| 68 |
+
display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
|
| 69 |
+
display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
|
| 70 |
+
st.dataframe(display_data, use_container_width = True)
|
pages/validate.py
CHANGED
|
@@ -1,8 +1,16 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from menu import menu_with_redirect
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
# Path manipulation
|
| 5 |
from pathlib import Path
|
|
|
|
| 6 |
|
| 7 |
# Custom and other imports
|
| 8 |
import project_config
|
|
@@ -14,4 +22,93 @@ menu_with_redirect()
|
|
| 14 |
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
|
| 15 |
|
| 16 |
# Main content
|
| 17 |
-
st.markdown(f"Hello, {st.session_state.name}!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from menu import menu_with_redirect
|
| 3 |
|
| 4 |
+
# Standard imports
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
# Path manipulation
|
| 12 |
from pathlib import Path
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
|
| 15 |
# Custom and other imports
|
| 16 |
import project_config
|
|
|
|
| 22 |
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
|
| 23 |
|
| 24 |
# Main content
|
| 25 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
| 26 |
+
|
| 27 |
+
st.subheader("Model Predictions", divider = "green")
|
| 28 |
+
|
| 29 |
+
# Print current query
|
| 30 |
+
st.markdown(f"**Query:** {st.session_state.query['source_node']} ➡️ {st.session_state.query['relation']} ➡️ {st.session_state.query['target_node_type']}")
|
| 31 |
+
|
| 32 |
+
with st.spinner('Loading knowledge graph...'):
|
| 33 |
+
kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
| 34 |
+
|
| 35 |
+
# Get paths to embeddings, relation weights, and edge types
|
| 36 |
+
with st.spinner('Downloading AI model...'):
|
| 37 |
+
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
| 38 |
+
filename="2024_03_29_04_12_52_epoch=3-step=54291_embeddings.pt",
|
| 39 |
+
token=st.secrets["HF_TOKEN"])
|
| 40 |
+
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
| 41 |
+
filename="2024_03_29_04_12_52_epoch=3-step=54291_relation_weights.pt",
|
| 42 |
+
token=st.secrets["HF_TOKEN"])
|
| 43 |
+
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
| 44 |
+
filename="2024_03_29_04_12_52_epoch=3-step=54291_edge_types.pt",
|
| 45 |
+
token=st.secrets["HF_TOKEN"])
|
| 46 |
+
|
| 47 |
+
# Load embeddings, relation weights, and edge types
|
| 48 |
+
with st.spinner('Loading AI model...'):
|
| 49 |
+
embeddings = torch.load(embed_path)
|
| 50 |
+
relation_weights = torch.load(relation_weights_path)
|
| 51 |
+
edge_types = torch.load(edge_types_path)
|
| 52 |
+
|
| 53 |
+
# # Print source node type
|
| 54 |
+
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
|
| 55 |
+
|
| 56 |
+
# # Print source node
|
| 57 |
+
# st.write(f"Source Node: {st.session_state.query['source_node']}")
|
| 58 |
+
|
| 59 |
+
# # Print relation
|
| 60 |
+
# st.write(f"Edge Type: {st.session_state.query['relation']}")
|
| 61 |
+
|
| 62 |
+
# # Print target node type
|
| 63 |
+
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
|
| 64 |
+
|
| 65 |
+
# Compute predictions
|
| 66 |
+
with st.spinner('Computing predictions...'):
|
| 67 |
+
|
| 68 |
+
source_node_type = st.session_state.query['source_node_type']
|
| 69 |
+
source_node = st.session_state.query['source_node']
|
| 70 |
+
relation = st.session_state.query['relation']
|
| 71 |
+
target_node_type = st.session_state.query['target_node_type']
|
| 72 |
+
|
| 73 |
+
# Get source node index
|
| 74 |
+
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
|
| 75 |
+
|
| 76 |
+
# Get relation index
|
| 77 |
+
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
|
| 78 |
+
|
| 79 |
+
# Get target nodes indices
|
| 80 |
+
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type]
|
| 81 |
+
dst_indices = target_nodes.node_index.values
|
| 82 |
+
src_indices = np.repeat(src_index, len(dst_indices))
|
| 83 |
+
|
| 84 |
+
# Retrieve cached embeddings
|
| 85 |
+
src_embeddings = embeddings[src_indices]
|
| 86 |
+
dst_embeddings = embeddings[dst_indices]
|
| 87 |
+
|
| 88 |
+
# Apply activation function
|
| 89 |
+
src_embeddings = F.leaky_relu(src_embeddings)
|
| 90 |
+
dst_embeddings = F.leaky_relu(dst_embeddings)
|
| 91 |
+
|
| 92 |
+
# Get relation weights
|
| 93 |
+
rel_weights = relation_weights[edge_type_index]
|
| 94 |
+
|
| 95 |
+
# Compute weighted dot product
|
| 96 |
+
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
|
| 97 |
+
scores = torch.sigmoid(scores)
|
| 98 |
+
|
| 99 |
+
# Add scores to dataframe
|
| 100 |
+
target_nodes['score'] = scores.detach().numpy()
|
| 101 |
+
|
| 102 |
+
# Rank target nodes by score
|
| 103 |
+
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
|
| 104 |
+
|
| 105 |
+
# Add rank to dataframe
|
| 106 |
+
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
|
| 107 |
+
|
| 108 |
+
# Show top ranked nodes
|
| 109 |
+
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], 50)
|
| 110 |
+
|
| 111 |
+
# Rename columns
|
| 112 |
+
display_data = target_nodes[['rank', 'node_id', 'node_name', 'node_source', 'score']].iloc[:top_k].copy()
|
| 113 |
+
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'node_source': 'Database', 'score': 'Score'})
|
| 114 |
+
st.dataframe(display_data, use_container_width = True)
|
requirements.txt
CHANGED
|
@@ -7,4 +7,5 @@ pathlib
|
|
| 7 |
torch
|
| 8 |
altair<5
|
| 9 |
gspread
|
| 10 |
-
oauth2client
|
|
|
|
|
|
| 7 |
torch
|
| 8 |
altair<5
|
| 9 |
gspread
|
| 10 |
+
oauth2client
|
| 11 |
+
huggingface_hub
|