ProtHGT / run_prothgt_app.py
Erva Ulusoy
potential fix for torch import error
89a4f79
import os
import site
import glob
import subprocess
def _maybe_fix_torch_execstack() -> None:
"""
Fix for environments that refuse to load shared objects requiring an executable stack
(e.g., newer glibc hardening). Some PyTorch wheels ship libtorch_cpu.so with the
execstack flag set, causing:
ImportError: libtorch_cpu.so: cannot enable executable stack ...
We clear the flag *before* importing torch.
"""
if os.environ.get("PROTHGT_TORCH_EXECSTACK_FIXED") == "1":
return
# patchelf is installed via packages.txt in this repo.
patchelf = "patchelf"
paths = []
for fn in (getattr(site, "getsitepackages", None), getattr(site, "getusersitepackages", None)):
if fn is None:
continue
try:
p = fn()
if isinstance(p, str):
paths.append(p)
else:
paths.extend(list(p))
except Exception:
pass
targets = []
for p in paths:
targets += glob.glob(os.path.join(p, "torch", "lib", "libtorch_cpu.so"))
targets += glob.glob(os.path.join(p, "torch", "lib", "libtorch_python.so"))
for so in sorted(set(targets)):
try:
if os.path.exists(so):
subprocess.run([patchelf, "--clear-execstack", so], check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except Exception:
# Best-effort; if it fails we'll still try importing torch and surface the real error.
pass
os.environ["PROTHGT_TORCH_EXECSTACK_FIXED"] = "1"
_maybe_fix_torch_execstack()
import torch
from torch.nn import Linear
from torch_geometric.nn import HGTConv, MLP
import pandas as pd
import yaml
from datasets import load_dataset
import gdown
import copy
import json
import gzip
class ProtHGT(torch.nn.Module):
def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout):
super().__init__()
self.lin_dict = torch.nn.ModuleDict()
for node_type in data.node_types:
input_dim = data[node_type].x.size(1) # Get actual input dimension from data
self.lin_dict[node_type] = Linear(input_dim, hidden_channels)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum')
self.convs.append(conv)
self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None)
def generate_embeddings(self, x_dict, edge_index_dict):
# Generate updated embeddings through the HGT layers
x_dict = {
node_type: self.lin_dict[node_type](x).relu_()
for node_type, x in x_dict.items()
}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return x_dict
def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False):
# Get updated embeddings
x_dict = self.generate_embeddings(x_dict, edge_index_dict)
# Make predictions
row, col = tr_edge_label_index
z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1)
return self.mlp(z).view(-1), x_dict
def _build_edge_label_index(heterodata, protein_ids, go_category):
"""
Build a dense candidate edge_label_index (Protein × GO terms) for inference.
IMPORTANT: Do NOT overwrite heterodata.edge_index_dict here.
Graph edges are used for message passing; candidate edges are only for scoring.
"""
protein_indices = torch.tensor(
[heterodata['Protein']['id_mapping'][pid] for pid in protein_ids],
dtype=torch.long,
)
n_terms = len(heterodata[go_category]['id_mapping'])
term_indices = torch.arange(n_terms, dtype=torch.long)
row = protein_indices.repeat_interleave(n_terms)
col = term_indices.repeat(len(protein_indices))
return torch.stack([row, col], dim=0)
def get_available_proteins(name_file='data/name_info.json.gz'):
with gzip.open(name_file, 'rt', encoding='utf-8') as file:
name_info = json.load(file)
return list(name_info['Protein'].keys())
def _generate_predictions(heterodata, model, edge_label_index, target_type):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
heterodata = heterodata.to(device)
edge_label_index = edge_label_index.to(device)
with torch.no_grad():
predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, edge_label_index, target_type)
predictions = torch.sigmoid(predictions)
return predictions.cpu()
def _create_prediction_df(predictions, heterodata, protein_ids, go_category, threshold: float = 0.0):
go_category_dict = {
'GO_term_F': 'Molecular Function',
'GO_term_P': 'Biological Process',
'GO_term_C': 'Cellular Component'
}
# Load name information from gzipped file
with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file:
name_info = json.load(file)
id_mapping = heterodata[go_category]['id_mapping'] # dict: GO_id -> index
n_go_terms = len(id_mapping)
# Create lists to store the data
all_proteins = []
all_protein_names = []
all_go_terms = []
all_go_term_names = []
all_categories = []
all_probabilities = []
# Build GO terms list aligned with their numeric indices (critical for correctness)
go_terms = [None] * n_go_terms
for go_id, idx in id_mapping.items():
go_terms[int(idx)] = go_id
# Process predictions for each protein
for i, protein_id in enumerate(protein_ids):
# Get predictions for this protein
start_idx = i * n_go_terms
end_idx = (i + 1) * n_go_terms
protein_predictions = predictions[start_idx:end_idx]
# Optional pre-filter for performance
if threshold and threshold > 0.0:
keep_mask = protein_predictions >= float(threshold)
if keep_mask.any():
keep_idx = torch.nonzero(keep_mask, as_tuple=False).view(-1)
protein_predictions = protein_predictions[keep_idx]
else:
continue
else:
keep_idx = torch.arange(n_go_terms)
# Get protein name
protein_name = name_info['Protein'].get(protein_id, protein_id)
# Extend the lists
k = int(protein_predictions.numel())
all_proteins.extend([protein_id] * k)
all_protein_names.extend([protein_name] * k)
kept_go_ids = [go_terms[int(j)] for j in keep_idx.tolist()]
all_go_terms.extend(kept_go_ids)
all_go_term_names.extend([name_info['GO_term'].get(term_id, term_id) for term_id in kept_go_ids])
all_categories.extend([go_category_dict[go_category]] * k)
all_probabilities.extend(protein_predictions.tolist())
# Create DataFrame
prediction_df = pd.DataFrame({
'UniProt_ID': all_proteins,
'Protein': all_protein_names,
'GO_ID': all_go_terms,
'GO_term': all_go_term_names,
'GO_category': all_categories,
'Probability': all_probabilities
})
return prediction_df
def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_category, threshold: float = 0.0):
all_predictions = []
# Convert single protein ID to list if necessary
if isinstance(protein_ids, str):
protein_ids = [protein_ids]
# Load dataset once
# heterodata = load_dataset('HUBioDataLab/ProtHGT-KG', data_files="prothgt-kg.json.gz")
print('Loading data...')
file_id = "18u1o2sm8YjMo9joFw4Ilwvg0-rUU0PXK"
output = "data/prothgt-kg.pt"
if not os.path.exists(output):
try:
url = f"https://drive.google.com/uc?id={file_id}"
print(f"Downloading file from {url}...")
gdown.download(url, output, quiet=False)
print(f"File downloaded to {output}")
except Exception as e:
print(f"Error downloading file: {e}")
raise
else:
print(f"File already exists at {output}")
heterodata = torch.load(output)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths):
print(f'Generating predictions for {go_cat}...')
# Build candidate edges for inference (do NOT modify graph edges)
edge_label_index = _build_edge_label_index(heterodata, protein_ids, go_cat)
# Load model config
with open(model_config_path, 'r') as file:
model_config = yaml.safe_load(file)
# Initialize model with configuration
model = ProtHGT(
heterodata,
hidden_channels=model_config['hidden_channels'][0],
num_heads=model_config['num_heads'],
num_layers=model_config['num_layers'],
mlp_hidden_layers=model_config['hidden_channels'][1],
mlp_dropout=model_config['mlp_dropout']
)
# Load model weights
model.load_state_dict(torch.load(model_path, map_location=device))
print(f'Loaded model weights from {model_path}')
# Generate predictions
predictions = _generate_predictions(heterodata, model, edge_label_index, go_cat)
prediction_df = _create_prediction_df(predictions, heterodata, protein_ids, go_cat, threshold=threshold)
all_predictions.append(prediction_df)
# Clean up memory
del edge_label_index
del model
del predictions
torch.cuda.empty_cache() # Clear CUDA cache if using GPU
# Combine all predictions
final_df = pd.concat(all_predictions, ignore_index=True)
# Clean up
del all_predictions
torch.cuda.empty_cache()
return heterodata, final_df