Spaces:
Running
Running
| import os | |
| import site | |
| import glob | |
| import subprocess | |
| def _maybe_fix_torch_execstack() -> None: | |
| """ | |
| Fix for environments that refuse to load shared objects requiring an executable stack | |
| (e.g., newer glibc hardening). Some PyTorch wheels ship libtorch_cpu.so with the | |
| execstack flag set, causing: | |
| ImportError: libtorch_cpu.so: cannot enable executable stack ... | |
| We clear the flag *before* importing torch. | |
| """ | |
| if os.environ.get("PROTHGT_TORCH_EXECSTACK_FIXED") == "1": | |
| return | |
| # patchelf is installed via packages.txt in this repo. | |
| patchelf = "patchelf" | |
| paths = [] | |
| for fn in (getattr(site, "getsitepackages", None), getattr(site, "getusersitepackages", None)): | |
| if fn is None: | |
| continue | |
| try: | |
| p = fn() | |
| if isinstance(p, str): | |
| paths.append(p) | |
| else: | |
| paths.extend(list(p)) | |
| except Exception: | |
| pass | |
| targets = [] | |
| for p in paths: | |
| targets += glob.glob(os.path.join(p, "torch", "lib", "libtorch_cpu.so")) | |
| targets += glob.glob(os.path.join(p, "torch", "lib", "libtorch_python.so")) | |
| for so in sorted(set(targets)): | |
| try: | |
| if os.path.exists(so): | |
| subprocess.run([patchelf, "--clear-execstack", so], check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| except Exception: | |
| # Best-effort; if it fails we'll still try importing torch and surface the real error. | |
| pass | |
| os.environ["PROTHGT_TORCH_EXECSTACK_FIXED"] = "1" | |
| _maybe_fix_torch_execstack() | |
| import torch | |
| from torch.nn import Linear | |
| from torch_geometric.nn import HGTConv, MLP | |
| import pandas as pd | |
| import yaml | |
| from datasets import load_dataset | |
| import gdown | |
| import copy | |
| import json | |
| import gzip | |
| class ProtHGT(torch.nn.Module): | |
| def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout): | |
| super().__init__() | |
| self.lin_dict = torch.nn.ModuleDict() | |
| for node_type in data.node_types: | |
| input_dim = data[node_type].x.size(1) # Get actual input dimension from data | |
| self.lin_dict[node_type] = Linear(input_dim, hidden_channels) | |
| self.convs = torch.nn.ModuleList() | |
| for _ in range(num_layers): | |
| conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum') | |
| self.convs.append(conv) | |
| self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None) | |
| def generate_embeddings(self, x_dict, edge_index_dict): | |
| # Generate updated embeddings through the HGT layers | |
| x_dict = { | |
| node_type: self.lin_dict[node_type](x).relu_() | |
| for node_type, x in x_dict.items() | |
| } | |
| for conv in self.convs: | |
| x_dict = conv(x_dict, edge_index_dict) | |
| return x_dict | |
| def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False): | |
| # Get updated embeddings | |
| x_dict = self.generate_embeddings(x_dict, edge_index_dict) | |
| # Make predictions | |
| row, col = tr_edge_label_index | |
| z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1) | |
| return self.mlp(z).view(-1), x_dict | |
| def _build_edge_label_index(heterodata, protein_ids, go_category): | |
| """ | |
| Build a dense candidate edge_label_index (Protein × GO terms) for inference. | |
| IMPORTANT: Do NOT overwrite heterodata.edge_index_dict here. | |
| Graph edges are used for message passing; candidate edges are only for scoring. | |
| """ | |
| protein_indices = torch.tensor( | |
| [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids], | |
| dtype=torch.long, | |
| ) | |
| n_terms = len(heterodata[go_category]['id_mapping']) | |
| term_indices = torch.arange(n_terms, dtype=torch.long) | |
| row = protein_indices.repeat_interleave(n_terms) | |
| col = term_indices.repeat(len(protein_indices)) | |
| return torch.stack([row, col], dim=0) | |
| def get_available_proteins(name_file='data/name_info.json.gz'): | |
| with gzip.open(name_file, 'rt', encoding='utf-8') as file: | |
| name_info = json.load(file) | |
| return list(name_info['Protein'].keys()) | |
| def _generate_predictions(heterodata, model, edge_label_index, target_type): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| model.eval() | |
| heterodata = heterodata.to(device) | |
| edge_label_index = edge_label_index.to(device) | |
| with torch.no_grad(): | |
| predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, edge_label_index, target_type) | |
| predictions = torch.sigmoid(predictions) | |
| return predictions.cpu() | |
| def _create_prediction_df(predictions, heterodata, protein_ids, go_category, threshold: float = 0.0): | |
| go_category_dict = { | |
| 'GO_term_F': 'Molecular Function', | |
| 'GO_term_P': 'Biological Process', | |
| 'GO_term_C': 'Cellular Component' | |
| } | |
| # Load name information from gzipped file | |
| with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file: | |
| name_info = json.load(file) | |
| id_mapping = heterodata[go_category]['id_mapping'] # dict: GO_id -> index | |
| n_go_terms = len(id_mapping) | |
| # Create lists to store the data | |
| all_proteins = [] | |
| all_protein_names = [] | |
| all_go_terms = [] | |
| all_go_term_names = [] | |
| all_categories = [] | |
| all_probabilities = [] | |
| # Build GO terms list aligned with their numeric indices (critical for correctness) | |
| go_terms = [None] * n_go_terms | |
| for go_id, idx in id_mapping.items(): | |
| go_terms[int(idx)] = go_id | |
| # Process predictions for each protein | |
| for i, protein_id in enumerate(protein_ids): | |
| # Get predictions for this protein | |
| start_idx = i * n_go_terms | |
| end_idx = (i + 1) * n_go_terms | |
| protein_predictions = predictions[start_idx:end_idx] | |
| # Optional pre-filter for performance | |
| if threshold and threshold > 0.0: | |
| keep_mask = protein_predictions >= float(threshold) | |
| if keep_mask.any(): | |
| keep_idx = torch.nonzero(keep_mask, as_tuple=False).view(-1) | |
| protein_predictions = protein_predictions[keep_idx] | |
| else: | |
| continue | |
| else: | |
| keep_idx = torch.arange(n_go_terms) | |
| # Get protein name | |
| protein_name = name_info['Protein'].get(protein_id, protein_id) | |
| # Extend the lists | |
| k = int(protein_predictions.numel()) | |
| all_proteins.extend([protein_id] * k) | |
| all_protein_names.extend([protein_name] * k) | |
| kept_go_ids = [go_terms[int(j)] for j in keep_idx.tolist()] | |
| all_go_terms.extend(kept_go_ids) | |
| all_go_term_names.extend([name_info['GO_term'].get(term_id, term_id) for term_id in kept_go_ids]) | |
| all_categories.extend([go_category_dict[go_category]] * k) | |
| all_probabilities.extend(protein_predictions.tolist()) | |
| # Create DataFrame | |
| prediction_df = pd.DataFrame({ | |
| 'UniProt_ID': all_proteins, | |
| 'Protein': all_protein_names, | |
| 'GO_ID': all_go_terms, | |
| 'GO_term': all_go_term_names, | |
| 'GO_category': all_categories, | |
| 'Probability': all_probabilities | |
| }) | |
| return prediction_df | |
| def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_category, threshold: float = 0.0): | |
| all_predictions = [] | |
| # Convert single protein ID to list if necessary | |
| if isinstance(protein_ids, str): | |
| protein_ids = [protein_ids] | |
| # Load dataset once | |
| # heterodata = load_dataset('HUBioDataLab/ProtHGT-KG', data_files="prothgt-kg.json.gz") | |
| print('Loading data...') | |
| file_id = "18u1o2sm8YjMo9joFw4Ilwvg0-rUU0PXK" | |
| output = "data/prothgt-kg.pt" | |
| if not os.path.exists(output): | |
| try: | |
| url = f"https://drive.google.com/uc?id={file_id}" | |
| print(f"Downloading file from {url}...") | |
| gdown.download(url, output, quiet=False) | |
| print(f"File downloaded to {output}") | |
| except Exception as e: | |
| print(f"Error downloading file: {e}") | |
| raise | |
| else: | |
| print(f"File already exists at {output}") | |
| heterodata = torch.load(output) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths): | |
| print(f'Generating predictions for {go_cat}...') | |
| # Build candidate edges for inference (do NOT modify graph edges) | |
| edge_label_index = _build_edge_label_index(heterodata, protein_ids, go_cat) | |
| # Load model config | |
| with open(model_config_path, 'r') as file: | |
| model_config = yaml.safe_load(file) | |
| # Initialize model with configuration | |
| model = ProtHGT( | |
| heterodata, | |
| hidden_channels=model_config['hidden_channels'][0], | |
| num_heads=model_config['num_heads'], | |
| num_layers=model_config['num_layers'], | |
| mlp_hidden_layers=model_config['hidden_channels'][1], | |
| mlp_dropout=model_config['mlp_dropout'] | |
| ) | |
| # Load model weights | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| print(f'Loaded model weights from {model_path}') | |
| # Generate predictions | |
| predictions = _generate_predictions(heterodata, model, edge_label_index, go_cat) | |
| prediction_df = _create_prediction_df(predictions, heterodata, protein_ids, go_cat, threshold=threshold) | |
| all_predictions.append(prediction_df) | |
| # Clean up memory | |
| del edge_label_index | |
| del model | |
| del predictions | |
| torch.cuda.empty_cache() # Clear CUDA cache if using GPU | |
| # Combine all predictions | |
| final_df = pd.concat(all_predictions, ignore_index=True) | |
| # Clean up | |
| del all_predictions | |
| torch.cuda.empty_cache() | |
| return heterodata, final_df | |