Spaces:
Running
Running
Erva Ulusoy
commited on
Commit
·
23f3d8f
1
Parent(s):
cdefd25
fix to avoid overwriting graph edges during inference + add prediction generation threshold
Browse files- ProtHGT_app.py +12 -1
- run_prothgt_app.py +50 -35
ProtHGT_app.py
CHANGED
|
@@ -270,6 +270,16 @@ with st.sidebar:
|
|
| 270 |
disabled=disabled
|
| 271 |
)
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
if selected_proteins and selected_go_category:
|
| 274 |
|
| 275 |
button_disabled = st.session_state.submitted
|
|
@@ -355,7 +365,8 @@ if st.session_state.submitted:
|
|
| 355 |
protein_ids=selected_proteins,
|
| 356 |
model_paths=model_paths,
|
| 357 |
model_config_paths=model_config_paths,
|
| 358 |
-
go_category=go_categories
|
|
|
|
| 359 |
)
|
| 360 |
|
| 361 |
st.session_state.heterodata = heterodata
|
|
|
|
| 270 |
disabled=disabled
|
| 271 |
)
|
| 272 |
|
| 273 |
+
generation_threshold = st.number_input(
|
| 274 |
+
"Generation threshold (optional)",
|
| 275 |
+
min_value=0.0,
|
| 276 |
+
max_value=1.0,
|
| 277 |
+
value=0.0,
|
| 278 |
+
step=0.05,
|
| 279 |
+
help="If > 0, only predictions with Probability >= threshold are generated. This reduces output size and speeds up the app.",
|
| 280 |
+
disabled=disabled,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
if selected_proteins and selected_go_category:
|
| 284 |
|
| 285 |
button_disabled = st.session_state.submitted
|
|
|
|
| 365 |
protein_ids=selected_proteins,
|
| 366 |
model_paths=model_paths,
|
| 367 |
model_config_paths=model_config_paths,
|
| 368 |
+
go_category=go_categories,
|
| 369 |
+
threshold=generation_threshold,
|
| 370 |
)
|
| 371 |
|
| 372 |
st.session_state.heterodata = heterodata
|
run_prothgt_app.py
CHANGED
|
@@ -48,44 +48,44 @@ class ProtHGT(torch.nn.Module):
|
|
| 48 |
|
| 49 |
return self.mlp(z).view(-1), x_dict
|
| 50 |
|
| 51 |
-
def
|
| 52 |
-
"""
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
n_terms = len(heterodata[go_category]['id_mapping'])
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
all_edges.append([protein_idx, term_idx])
|
| 61 |
-
|
| 62 |
-
edge_index = torch.tensor(all_edges).t()
|
| 63 |
-
|
| 64 |
-
heterodata[('Protein', 'protein_function', go_category)].edge_index = edge_index
|
| 65 |
-
heterodata[(go_category, 'rev_protein_function', 'Protein')].edge_index = torch.stack([edge_index[1], edge_index[0]])
|
| 66 |
-
|
| 67 |
-
return heterodata
|
| 68 |
|
| 69 |
def get_available_proteins(name_file='data/name_info.json.gz'):
|
| 70 |
with gzip.open(name_file, 'rt', encoding='utf-8') as file:
|
| 71 |
name_info = json.load(file)
|
| 72 |
return list(name_info['Protein'].keys())
|
| 73 |
|
| 74 |
-
def _generate_predictions(heterodata, model, target_type):
|
| 75 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 76 |
|
| 77 |
model.to(device)
|
| 78 |
model.eval()
|
| 79 |
heterodata = heterodata.to(device)
|
|
|
|
| 80 |
|
| 81 |
with torch.no_grad():
|
| 82 |
-
edge_label_index = heterodata.edge_index_dict[('Protein', 'protein_function', target_type)]
|
| 83 |
predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, edge_label_index, target_type)
|
| 84 |
predictions = torch.sigmoid(predictions)
|
| 85 |
|
| 86 |
return predictions.cpu()
|
| 87 |
|
| 88 |
-
def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
|
| 89 |
go_category_dict = {
|
| 90 |
'GO_term_F': 'Molecular Function',
|
| 91 |
'GO_term_P': 'Biological Process',
|
|
@@ -96,8 +96,8 @@ def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
|
|
| 96 |
with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
|
| 97 |
name_info = json.load(file)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
n_go_terms = len(
|
| 101 |
|
| 102 |
# Create lists to store the data
|
| 103 |
all_proteins = []
|
|
@@ -107,8 +107,10 @@ def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
|
|
| 107 |
all_categories = []
|
| 108 |
all_probabilities = []
|
| 109 |
|
| 110 |
-
#
|
| 111 |
-
go_terms =
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Process predictions for each protein
|
| 114 |
for i, protein_id in enumerate(protein_ids):
|
|
@@ -116,16 +118,29 @@ def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
|
|
| 116 |
start_idx = i * n_go_terms
|
| 117 |
end_idx = (i + 1) * n_go_terms
|
| 118 |
protein_predictions = predictions[start_idx:end_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
# Get protein name
|
| 121 |
protein_name = name_info['Protein'].get(protein_id, protein_id)
|
| 122 |
|
| 123 |
# Extend the lists
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
| 129 |
all_probabilities.extend(protein_predictions.tolist())
|
| 130 |
|
| 131 |
# Create DataFrame
|
|
@@ -140,7 +155,7 @@ def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
|
|
| 140 |
|
| 141 |
return prediction_df
|
| 142 |
|
| 143 |
-
def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_category):
|
| 144 |
all_predictions = []
|
| 145 |
|
| 146 |
# Convert single protein ID to list if necessary
|
|
@@ -171,9 +186,9 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
|
|
| 171 |
|
| 172 |
for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths):
|
| 173 |
print(f'Generating predictions for {go_cat}...')
|
| 174 |
-
|
| 175 |
-
#
|
| 176 |
-
|
| 177 |
|
| 178 |
# Load model config
|
| 179 |
with open(model_config_path, 'r') as file:
|
|
@@ -181,7 +196,7 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
|
|
| 181 |
|
| 182 |
# Initialize model with configuration
|
| 183 |
model = ProtHGT(
|
| 184 |
-
|
| 185 |
hidden_channels=model_config['hidden_channels'][0],
|
| 186 |
num_heads=model_config['num_heads'],
|
| 187 |
num_layers=model_config['num_layers'],
|
|
@@ -194,12 +209,12 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
|
|
| 194 |
print(f'Loaded model weights from {model_path}')
|
| 195 |
|
| 196 |
# Generate predictions
|
| 197 |
-
predictions = _generate_predictions(
|
| 198 |
-
prediction_df = _create_prediction_df(predictions,
|
| 199 |
all_predictions.append(prediction_df)
|
| 200 |
|
| 201 |
# Clean up memory
|
| 202 |
-
del
|
| 203 |
del model
|
| 204 |
del predictions
|
| 205 |
torch.cuda.empty_cache() # Clear CUDA cache if using GPU
|
|
|
|
| 48 |
|
| 49 |
return self.mlp(z).view(-1), x_dict
|
| 50 |
|
| 51 |
+
def _build_edge_label_index(heterodata, protein_ids, go_category):
|
| 52 |
+
"""
|
| 53 |
+
Build a dense candidate edge_label_index (Protein × GO terms) for inference.
|
| 54 |
+
|
| 55 |
+
IMPORTANT: Do NOT overwrite heterodata.edge_index_dict here.
|
| 56 |
+
Graph edges are used for message passing; candidate edges are only for scoring.
|
| 57 |
+
"""
|
| 58 |
+
protein_indices = torch.tensor(
|
| 59 |
+
[heterodata['Protein']['id_mapping'][pid] for pid in protein_ids],
|
| 60 |
+
dtype=torch.long,
|
| 61 |
+
)
|
| 62 |
n_terms = len(heterodata[go_category]['id_mapping'])
|
| 63 |
+
term_indices = torch.arange(n_terms, dtype=torch.long)
|
| 64 |
|
| 65 |
+
row = protein_indices.repeat_interleave(n_terms)
|
| 66 |
+
col = term_indices.repeat(len(protein_indices))
|
| 67 |
+
return torch.stack([row, col], dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
def get_available_proteins(name_file='data/name_info.json.gz'):
|
| 70 |
with gzip.open(name_file, 'rt', encoding='utf-8') as file:
|
| 71 |
name_info = json.load(file)
|
| 72 |
return list(name_info['Protein'].keys())
|
| 73 |
|
| 74 |
+
def _generate_predictions(heterodata, model, edge_label_index, target_type):
|
| 75 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 76 |
|
| 77 |
model.to(device)
|
| 78 |
model.eval()
|
| 79 |
heterodata = heterodata.to(device)
|
| 80 |
+
edge_label_index = edge_label_index.to(device)
|
| 81 |
|
| 82 |
with torch.no_grad():
|
|
|
|
| 83 |
predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, edge_label_index, target_type)
|
| 84 |
predictions = torch.sigmoid(predictions)
|
| 85 |
|
| 86 |
return predictions.cpu()
|
| 87 |
|
| 88 |
+
def _create_prediction_df(predictions, heterodata, protein_ids, go_category, threshold: float = 0.0):
|
| 89 |
go_category_dict = {
|
| 90 |
'GO_term_F': 'Molecular Function',
|
| 91 |
'GO_term_P': 'Biological Process',
|
|
|
|
| 96 |
with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
|
| 97 |
name_info = json.load(file)
|
| 98 |
|
| 99 |
+
id_mapping = heterodata[go_category]['id_mapping'] # dict: GO_id -> index
|
| 100 |
+
n_go_terms = len(id_mapping)
|
| 101 |
|
| 102 |
# Create lists to store the data
|
| 103 |
all_proteins = []
|
|
|
|
| 107 |
all_categories = []
|
| 108 |
all_probabilities = []
|
| 109 |
|
| 110 |
+
# Build GO terms list aligned with their numeric indices (critical for correctness)
|
| 111 |
+
go_terms = [None] * n_go_terms
|
| 112 |
+
for go_id, idx in id_mapping.items():
|
| 113 |
+
go_terms[int(idx)] = go_id
|
| 114 |
|
| 115 |
# Process predictions for each protein
|
| 116 |
for i, protein_id in enumerate(protein_ids):
|
|
|
|
| 118 |
start_idx = i * n_go_terms
|
| 119 |
end_idx = (i + 1) * n_go_terms
|
| 120 |
protein_predictions = predictions[start_idx:end_idx]
|
| 121 |
+
|
| 122 |
+
# Optional pre-filter for performance
|
| 123 |
+
if threshold and threshold > 0.0:
|
| 124 |
+
keep_mask = protein_predictions >= float(threshold)
|
| 125 |
+
if keep_mask.any():
|
| 126 |
+
keep_idx = torch.nonzero(keep_mask, as_tuple=False).view(-1)
|
| 127 |
+
protein_predictions = protein_predictions[keep_idx]
|
| 128 |
+
else:
|
| 129 |
+
continue
|
| 130 |
+
else:
|
| 131 |
+
keep_idx = torch.arange(n_go_terms)
|
| 132 |
|
| 133 |
# Get protein name
|
| 134 |
protein_name = name_info['Protein'].get(protein_id, protein_id)
|
| 135 |
|
| 136 |
# Extend the lists
|
| 137 |
+
k = int(protein_predictions.numel())
|
| 138 |
+
all_proteins.extend([protein_id] * k)
|
| 139 |
+
all_protein_names.extend([protein_name] * k)
|
| 140 |
+
kept_go_ids = [go_terms[int(j)] for j in keep_idx.tolist()]
|
| 141 |
+
all_go_terms.extend(kept_go_ids)
|
| 142 |
+
all_go_term_names.extend([name_info['GO_term'].get(term_id, term_id) for term_id in kept_go_ids])
|
| 143 |
+
all_categories.extend([go_category_dict[go_category]] * k)
|
| 144 |
all_probabilities.extend(protein_predictions.tolist())
|
| 145 |
|
| 146 |
# Create DataFrame
|
|
|
|
| 155 |
|
| 156 |
return prediction_df
|
| 157 |
|
| 158 |
+
def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_category, threshold: float = 0.0):
|
| 159 |
all_predictions = []
|
| 160 |
|
| 161 |
# Convert single protein ID to list if necessary
|
|
|
|
| 186 |
|
| 187 |
for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths):
|
| 188 |
print(f'Generating predictions for {go_cat}...')
|
| 189 |
+
|
| 190 |
+
# Build candidate edges for inference (do NOT modify graph edges)
|
| 191 |
+
edge_label_index = _build_edge_label_index(heterodata, protein_ids, go_cat)
|
| 192 |
|
| 193 |
# Load model config
|
| 194 |
with open(model_config_path, 'r') as file:
|
|
|
|
| 196 |
|
| 197 |
# Initialize model with configuration
|
| 198 |
model = ProtHGT(
|
| 199 |
+
heterodata,
|
| 200 |
hidden_channels=model_config['hidden_channels'][0],
|
| 201 |
num_heads=model_config['num_heads'],
|
| 202 |
num_layers=model_config['num_layers'],
|
|
|
|
| 209 |
print(f'Loaded model weights from {model_path}')
|
| 210 |
|
| 211 |
# Generate predictions
|
| 212 |
+
predictions = _generate_predictions(heterodata, model, edge_label_index, go_cat)
|
| 213 |
+
prediction_df = _create_prediction_df(predictions, heterodata, protein_ids, go_cat, threshold=threshold)
|
| 214 |
all_predictions.append(prediction_df)
|
| 215 |
|
| 216 |
# Clean up memory
|
| 217 |
+
del edge_label_index
|
| 218 |
del model
|
| 219 |
del predictions
|
| 220 |
torch.cuda.empty_cache() # Clear CUDA cache if using GPU
|