Erva Ulusoy commited on
Commit
23f3d8f
·
1 Parent(s): cdefd25

fix to avoid overwriting graph edges during inference + add prediction generation threshold

Browse files
Files changed (2) hide show
  1. ProtHGT_app.py +12 -1
  2. run_prothgt_app.py +50 -35
ProtHGT_app.py CHANGED
@@ -270,6 +270,16 @@ with st.sidebar:
270
  disabled=disabled
271
  )
272
 
 
 
 
 
 
 
 
 
 
 
273
  if selected_proteins and selected_go_category:
274
 
275
  button_disabled = st.session_state.submitted
@@ -355,7 +365,8 @@ if st.session_state.submitted:
355
  protein_ids=selected_proteins,
356
  model_paths=model_paths,
357
  model_config_paths=model_config_paths,
358
- go_category=go_categories
 
359
  )
360
 
361
  st.session_state.heterodata = heterodata
 
270
  disabled=disabled
271
  )
272
 
273
+ generation_threshold = st.number_input(
274
+ "Generation threshold (optional)",
275
+ min_value=0.0,
276
+ max_value=1.0,
277
+ value=0.0,
278
+ step=0.05,
279
+ help="If > 0, only predictions with Probability >= threshold are generated. This reduces output size and speeds up the app.",
280
+ disabled=disabled,
281
+ )
282
+
283
  if selected_proteins and selected_go_category:
284
 
285
  button_disabled = st.session_state.submitted
 
365
  protein_ids=selected_proteins,
366
  model_paths=model_paths,
367
  model_config_paths=model_config_paths,
368
+ go_category=go_categories,
369
+ threshold=generation_threshold,
370
  )
371
 
372
  st.session_state.heterodata = heterodata
run_prothgt_app.py CHANGED
@@ -48,44 +48,44 @@ class ProtHGT(torch.nn.Module):
48
 
49
  return self.mlp(z).view(-1), x_dict
50
 
51
- def _load_data(heterodata, protein_ids, go_category):
52
- """Process the loaded heterodata for specific proteins and GO categories."""
53
- # Get protein indices for all input proteins
54
- protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids]
 
 
 
 
 
 
 
55
  n_terms = len(heterodata[go_category]['id_mapping'])
 
56
 
57
- all_edges = []
58
- for protein_idx in protein_indices:
59
- for term_idx in range(n_terms):
60
- all_edges.append([protein_idx, term_idx])
61
-
62
- edge_index = torch.tensor(all_edges).t()
63
-
64
- heterodata[('Protein', 'protein_function', go_category)].edge_index = edge_index
65
- heterodata[(go_category, 'rev_protein_function', 'Protein')].edge_index = torch.stack([edge_index[1], edge_index[0]])
66
-
67
- return heterodata
68
 
69
  def get_available_proteins(name_file='data/name_info.json.gz'):
70
  with gzip.open(name_file, 'rt', encoding='utf-8') as file:
71
  name_info = json.load(file)
72
  return list(name_info['Protein'].keys())
73
 
74
- def _generate_predictions(heterodata, model, target_type):
75
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76
 
77
  model.to(device)
78
  model.eval()
79
  heterodata = heterodata.to(device)
 
80
 
81
  with torch.no_grad():
82
- edge_label_index = heterodata.edge_index_dict[('Protein', 'protein_function', target_type)]
83
  predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, edge_label_index, target_type)
84
  predictions = torch.sigmoid(predictions)
85
 
86
  return predictions.cpu()
87
 
88
- def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
89
  go_category_dict = {
90
  'GO_term_F': 'Molecular Function',
91
  'GO_term_P': 'Biological Process',
@@ -96,8 +96,8 @@ def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
96
  with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
97
  name_info = json.load(file)
98
 
99
- # Get number of GO terms for this category
100
- n_go_terms = len(heterodata[go_category]['id_mapping'])
101
 
102
  # Create lists to store the data
103
  all_proteins = []
@@ -107,8 +107,10 @@ def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
107
  all_categories = []
108
  all_probabilities = []
109
 
110
- # Get list of GO terms once
111
- go_terms = list(heterodata[go_category]['id_mapping'].keys())
 
 
112
 
113
  # Process predictions for each protein
114
  for i, protein_id in enumerate(protein_ids):
@@ -116,16 +118,29 @@ def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
116
  start_idx = i * n_go_terms
117
  end_idx = (i + 1) * n_go_terms
118
  protein_predictions = predictions[start_idx:end_idx]
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # Get protein name
121
  protein_name = name_info['Protein'].get(protein_id, protein_id)
122
 
123
  # Extend the lists
124
- all_proteins.extend([protein_id] * n_go_terms)
125
- all_protein_names.extend([protein_name] * n_go_terms)
126
- all_go_terms.extend(go_terms)
127
- all_go_term_names.extend([name_info['GO_term'].get(term_id, term_id) for term_id in go_terms])
128
- all_categories.extend([go_category_dict[go_category]] * n_go_terms)
 
 
129
  all_probabilities.extend(protein_predictions.tolist())
130
 
131
  # Create DataFrame
@@ -140,7 +155,7 @@ def _create_prediction_df(predictions, heterodata, protein_ids, go_category):
140
 
141
  return prediction_df
142
 
143
- def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_category):
144
  all_predictions = []
145
 
146
  # Convert single protein ID to list if necessary
@@ -171,9 +186,9 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
171
 
172
  for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths):
173
  print(f'Generating predictions for {go_cat}...')
174
-
175
- # Process data for current GO category
176
- processed_data = _load_data(copy.deepcopy(heterodata), protein_ids, go_cat)
177
 
178
  # Load model config
179
  with open(model_config_path, 'r') as file:
@@ -181,7 +196,7 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
181
 
182
  # Initialize model with configuration
183
  model = ProtHGT(
184
- processed_data,
185
  hidden_channels=model_config['hidden_channels'][0],
186
  num_heads=model_config['num_heads'],
187
  num_layers=model_config['num_layers'],
@@ -194,12 +209,12 @@ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_cate
194
  print(f'Loaded model weights from {model_path}')
195
 
196
  # Generate predictions
197
- predictions = _generate_predictions(processed_data, model, go_cat)
198
- prediction_df = _create_prediction_df(predictions, processed_data, protein_ids, go_cat)
199
  all_predictions.append(prediction_df)
200
 
201
  # Clean up memory
202
- del processed_data
203
  del model
204
  del predictions
205
  torch.cuda.empty_cache() # Clear CUDA cache if using GPU
 
48
 
49
  return self.mlp(z).view(-1), x_dict
50
 
51
+ def _build_edge_label_index(heterodata, protein_ids, go_category):
52
+ """
53
+ Build a dense candidate edge_label_index (Protein × GO terms) for inference.
54
+
55
+ IMPORTANT: Do NOT overwrite heterodata.edge_index_dict here.
56
+ Graph edges are used for message passing; candidate edges are only for scoring.
57
+ """
58
+ protein_indices = torch.tensor(
59
+ [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids],
60
+ dtype=torch.long,
61
+ )
62
  n_terms = len(heterodata[go_category]['id_mapping'])
63
+ term_indices = torch.arange(n_terms, dtype=torch.long)
64
 
65
+ row = protein_indices.repeat_interleave(n_terms)
66
+ col = term_indices.repeat(len(protein_indices))
67
+ return torch.stack([row, col], dim=0)
 
 
 
 
 
 
 
 
68
 
69
  def get_available_proteins(name_file='data/name_info.json.gz'):
70
  with gzip.open(name_file, 'rt', encoding='utf-8') as file:
71
  name_info = json.load(file)
72
  return list(name_info['Protein'].keys())
73
 
74
+ def _generate_predictions(heterodata, model, edge_label_index, target_type):
75
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76
 
77
  model.to(device)
78
  model.eval()
79
  heterodata = heterodata.to(device)
80
+ edge_label_index = edge_label_index.to(device)
81
 
82
  with torch.no_grad():
 
83
  predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, edge_label_index, target_type)
84
  predictions = torch.sigmoid(predictions)
85
 
86
  return predictions.cpu()
87
 
88
+ def _create_prediction_df(predictions, heterodata, protein_ids, go_category, threshold: float = 0.0):
89
  go_category_dict = {
90
  'GO_term_F': 'Molecular Function',
91
  'GO_term_P': 'Biological Process',
 
96
  with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
97
  name_info = json.load(file)
98
 
99
+ id_mapping = heterodata[go_category]['id_mapping'] # dict: GO_id -> index
100
+ n_go_terms = len(id_mapping)
101
 
102
  # Create lists to store the data
103
  all_proteins = []
 
107
  all_categories = []
108
  all_probabilities = []
109
 
110
+ # Build GO terms list aligned with their numeric indices (critical for correctness)
111
+ go_terms = [None] * n_go_terms
112
+ for go_id, idx in id_mapping.items():
113
+ go_terms[int(idx)] = go_id
114
 
115
  # Process predictions for each protein
116
  for i, protein_id in enumerate(protein_ids):
 
118
  start_idx = i * n_go_terms
119
  end_idx = (i + 1) * n_go_terms
120
  protein_predictions = predictions[start_idx:end_idx]
121
+
122
+ # Optional pre-filter for performance
123
+ if threshold and threshold > 0.0:
124
+ keep_mask = protein_predictions >= float(threshold)
125
+ if keep_mask.any():
126
+ keep_idx = torch.nonzero(keep_mask, as_tuple=False).view(-1)
127
+ protein_predictions = protein_predictions[keep_idx]
128
+ else:
129
+ continue
130
+ else:
131
+ keep_idx = torch.arange(n_go_terms)
132
 
133
  # Get protein name
134
  protein_name = name_info['Protein'].get(protein_id, protein_id)
135
 
136
  # Extend the lists
137
+ k = int(protein_predictions.numel())
138
+ all_proteins.extend([protein_id] * k)
139
+ all_protein_names.extend([protein_name] * k)
140
+ kept_go_ids = [go_terms[int(j)] for j in keep_idx.tolist()]
141
+ all_go_terms.extend(kept_go_ids)
142
+ all_go_term_names.extend([name_info['GO_term'].get(term_id, term_id) for term_id in kept_go_ids])
143
+ all_categories.extend([go_category_dict[go_category]] * k)
144
  all_probabilities.extend(protein_predictions.tolist())
145
 
146
  # Create DataFrame
 
155
 
156
  return prediction_df
157
 
158
+ def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_category, threshold: float = 0.0):
159
  all_predictions = []
160
 
161
  # Convert single protein ID to list if necessary
 
186
 
187
  for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths):
188
  print(f'Generating predictions for {go_cat}...')
189
+
190
+ # Build candidate edges for inference (do NOT modify graph edges)
191
+ edge_label_index = _build_edge_label_index(heterodata, protein_ids, go_cat)
192
 
193
  # Load model config
194
  with open(model_config_path, 'r') as file:
 
196
 
197
  # Initialize model with configuration
198
  model = ProtHGT(
199
+ heterodata,
200
  hidden_channels=model_config['hidden_channels'][0],
201
  num_heads=model_config['num_heads'],
202
  num_layers=model_config['num_layers'],
 
209
  print(f'Loaded model weights from {model_path}')
210
 
211
  # Generate predictions
212
+ predictions = _generate_predictions(heterodata, model, edge_label_index, go_cat)
213
+ prediction_df = _create_prediction_df(predictions, heterodata, protein_ids, go_cat, threshold=threshold)
214
  all_predictions.append(prediction_df)
215
 
216
  # Clean up memory
217
+ del edge_label_index
218
  del model
219
  del predictions
220
  torch.cuda.empty_cache() # Clear CUDA cache if using GPU