| | import os |
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import re |
| | import math |
| | import torch.nn.functional as F |
| | from tqdm import tqdm |
| | from src.interface.model_interface import MInterface_base |
| | from src.model.pretrain_model_interface import PretrainModelInterface |
| | from scipy.stats import spearmanr |
| | from src.data.esm.sdk.api import ESMProtein |
| | from src.model.pretrain_modules import ( |
| | ESM2Model, ESMC600MModel, ESM3Model, VenusPLMModel, |
| | ProSTModel, ProstT5Model, ProTrekModel, SaPortModel, |
| | ProtT5, DPLMModel, PGLMModel, ANKHBase, ProtGPT2Model |
| | ) |
| |
|
| | class MInterface(MInterface_base): |
| | def __init__(self, model_name=None, loss=None, lr=None, **kargs): |
| | super().__init__() |
| | self.save_hyperparameters() |
| | self.model = PretrainModelInterface( |
| | self.hparams.pretrain_model_name, |
| | task_type=self.hparams.task_type |
| | ) |
| | self.tokenizer = self.model.pretrain_model.get_tokenizer() |
| | self.multimodal_models = [ |
| | ProTrekModel, |
| | SaPortModel |
| | ] |
| | self.mlm_models = [ |
| | ESM2Model, |
| | ESMC600MModel, |
| | ESM3Model, |
| | PGLMModel, |
| | VenusPLMModel, |
| | ProTrekModel, |
| | DPLMModel, |
| | ] |
| | self.start, self.end = 1, -1 |
| | if type(self.model) == ProtGPT2Model: |
| | self.start -= 1 |
| | self._context = { |
| | "test": { |
| | "spearmans": [] |
| | }, |
| | } |
| | os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True) |
| |
|
| | def test_step(self, batch, batch_idx): |
| | dms_id = batch["dms_id"][0] |
| | dms_csv_path = batch["dms_csv_path"][0] |
| | pdb_range = batch["pdb_range"] |
| | target_sequence = batch["target_sequence"][0] |
| | pdb_file_path = batch["pdb_file_path"][0] |
| | dms_df = pd.read_csv(dms_csv_path) |
| | true_dms_scores = dms_df["DMS_score"].tolist() |
| | predict_dms = [] |
| |
|
| | if type(self.model) in self.multimodal_models: |
| | |
| | target_sequence = target_sequence[pdb_range[0]:pdb_range[1]] |
| | offset = pdb_range[0].item() if isinstance(pdb_range[0], torch.Tensor) else pdb_range[0] |
| | else: |
| | offset = 0 |
| |
|
| | if pdb_file_path is not None: |
| | structure = ESMProtein.from_pdb(pdb_file_path) |
| | coordinates = structure.coordinates |
| | else: |
| | coordinates = None |
| |
|
| | if type(self.model.pretrain_model) in self.mlm_models: |
| | sequence = target_sequence |
| | mutations = dms_df["mutant"].tolist() |
| | model = self.model.pretrain_model |
| | tokenizer = self.tokenizer |
| | batch_size = 8 |
| | window_size = 1022 |
| |
|
| | if len(sequence) == 0: |
| | raise ValueError("Empty wildtype sequence provided") |
| | print(f"Working with sequence of length {len(sequence)} using optimized MLM approach") |
| |
|
| | parsed_mutations, unique_positions = [], set() |
| | for mutation in mutations: |
| | if ":" in mutation: |
| | sub_mutations = mutation.split(":") |
| | multi_wt, multi_mt = "", "" |
| | multi_pos, multi_seq_pos, valid_multi = [], [], True |
| | for sub_mut in sub_mutations: |
| | match = re.match(r"([A-Z])(\d+)([A-Z])", sub_mut) |
| | if not match: |
| | print(f"Warning: Could not parse mutation {sub_mut}, skipping") |
| | valid_multi = False |
| | break |
| | wt, pos_str, mt = match.groups() |
| | pos = int(pos_str) |
| | seq_pos = pos - (1 + offset) |
| |
|
| | if seq_pos < 0 or seq_pos >= len(sequence): |
| | print(f"Warning: Position {pos} out of range, skipping") |
| | valid_multi = False |
| | break |
| | if sequence[seq_pos] != wt: |
| | print(f"Warning: Wild-type {wt} at pos {pos} doesn't match sequence {sequence[seq_pos]}, skipping") |
| | valid_multi = False |
| | break |
| |
|
| | multi_wt += wt |
| | multi_mt += mt |
| | multi_pos.append(pos) |
| | multi_seq_pos.append(seq_pos) |
| | unique_positions.add(pos) |
| |
|
| | if valid_multi: |
| | parsed_mutations.append((multi_wt, multi_pos, multi_mt, multi_seq_pos, mutation)) |
| | else: |
| | match = re.match(r"([A-Z])(\d+)([A-Z])", mutation) |
| | if not match: |
| | print(f"Warning: Could not parse mutation {mutation}, skipping") |
| | continue |
| |
|
| | wt, pos_str, mt = match.groups() |
| | pos = int(pos_str) |
| | seq_pos = pos - (1 + offset) |
| |
|
| | if seq_pos < 0 or seq_pos >= len(sequence): |
| | print(f"Warning: Position {pos} out of range, skipping") |
| | continue |
| |
|
| | if sequence[seq_pos] != wt: |
| | print(f"Warning: Wild-type {wt} at pos {pos} doesn't match sequence {sequence[seq_pos]}, skipping") |
| | continue |
| |
|
| | parsed_mutations.append((wt, [pos], mt, [seq_pos], mutation)) |
| | unique_positions.add(pos) |
| | if not parsed_mutations: |
| | print("No valid mutations to score") |
| | predict_dms = [0.0] * len(mutations) |
| | else: |
| | unique_positions = sorted(list(unique_positions)) |
| | print(f"Found {len(unique_positions)} unique mutation positions to pre-compute") |
| |
|
| | aa_to_token = {} |
| | amino_acids = "ACDEFGHIKLMNPQRSTVWY" |
| | for aa in amino_acids: |
| | tokens = tokenizer.encode(aa, add_special_tokens=False) |
| | aa_to_token[aa] = tokens[0] |
| | mask_token_id = tokenizer.mask_token_id |
| |
|
| | position_aa_scores = {} |
| | num_batches = math.ceil(len(unique_positions) / batch_size) |
| | progress_bar = tqdm(total=num_batches, desc="Pre-computing position scores") |
| | for batch_idx in range(0, len(unique_positions), batch_size): |
| | batch_positions = unique_positions[batch_idx:batch_idx + batch_size] |
| | window_groups = {} |
| |
|
| | for pos in batch_positions: |
| | seq_pos = pos - (1 + offset) |
| | if len(sequence) > window_size: |
| | window_half = (window_size) // 2 |
| | start_pos = max(0, seq_pos - window_half) |
| | end_pos = min(len(sequence), start_pos + window_size) |
| | if end_pos == len(sequence): |
| | start_pos = max(0, len(sequence) - (window_size)) |
| | seq_window = sequence[start_pos:end_pos] |
| | rel_pos = seq_pos - start_pos |
| | else: |
| | seq_window = sequence |
| | rel_pos = seq_pos |
| |
|
| | window_key = (seq_window, start_pos if len(sequence) > window_size else 0) |
| | if window_key not in window_groups: |
| | window_groups[window_key] = [] |
| | window_groups[window_key].append((pos, seq_pos, rel_pos)) |
| |
|
| | for (seq_window, window_start), positions_in_window in window_groups.items(): |
| | unique_rel_positions = set(info[2] for info in positions_in_window) |
| | input_items = [] |
| | rel_pos_map = {} |
| |
|
| | for rel_pos in unique_rel_positions: |
| | masked_seq = list(seq_window) |
| | masked_seq[rel_pos] = tokenizer.mask_token |
| | input_items.append({ |
| | "seq": ''.join(masked_seq), |
| | "X": coordinates, |
| | "name": f"masked_pos_{rel_pos}", |
| | "label": 1.0 |
| | }) |
| | rel_pos_map[len(input_items) - 1] = rel_pos |
| |
|
| | with torch.no_grad(): |
| | batch = model.construct_batch(input_items) |
| | outputs = model.forward(batch=batch, return_logits=True) |
| | batch_logits = outputs[:, self.start:self.end, :] |
| |
|
| | for idx, rel_pos in rel_pos_map.items(): |
| | logits = batch_logits[idx, rel_pos, :] |
| | log_probs = torch.log_softmax(logits, dim=-1) |
| | for pos, seq_pos, pos_rel_pos in positions_in_window: |
| | if pos_rel_pos == rel_pos: |
| | if pos not in position_aa_scores: |
| | position_aa_scores[pos] = {} |
| | for aa in amino_acids: |
| | token_id = aa_to_token[aa] |
| | position_aa_scores[pos][aa] = log_probs[token_id].item() |
| |
|
| | if progress_bar is not None: |
| | progress_bar.update(1) |
| |
|
| | if progress_bar is not None: |
| | progress_bar.close() |
| |
|
| | mutation_scores = {} |
| | print("Calculating scores for all mutations using pre-computed values") |
| |
|
| | for wt, pos_list, mt, seq_pos_list, mutation_name in tqdm(parsed_mutations, desc="Scoring mutations"): |
| | cumulative_score = 0.0 |
| | for i, (pos, aa_mt) in enumerate(zip(pos_list, mt)): |
| | aa_wt = wt[i] if i < len(wt) else wt |
| | if pos in position_aa_scores: |
| | wt_score = position_aa_scores[pos][aa_wt] |
| | mt_score = position_aa_scores[pos][aa_mt] |
| | cumulative_score += (mt_score - wt_score) |
| | else: |
| | print(f"Warning: Position {pos} not found in pre-computed scores, mutation {mutation_name} may be incomplete") |
| | mutation_scores[mutation_name] = cumulative_score |
| |
|
| | predict_dms = [mutation_scores.get(mut, 0.0) for mut in mutations] |
| | else: |
| | |
| | batch["max_length"][0] = len(target_sequence) |
| | target_sequence = target_sequence[pdb_range[0]:pdb_range[1]] |
| | offset = pdb_range[0].item() if isinstance(pdb_range[0], torch.Tensor) else pdb_range[0] |
| | if isinstance(self.model.pretrain_model, ProSTModel): |
| | target_sequence = target_sequence[:1022] |
| | self.model.pretrain_model.max_length = len(target_sequence) |
| | mutants = dms_df["mutant"].tolist() |
| | coordinates = ESMProtein.from_pdb(pdb_file_path).coordinates |
| | |
| | wt_input = [{ |
| | "seq": target_sequence, |
| | "X": coordinates, |
| | "name": "wildtype", |
| | "label": 1.0 |
| | }] |
| | with torch.no_grad(): |
| | try: |
| | wt_batch = self.model.pretrain_model.construct_batch(wt_input) |
| | wt_logits = self.model.pretrain_model.forward(batch=wt_batch, return_logits=True).squeeze(0)[self.start:self.end,:] |
| | except Exception: |
| | return None |
| | |
| | mutant_inputs = [] |
| | selected_true_dms_scores = [] |
| | for j, mutant in enumerate(mutants): |
| | mutated_seq = list(target_sequence) |
| | mut_positions = [] |
| | for mut in mutant.split(":"): |
| | wt_res, pos_str, mut_res = mut[0], mut[1:-1], mut[-1] |
| | pos = int(pos_str) - (1 + offset) |
| | if isinstance(self.model.pretrain_model, ProSTModel): |
| | if pos > 1020: |
| | continue |
| | mutated_seq[pos] = mut_res |
| | mut_positions.append(pos) |
| | mutant_inputs.append({ |
| | "seq": ''.join(mutated_seq), |
| | "X": coordinates, |
| | "name": f"mutant", |
| | "label": 1.0, |
| | "mut_positions": mut_positions |
| | }) |
| | selected_true_dms_scores.append(true_dms_scores[j]) |
| | |
| | predict_dms = [] |
| | batch_size = 8 |
| | for i in tqdm(range(0, len(mutant_inputs), batch_size), desc=f"Processing {dms_id} (Similarity)..."): |
| | batch_mutants = mutant_inputs[i:i + batch_size] |
| | batch_scores = selected_true_dms_scores[i:i + batch_size] |
| | mut_positions_batch = [x["mut_positions"] for x in batch_mutants] |
| | with torch.no_grad(): |
| | try: |
| | mut_batch = self.model.pretrain_model.construct_batch(batch_mutants) |
| | mut_logits = self.model.pretrain_model.forward(batch=mut_batch, return_logits=True)[:, self.start:self.end, :] |
| | except Exception as e: |
| | |
| | print(f"Batch {i}-{i+batch_size} failed with error: {e}") |
| | selected_true_dms_scores = selected_true_dms_scores[:i] + selected_true_dms_scores[i + len(batch_mutants):] |
| | continue |
| |
|
| | for j in range(mut_logits.size(0)): |
| | mut_emb = mut_logits[j] |
| | if isinstance(self.model.pretrain_model, ProtGPT2Model): |
| | wt_emb_mean = wt_logits.mean(0) |
| | mut_emb_mean = mut_emb.mean(0) |
| | similarity = -1.0 * local_l2_difference_single(wt_emb_mean, mut_emb_mean, [0], window_size=0) |
| | else: |
| | |
| | mut_pos = mut_positions_batch[j] |
| | |
| | |
| | similarity = -1.0 * local_l2_difference_single(wt_logits, mut_emb, mut_pos, window_size=5) |
| | score = similarity |
| | predict_dms.append(score) |
| |
|
| | assert len(predict_dms) == len(true_dms_scores) |
| | |
| | spearman = spearmanr(np.array(predict_dms), np.array(true_dms_scores)).statistic |
| | log_dict = { |
| | "test_spearman": spearman |
| | } |
| | if not np.isnan(spearman): |
| | self._context['test']['spearmans'].append(spearman) |
| | self.log_dict(log_dict, prog_bar=True, logger=True, on_step=True) |
| | |
| | def on_test_epoch_end(self): |
| | spearmans = np.array(self._context['test']['spearmans']).mean() |
| | metric = { |
| | "avg_spearman": spearmans |
| | } |
| | self.log_dict(metric, prog_bar=True, logger=True, on_epoch=True) |
| |
|
| |
|
| | def local_l2_difference_single(wt_emb, mut_emb, mut_pos_list, window_size=5): |
| | """ |
| | Args: |
| | wt_emb: (L, d) wild-type embedding |
| | mut_emb: (L, d) mutant embedding |
| | mut_pos_list: list of int, 突变位点列表 |
| | window_size: int |
| | |
| | Returns: |
| | score: float (单个样本的局部L2差异) |
| | """ |
| | L, d = wt_emb.shape |
| | window_indices = set() |
| | for pos in mut_pos_list: |
| | start = max(pos - window_size, 0) |
| | end = min(pos + window_size + 1, L) |
| | window_indices.update(range(start, end)) |
| | window_indices = sorted(window_indices) |
| |
|
| | wt_local = wt_emb[window_indices, :] |
| | mut_local = mut_emb[window_indices, :] |
| | delta = mut_local - wt_local |
| | score = torch.norm(delta) |
| | return score.item() |
| |
|