| import torch |
| import numpy as np |
| import random |
| import os |
| import sqlite3 |
| from typing import List, Tuple, Dict, Optional |
| from glob import glob |
| from pandas import read_csv, read_excel |
| from datasets import load_dataset, Dataset |
| from dataclasses import dataclass |
|
|
| try: |
| from utils import print_message |
| from seed_utils import get_global_seed |
| from embedder import get_embedding_filename |
| except ImportError: |
| from ..utils import print_message |
| from ..seed_utils import get_global_seed |
| from ..embedder import get_embedding_filename |
| from .supported_datasets import supported_datasets, standard_data_benchmark, vector_benchmark |
| from .utils import ( |
| AA_SET, |
| CODON_SET, |
| DNA_SET, |
| RNA_SET, |
| NONCANONICAL_AMINO_ACIDS, |
| AMINO_ACID_TO_HUMAN_CODON, |
| NONCANONICAL_ALANINE_CODON, |
| AA_TO_CODON_TOKEN, |
| CODON_TO_AA, |
| DNA_CODON_TO_AA, |
| RNA_CODON_TO_AA, |
| ) |
|
|
|
|
|
|
|
|
| @dataclass |
| class DataArguments: |
| """ |
| Args: |
| data_paths: List[str] |
| paths to the datasets |
| max_length: int |
| max length of sequences |
| trim: bool |
| whether to trim sequences to max_length |
| """ |
| def __init__( |
| self, |
| data_names: List[str], |
| delimiter: str = ',', |
| col_names: List[str] = ['seqs', 'labels'], |
| max_length: int = 1024, |
| trim: bool = False, |
| data_dirs: Optional[List[str]] = [], |
| multi_column: Optional[List[str]] = None, |
| aa_to_dna: bool = False, |
| aa_to_rna: bool = False, |
| dna_to_aa: bool = False, |
| rna_to_aa: bool = False, |
| codon_to_aa: bool = False, |
| aa_to_codon: bool = False, |
| **kwargs |
| ): |
| self.data_names = data_names |
| self.data_dirs = data_dirs |
| self.delimiter = delimiter |
| self.col_names = col_names |
| self.max_length = max_length |
| self.trim = trim |
| self.protein_gym = False |
| self.multi_column = multi_column |
| self.aa_to_dna = aa_to_dna |
| self.aa_to_rna = aa_to_rna |
| self.dna_to_aa = dna_to_aa |
| self.rna_to_aa = rna_to_aa |
| self.codon_to_aa = codon_to_aa |
| self.aa_to_codon = aa_to_codon |
|
|
| if len(data_names) > 0: |
| if data_names[0] == 'standard_benchmark': |
| self.data_paths = [supported_datasets[data_name] for data_name in standard_data_benchmark] |
| elif data_names[0] == 'vector_benchmark': |
| self.data_paths = [supported_datasets[data_name] for data_name in vector_benchmark] |
| else: |
| self.data_paths = [] |
| for data_name in data_names: |
| if data_name == 'protein_gym': |
| |
| self.protein_gym = True |
| continue |
| if data_name in supported_datasets: |
| self.data_paths.append(supported_datasets[data_name]) |
| else: |
| print(f'{data_name} not found in supported datasets') |
| print('We will attempt to load it from huggingface anyways, but this may not work') |
| self.data_paths.append(data_name) |
| else: |
| self.data_paths = [] |
| |
| if data_dirs is not None: |
| for dir in data_dirs: |
| if not os.path.exists(dir): |
| raise FileNotFoundError(f'{dir} does not exist') |
|
|
|
|
| class DataMixin: |
| def __init__(self, data_args: Optional[DataArguments] = None): |
| |
| self._sql = False |
| self._full = False |
| self._max_length = 1024 |
| self._trim = False |
| self._delimiter = ',' |
| self._col_names = ['seqs', 'labels'] |
| self._aa_to_dna = False |
| self._aa_to_rna = False |
| self._dna_to_aa = False |
| self._rna_to_aa = False |
| self._codon_to_aa = False |
| self._aa_to_codon = False |
| self.data_args = data_args |
| self._multi_column = None if data_args is None else getattr(data_args, 'multi_column', None) |
| if data_args is not None: |
| self._aa_to_dna = data_args.aa_to_dna |
| self._aa_to_rna = data_args.aa_to_rna |
| self._dna_to_aa = data_args.dna_to_aa |
| self._rna_to_aa = data_args.rna_to_aa |
| self._codon_to_aa = data_args.codon_to_aa |
| self._aa_to_codon = data_args.aa_to_codon |
|
|
| def _not_regression(self, labels): |
| if isinstance(labels, list): |
| |
| if isinstance(labels[0], list): |
| |
| return all(isinstance(element, (int, float)) and element == int(element) |
| for label in labels for element in label) |
| else: |
| |
| return all(isinstance(label, (int, float)) and label == int(label) |
| for label in labels) |
| else: |
| |
| return all(isinstance(label, (int, float)) and label == int(label) for label in labels) |
|
|
| def _encode_labels(self, labels, tag2id): |
| return [torch.tensor([tag2id[tag] for tag in doc], dtype=torch.long) for doc in labels] |
|
|
| def _label_type_checker(self, labels): |
| ex = labels[0] |
| assert len(labels) > 0, f'Labels is empty: {labels}' |
| if self._not_regression(labels): |
| if isinstance(ex, list): |
| label_type = 'multilabel' |
| elif isinstance(ex, int) or isinstance(ex, float): |
| label_type = 'singlelabel' |
| elif isinstance(ex, str): |
| label_type = 'string' |
| else: |
| label_type = 'regression' |
| return label_type |
|
|
| def _is_sigmoid_regression(self, labels) -> bool: |
| """Heuristic: labels within [0, 1] and cover the range approximately. |
| Uses 10-bin histogram coverage and span threshold. |
| """ |
| arr = [] |
| for label in labels: |
| try: |
| arr.extend(label) |
| except: |
| arr.append(label) |
| arr = np.array(arr, dtype=float).flatten() |
|
|
| min_val, max_val = float(arr.min()), float(arr.max()) |
| cond1 = min_val > 0.0 - 1e-6 and max_val < 1.0 + 1e-6 |
|
|
| |
| cond2 = (max_val - min_val) > 0.75 |
|
|
| |
| hist, _ = np.histogram(arr, bins=10, range=(0.0, 1.0)) |
| cond3 = int((hist > 0).sum()) >= 7 |
|
|
| sigmoid_regression_status = cond1 and cond2 and cond3 |
| return sigmoid_regression_status |
|
|
| def _select_from_sql(self, c, seq, cast_to_torch=True): |
| c.execute("SELECT embedding FROM embeddings WHERE sequence = ?", (seq,)) |
| embedding = np.frombuffer(c.fetchone()[0], dtype=np.float32).reshape(1, -1) |
| if self._full: |
| embedding = embedding.reshape(len(seq), -1) |
| if cast_to_torch: |
| embedding = torch.tensor(embedding) |
| return embedding |
|
|
| def _select_from_pth(self, emb_dict, seq, cast_to_np=False): |
| embedding = emb_dict[seq].reshape(1, -1) |
| if self._full: |
| embedding = embedding.reshape(len(seq), -1) |
| if cast_to_np: |
| embedding = embedding.numpy() |
| return embedding |
|
|
| def _labels_to_numpy(self, labels): |
| if isinstance(labels[0], list): |
| return np.array(labels).flatten() |
| else: |
| return np.array([labels]).flatten() |
|
|
| def _random_order(self, seq_a, seq_b): |
| if random.random() < 0.5: |
| return seq_a, seq_b |
| else: |
| return seq_b, seq_a |
|
|
| def _truncate_pairs(self, ex): |
| |
| seq_a, seq_b = ex['SeqA'], ex['SeqB'] |
| trunc_a, trunc_b = seq_a, seq_b |
| while len(trunc_a) + len(trunc_b) > self._max_length: |
| if len(trunc_a) > len(trunc_b): |
| trunc_a = trunc_a[:-1] |
| else: |
| trunc_b = trunc_b[:-1] |
| ex['SeqA'] = trunc_a |
| ex['SeqB'] = trunc_b |
| return ex |
|
|
| def _active_translation_mode(self): |
| mode_to_flag = { |
| 'aa_to_dna': self._aa_to_dna, |
| 'aa_to_rna': self._aa_to_rna, |
| 'dna_to_aa': self._dna_to_aa, |
| 'rna_to_aa': self._rna_to_aa, |
| 'codon_to_aa': self._codon_to_aa, |
| 'aa_to_codon': self._aa_to_codon, |
| } |
| active_modes = [mode for mode, enabled in mode_to_flag.items() if enabled] |
| assert len(active_modes) <= 1, f'Only one translation mode can be enabled at a time, found: {active_modes}' |
| return active_modes[0] if len(active_modes) == 1 else None |
|
|
| def _assert_characters_in_set(self, seq, allowed_chars, mode): |
| bad_chars = sorted({char for char in seq if char.upper() not in allowed_chars}) |
| assert len(bad_chars) == 0, f'Invalid characters for {mode}: {bad_chars}.' |
|
|
| def _validate_translated_output(self, translated_seq, allowed_chars, mode): |
| bad_chars = sorted({char for char in translated_seq if char not in allowed_chars}) |
| assert len(bad_chars) == 0, f'Translation output for {mode} contains unexpected characters: {bad_chars}.' |
|
|
| def _normalize_aa_for_nucleotide_translation(self, seq): |
| canonical_aas = set(AMINO_ACID_TO_HUMAN_CODON.keys()) |
| normalized = [] |
| for residue in seq: |
| residue = residue.upper() |
| if residue in canonical_aas: |
| normalized.append(residue) |
| else: |
| normalized.append('X') |
| return ''.join(normalized) |
|
|
| def _translate_aa_to_dna(self, seq): |
| seq = self._normalize_aa_for_nucleotide_translation(seq) |
| dna_codons = [] |
| for residue in seq: |
| residue = residue.upper() |
| if residue in AMINO_ACID_TO_HUMAN_CODON: |
| dna_codons.append(AMINO_ACID_TO_HUMAN_CODON[residue]) |
| elif residue in NONCANONICAL_AMINO_ACIDS: |
| dna_codons.append(NONCANONICAL_ALANINE_CODON) |
| else: |
| raise AssertionError(f'Unexpected amino acid token "{residue}" while converting aa_to_dna.') |
| translated = ''.join(dna_codons) |
| self._validate_translated_output(translated, DNA_SET, 'aa_to_dna') |
| return translated |
|
|
| def _translate_aa_to_rna(self, seq): |
| dna_translated = self._translate_aa_to_dna(seq) |
| translated = dna_translated.replace('T', 'U') |
| self._validate_translated_output(translated, RNA_SET, 'aa_to_rna') |
| return translated |
|
|
| def _translate_dna_to_aa(self, seq): |
| dna_seq = seq.upper() |
| self._assert_characters_in_set(dna_seq, DNA_SET, 'dna_to_aa') |
| assert len(dna_seq) % 3 == 0, f'dna_to_aa requires sequence length multiple of 3, got {len(dna_seq)}.' |
| aa_seq = [] |
| for idx in range(0, len(dna_seq), 3): |
| codon = dna_seq[idx:idx + 3] |
| assert codon in DNA_CODON_TO_AA, f'Unknown DNA codon for dna_to_aa: {codon}' |
| translated_char = DNA_CODON_TO_AA[codon] |
| if translated_char != '*': |
| aa_seq.append(translated_char) |
| translated = ''.join(aa_seq) |
| self._validate_translated_output(translated, AA_SET - {'*'}, 'dna_to_aa') |
| return translated |
|
|
| def _translate_rna_to_aa(self, seq): |
| rna_seq = seq.upper() |
| self._assert_characters_in_set(rna_seq, RNA_SET, 'rna_to_aa') |
| assert len(rna_seq) % 3 == 0, f'rna_to_aa requires sequence length multiple of 3, got {len(rna_seq)}.' |
| aa_seq = [] |
| for idx in range(0, len(rna_seq), 3): |
| codon = rna_seq[idx:idx + 3] |
| assert codon in RNA_CODON_TO_AA, f'Unknown RNA codon for rna_to_aa: {codon}' |
| translated_char = RNA_CODON_TO_AA[codon] |
| if translated_char != '*': |
| aa_seq.append(translated_char) |
| translated = ''.join(aa_seq) |
| self._validate_translated_output(translated, AA_SET - {'*'}, 'rna_to_aa') |
| return translated |
|
|
| def _translate_codon_to_aa(self, seq): |
| aa_seq = [] |
| for token in seq: |
| assert token in CODON_TO_AA, f'Unknown codon token for codon_to_aa: {token}' |
| translated_char = CODON_TO_AA[token] |
| if translated_char != '*': |
| aa_seq.append(translated_char) |
| translated = ''.join(aa_seq) |
| self._validate_translated_output(translated, AA_SET - {'*'}, 'codon_to_aa') |
| return translated |
|
|
| def _translate_aa_to_codon(self, seq): |
| codon_tokens = [] |
| for residue in seq: |
| residue = residue.upper() |
| if residue in AA_TO_CODON_TOKEN: |
| codon_tokens.append(AA_TO_CODON_TOKEN[residue]) |
| elif residue in NONCANONICAL_AMINO_ACIDS: |
| codon_tokens.append(AA_TO_CODON_TOKEN['A']) |
| else: |
| raise AssertionError(f'Unexpected amino acid token "{residue}" while converting aa_to_codon.') |
| translated = ''.join(codon_tokens) |
| self._validate_translated_output(translated, CODON_SET, 'aa_to_codon') |
| return translated |
|
|
| def _translate_sequence_for_mode(self, seq, mode): |
| if mode == 'aa_to_dna': |
| return self._translate_aa_to_dna(seq) |
| if mode == 'aa_to_rna': |
| return self._translate_aa_to_rna(seq) |
| if mode == 'dna_to_aa': |
| return self._translate_dna_to_aa(seq) |
| if mode == 'rna_to_aa': |
| return self._translate_rna_to_aa(seq) |
| if mode == 'codon_to_aa': |
| return self._translate_codon_to_aa(seq) |
| if mode == 'aa_to_codon': |
| return self._translate_aa_to_codon(seq) |
| raise AssertionError(f'Unsupported translation mode: {mode}') |
|
|
| def _find_first_present_column(self, available_columns, candidates_ordered): |
| """Return the first column from candidates_ordered that exists in available_columns (case-insensitive).""" |
| lowercase_to_actual = {col.lower(): col for col in available_columns} |
| for candidate in candidates_ordered: |
| actual = lowercase_to_actual.get(candidate.lower()) |
| if actual is not None: |
| return actual |
| raise KeyError(f"None of the candidate columns were found. Candidates: {candidates_ordered}. Available: {available_columns}") |
|
|
| def _is_ppi_from_columns(self, available_columns): |
| """Detect if dataset contains paired sequence inputs (SeqA/SeqB variants).""" |
| lowercase_columns = set(col.lower() for col in available_columns) |
| base_candidates = ['seqs', 'seq', 'sequence', 'sequences'] |
| for base in base_candidates: |
| if (base + 'a') in lowercase_columns and (base + 'b') in lowercase_columns: |
| return True |
| return False |
|
|
| def _find_ppi_sequence_columns(self, available_columns): |
| """Return the actual column names for A and B sequences in PPI datasets based on priority.""" |
| lowercase_to_actual = {col.lower(): col for col in available_columns} |
| |
| specific_pairs = [ |
| ('SeqA', 'SeqB'), |
| ('seqa', 'seqb'), |
| ('SeqsA', 'SeqsB'), |
| ] |
| for cand_a, cand_b in specific_pairs: |
| a_actual = lowercase_to_actual.get(cand_a.lower()) |
| b_actual = lowercase_to_actual.get(cand_b.lower()) |
| if a_actual is not None and b_actual is not None: |
| return a_actual, b_actual |
|
|
| |
| base_candidates = ['seqs', 'seq', 'sequence', 'sequences'] |
| for base in base_candidates: |
| a_key = (base + 'a').lower() |
| b_key = (base + 'b').lower() |
| a_actual = lowercase_to_actual.get(a_key) |
| b_actual = lowercase_to_actual.get(b_key) |
| if a_actual is not None and b_actual is not None: |
| return a_actual, b_actual |
|
|
| raise KeyError(f"Could not find paired sequence columns for PPI. Available: {available_columns}") |
|
|
| def _is_missing_value(self, v): |
| if v is None: |
| return True |
| |
| try: |
| if isinstance(v, float) and np.isnan(v): |
| return True |
| except Exception: |
| pass |
| |
| if isinstance(v, (list, tuple, np.ndarray)): |
| for el in v: |
| if el is None: |
| return True |
| if isinstance(el, float) and np.isnan(el): |
| return True |
| return False |
|
|
| def process_datasets( |
| self, |
| hf_datasets: List[Tuple[Dataset, Dataset, Dataset, bool]], |
| data_names: List[str], |
| )-> Tuple[Dict[str, Tuple[Dataset, Dataset, Dataset, int, str, bool]], List[str]]: |
| max_length = self._max_length |
| datasets, all_seqs = {}, set() |
| translation_mode = self._active_translation_mode() |
| for dataset, data_name in zip(hf_datasets, data_names): |
| print_message(f'Processing {data_name}') |
| train_set, valid_set, test_set, ppi = dataset |
| print(train_set) |
| print(valid_set) |
| print(test_set) |
| |
| |
| before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set) |
| if ppi: |
| train_set = train_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels']))) |
| valid_set = valid_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels']))) |
| test_set = test_set.filter(lambda x: not (self._is_missing_value(x['SeqA']) or self._is_missing_value(x['SeqB']) or self._is_missing_value(x['labels']))) |
| elif self.data_args.multi_column: |
| cols = self.data_args.multi_column |
| |
| for col in cols: |
| assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}" |
|
|
| def _filter_row(x): |
| return (not self._is_missing_value(x['labels'])) and all(not self._is_missing_value(x[col]) for col in cols) |
|
|
| train_set = train_set.filter(_filter_row) |
| valid_set = valid_set.filter(_filter_row) |
| test_set = test_set.filter(_filter_row) |
| else: |
| train_set = train_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels']))) |
| valid_set = valid_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels']))) |
| test_set = test_set.filter(lambda x: not (self._is_missing_value(x['seqs']) or self._is_missing_value(x['labels']))) |
| if any([ |
| len(train_set) != before_train, |
| len(valid_set) != before_valid, |
| len(test_set) != before_test, |
| ]): |
| print_message( |
| f"Removed None / NaN rows - train: {before_train - len(train_set)}, valid: {before_valid - len(valid_set)}, test: {before_test - len(test_set)}" |
| ) |
|
|
| |
| if translation_mode is None: |
| if ppi: |
| train_set = train_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET), |
| 'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)}) |
| valid_set = valid_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET), |
| 'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)}) |
| test_set = test_set.map(lambda x: {'SeqA': ''.join(aa for aa in x['SeqA'] if aa in AA_SET), |
| 'SeqB': ''.join(aa for aa in x['SeqB'] if aa in AA_SET)}) |
| elif self.data_args.multi_column: |
| cols = self.data_args.multi_column |
| for col in cols: |
| train_set = train_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)}) |
| valid_set = valid_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)}) |
| test_set = test_set.map(lambda x, _col=col: {_col: ''.join(aa for aa in x[_col] if aa in AA_SET)}) |
| else: |
| train_set = train_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)}) |
| valid_set = valid_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)}) |
| test_set = test_set.map(lambda x: {'seqs': ''.join(aa for aa in x['seqs'] if aa in AA_SET)}) |
|
|
| |
| before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set) |
| if ppi: |
| train_set = train_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0) |
| valid_set = valid_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0) |
| test_set = test_set.filter(lambda x: len(x['SeqA']) > 0 and len(x['SeqB']) > 0) |
| elif self.data_args.multi_column: |
| cols = self.data_args.multi_column |
| train_set = train_set.filter(lambda x: all(len(x[col]) > 0 for col in cols)) |
| valid_set = valid_set.filter(lambda x: all(len(x[col]) > 0 for col in cols)) |
| test_set = test_set.filter(lambda x: all(len(x[col]) > 0 for col in cols)) |
| else: |
| train_set = train_set.filter(lambda x: len(x['seqs']) > 0) |
| valid_set = valid_set.filter(lambda x: len(x['seqs']) > 0) |
| test_set = test_set.filter(lambda x: len(x['seqs']) > 0) |
|
|
| if any([ |
| len(train_set) != before_train, |
| len(valid_set) != before_valid, |
| len(test_set) != before_test, |
| ]): |
| print_message( |
| f"Removed length 0 rows - train: {before_train - len(train_set)}, valid: {before_valid - len(valid_set)}, test: {before_test - len(test_set)}" |
| ) |
|
|
| |
| before_train, before_valid, before_test = len(train_set), len(valid_set), len(test_set) |
| if self._trim: |
| if ppi: |
| train_set = train_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length) |
| valid_set = valid_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length) |
| test_set = test_set.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= max_length) |
| elif self.data_args.multi_column: |
| cols = self.data_args.multi_column |
| train_set = train_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols)) |
| valid_set = valid_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols)) |
| test_set = test_set.filter(lambda x: all(len(x[col]) <= max_length for col in cols)) |
| else: |
| train_set = train_set.filter(lambda x: len(x['seqs']) <= max_length) |
| valid_set = valid_set.filter(lambda x: len(x['seqs']) <= max_length) |
| test_set = test_set.filter(lambda x: len(x['seqs']) <= max_length) |
|
|
| else: |
| if ppi: |
| train_set = train_set.map(self._truncate_pairs) |
| valid_set = valid_set.map(self._truncate_pairs) |
| test_set = test_set.map(self._truncate_pairs) |
| elif self.data_args.multi_column: |
| cols = self.data_args.multi_column |
| for col in cols: |
| train_set = train_set.map(lambda x, _col=col: { _col: x[_col][:max_length] }) |
| valid_set = valid_set.map(lambda x, _col=col: { _col: x[_col][:max_length] }) |
| test_set = test_set.map(lambda x, _col=col: { _col: x[_col][:max_length] }) |
| else: |
| train_set = train_set.map(lambda x: {'seqs': x['seqs'][:max_length]}) |
| valid_set = valid_set.map(lambda x: {'seqs': x['seqs'][:max_length]}) |
| test_set = test_set.map(lambda x: {'seqs': x['seqs'][:max_length]}) |
|
|
| if any([ |
| len(train_set) != before_train, |
| len(valid_set) != before_valid, |
| len(test_set) != before_test, |
| ]): |
| print_message( |
| f"Trimmed rows - train: {(before_train - len(train_set)) / before_train * 100:.2f}%, \ |
| valid: {(before_valid - len(valid_set)) / before_valid * 100:.2f}%, \ |
| test: {(before_test - len(test_set)) / before_test * 100:.2f}%" |
| ) |
|
|
| |
| if translation_mode is not None: |
| if ppi: |
| train_set = train_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode), |
| 'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)}) |
| valid_set = valid_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode), |
| 'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)}) |
| test_set = test_set.map(lambda x: {'SeqA': self._translate_sequence_for_mode(x['SeqA'], translation_mode), |
| 'SeqB': self._translate_sequence_for_mode(x['SeqB'], translation_mode)}) |
| elif self.data_args.multi_column: |
| cols = self.data_args.multi_column |
| for col in cols: |
| train_set = train_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)}) |
| valid_set = valid_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)}) |
| test_set = test_set.map(lambda x, _col=col: {_col: self._translate_sequence_for_mode(x[_col], translation_mode)}) |
| else: |
| train_set = train_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)}) |
| valid_set = valid_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)}) |
| test_set = test_set.map(lambda x: {'seqs': self._translate_sequence_for_mode(x['seqs'], translation_mode)}) |
| print_message(f"Translated sequences with mode {translation_mode} (post-trim/truncate).") |
|
|
| |
| if ppi: |
| all_seqs.update(list(train_set['SeqA']) + list(train_set['SeqB'])) |
| all_seqs.update(list(valid_set['SeqA']) + list(valid_set['SeqB'])) |
| all_seqs.update(list(test_set['SeqA']) + list(test_set['SeqB'])) |
| elif self.data_args.multi_column: |
| cols = self.data_args.multi_column |
| for col in cols: |
| all_seqs.update(list(train_set[col])) |
| all_seqs.update(list(valid_set[col])) |
| all_seqs.update(list(test_set[col])) |
| else: |
| all_seqs.update(list(train_set['seqs'])) |
| all_seqs.update(list(valid_set['seqs'])) |
| all_seqs.update(list(test_set['seqs'])) |
|
|
| |
| check_labels = list(valid_set['labels']) |
| label_type = self._label_type_checker(check_labels) |
|
|
| if label_type == 'string': |
| example = list(valid_set['labels'])[0] |
| try: |
| import ast |
| new_ex = ast.literal_eval(example) |
| if isinstance(new_ex, list): |
| label_type = 'multilabel' |
| train_set = train_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])}) |
| valid_set = valid_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])}) |
| test_set = test_set.map(lambda ex: {'labels': ast.literal_eval(ex['labels'])}) |
| except: |
| label_type = 'string' |
|
|
| if label_type == 'string': |
| train_labels = list(train_set['labels']) |
| unique_tags = set(tag for doc in train_labels for tag in doc) |
| tag2id = {tag: id for id, tag in enumerate(sorted(unique_tags))} |
| |
| train_set = train_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)}) |
| valid_set = valid_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)}) |
| test_set = test_set.map(lambda ex: {'labels': self._encode_labels(ex['labels'], tag2id=tag2id)}) |
| label_type = 'tokenwise' |
| num_labels = len(unique_tags) |
| else: |
| if label_type == 'regression': |
| |
| if self._is_sigmoid_regression(list(train_set['labels'])): |
| label_type = 'sigmoid_regression' |
| num_labels = 1 |
| else: |
| try: |
| train_labels_list = list(train_set['labels']) |
| num_labels = len(train_labels_list[0]) |
| except: |
| unique = np.unique(list(train_set['labels'])) |
| max_label = max(unique) |
| full_list = np.arange(0, max_label+1) |
| num_labels = len(full_list) |
| datasets[data_name] = (train_set, valid_set, test_set, num_labels, label_type, ppi) |
|
|
| print(f'Label type: {label_type}') |
| print(f'Number of labels: {num_labels}') |
|
|
| all_seqs = list(all_seqs) |
| all_seqs = sorted(all_seqs, key=len, reverse=True) |
| return datasets, all_seqs |
|
|
| def get_data(self): |
| """ |
| Supports .csv, .tsv, .txt |
| TODO fasta, fa, fna, etc. |
| """ |
| datasets, data_names = [], [] |
| label_candidates = ['labels', 'label', 'Labels', 'Label'] |
| seq_candidates = ['seqs', 'Seqs', 'seq', 'Seq', 'sequence', 'Sequence', 'sequences', 'Sequences'] |
|
|
| for data_path in self.data_args.data_paths: |
| data_name = data_path.split('/')[-1] |
| print_message(f'Loading {data_name}') |
| dataset = load_dataset(data_path) |
| if 'inverse' in data_name.lower(): |
| dataset = dataset.rename_columns({'seqs': 'labels', 'labels': 'seqs'}) |
| ppi = 'SeqA' in dataset['train'].column_names |
| |
| if not ppi: |
| ppi = self._is_ppi_from_columns(dataset['train'].column_names) |
| print_message(f'PPI (or dual sequence input dataset): {ppi}') |
|
|
| |
| assert 'train' in dataset, f'{data_name} does not have a train set' |
| assert 'valid' in dataset or 'test' in dataset, f'{data_name} does not have a valid or test set, needs at least one' |
| |
| if 'valid' not in dataset: |
| seed = get_global_seed() if get_global_seed() is not None else 42 |
| train_set = dataset['train'] |
| train_valid_set = train_set.train_test_split(test_size=0.1, seed=seed + 1) |
| train_set = train_valid_set['train'] |
| valid_set = train_valid_set['test'] |
| test_set = dataset['test'] |
| print_message(f'{data_name} does not have a valid set, created a 10% validation set') |
| elif 'test' not in dataset: |
| seed = get_global_seed() if get_global_seed() is not None else 42 |
| train_set = dataset['train'] |
| train_test_set = train_set.train_test_split(test_size=0.1, seed=seed + 2) |
| test_set = train_test_set['test'] |
| train_set = train_test_set['train'] |
| valid_set = dataset['valid'] |
| print_message(f'{data_name} does not have a test set, created a 10% test set') |
| else: |
| train_set, valid_set, test_set = dataset['train'], dataset['valid'], dataset['test'] |
| print_message(f'{data_name} has a valid and test set') |
|
|
| print_message(f'Train set: {len(train_set)}, Valid set: {len(valid_set)}, Test set: {len(test_set)}') |
| if ppi: |
| |
| print('Standardizing PPI column names') |
| try: |
| a_col, b_col = self._find_ppi_sequence_columns(train_set.column_names) |
| except KeyError: |
| |
| try: |
| a_col, b_col = self._find_ppi_sequence_columns(valid_set.column_names) |
| except KeyError: |
| a_col, b_col = self._find_ppi_sequence_columns(test_set.column_names) |
| |
| try: |
| lbl_col = self._find_first_present_column(train_set.column_names, label_candidates) |
| except KeyError: |
| try: |
| lbl_col = self._find_first_present_column(valid_set.column_names, label_candidates) |
| except KeyError: |
| lbl_col = self._find_first_present_column(test_set.column_names, label_candidates) |
|
|
| train_set = train_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| valid_set = valid_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| test_set = test_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
|
|
| print('Removing extras') |
| train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| else: |
| print('Standardizing column names') |
| use_multi = self.data_args.multi_column is not None |
| if not use_multi: |
| try: |
| seq_col = self._find_first_present_column(train_set.column_names, seq_candidates) |
| except KeyError: |
| try: |
| seq_col = self._find_first_present_column(valid_set.column_names, seq_candidates) |
| except KeyError: |
| seq_col = self._find_first_present_column(test_set.column_names, seq_candidates) |
|
|
| try: |
| label_col = self._find_first_present_column(train_set.column_names, label_candidates) |
| except KeyError: |
| try: |
| label_col = self._find_first_present_column(valid_set.column_names, label_candidates) |
| except KeyError: |
| label_col = self._find_first_present_column(test_set.column_names, label_candidates) |
|
|
| |
| train_set = train_set.rename_columns({label_col: 'labels'}) |
| valid_set = valid_set.rename_columns({label_col: 'labels'}) |
| test_set = test_set.rename_columns({label_col: 'labels'}) |
|
|
| if not use_multi: |
| train_set = train_set.rename_columns({seq_col: 'seqs'}) |
| valid_set = valid_set.rename_columns({seq_col: 'seqs'}) |
| test_set = test_set.rename_columns({seq_col: 'seqs'}) |
| |
| print('Removing extras') |
| train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['seqs', 'labels']]) |
| valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['seqs', 'labels']]) |
| test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['seqs', 'labels']]) |
| else: |
| |
| for col in self.data_args.multi_column: |
| assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}" |
| |
| keep_cols = set(self.data_args.multi_column + ['labels']) |
| train_set = train_set.remove_columns([col for col in train_set.column_names if col not in keep_cols]) |
| valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in keep_cols]) |
| test_set = test_set.remove_columns([col for col in test_set.column_names if col not in keep_cols]) |
|
|
| datasets.append((train_set, valid_set, test_set, ppi)) |
| data_names.append(data_name) |
|
|
| for data_dir in self.data_args.data_dirs: |
| |
| data_name = data_dir.split ('/')[-1] |
| |
| ppi = 'ppi' in data_dir.lower() |
| train_path = glob(os.path.join(data_dir, 'train.*'))[0] |
| valid_path = glob(os.path.join(data_dir, 'valid.*'))[0] |
| test_path = glob(os.path.join(data_dir, 'test.*'))[0] |
| if '.xlsx' in train_path: |
| train_set = read_excel(train_path) |
| valid_set = read_excel(valid_path) |
| test_set = read_excel(test_path) |
| else: |
| train_set = read_csv(train_path, delimiter=self._delimiter) |
| valid_set = read_csv(valid_path, delimiter=self._delimiter) |
| test_set = read_csv(test_path, delimiter=self._delimiter) |
|
|
| train_set = Dataset.from_pandas(train_set) |
| valid_set = Dataset.from_pandas(valid_set) |
| test_set = Dataset.from_pandas(test_set) |
|
|
| |
| if not ppi: |
| ppi = self._is_ppi_from_columns(train_set.column_names) |
|
|
| if ppi: |
| print('Standardizing PPI column names') |
| try: |
| a_col, b_col = self._find_ppi_sequence_columns(train_set.column_names) |
| except KeyError: |
| try: |
| a_col, b_col = self._find_ppi_sequence_columns(valid_set.column_names) |
| except KeyError: |
| a_col, b_col = self._find_ppi_sequence_columns(test_set.column_names) |
|
|
| try: |
| lbl_col = self._find_first_present_column(train_set.column_names, label_candidates) |
| except KeyError: |
| try: |
| lbl_col = self._find_first_present_column(valid_set.column_names, label_candidates) |
| except KeyError: |
| lbl_col = self._find_first_present_column(test_set.column_names, label_candidates) |
|
|
| train_set = train_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| valid_set = valid_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
| test_set = test_set.rename_columns({a_col: 'SeqA', b_col: 'SeqB', lbl_col: 'labels'}) |
|
|
| print('Removing extras') |
| train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['SeqA', 'SeqB', 'labels']]) |
| else: |
| print('Standardizing column names') |
| use_multi = self.data_args.multi_column is not None |
| if not use_multi: |
| try: |
| seq_col = self._find_first_present_column(train_set.column_names, seq_candidates) |
| except KeyError: |
| try: |
| seq_col = self._find_first_present_column(valid_set.column_names, seq_candidates) |
| except KeyError: |
| seq_col = self._find_first_present_column(test_set.column_names, seq_candidates) |
|
|
| try: |
| label_col = self._find_first_present_column(train_set.column_names, label_candidates) |
| except KeyError: |
| try: |
| label_col = self._find_first_present_column(valid_set.column_names, label_candidates) |
| except KeyError: |
| label_col = self._find_first_present_column(test_set.column_names, label_candidates) |
|
|
| |
| train_set = train_set.rename_columns({label_col: 'labels'}) |
| valid_set = valid_set.rename_columns({label_col: 'labels'}) |
| test_set = test_set.rename_columns({label_col: 'labels'}) |
|
|
| if not use_multi: |
| train_set = train_set.rename_columns({seq_col: 'seqs'}) |
| valid_set = valid_set.rename_columns({seq_col: 'seqs'}) |
| test_set = test_set.rename_columns({seq_col: 'seqs'}) |
| |
| print('Removing extras') |
| train_set = train_set.remove_columns([col for col in train_set.column_names if col not in ['seqs', 'labels']]) |
| valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in ['seqs', 'labels']]) |
| test_set = test_set.remove_columns([col for col in test_set.column_names if col not in ['seqs', 'labels']]) |
| else: |
| |
| for col in self.data_args.multi_column: |
| assert col in train_set.column_names or col in valid_set.column_names or col in test_set.column_names, f"Column {col} not found in dataset {data_name}" |
| |
| keep_cols = set(self.data_args.multi_column + ['labels']) |
| train_set = train_set.remove_columns([col for col in train_set.column_names if col not in keep_cols]) |
| valid_set = valid_set.remove_columns([col for col in valid_set.column_names if col not in keep_cols]) |
| test_set = test_set.remove_columns([col for col in test_set.column_names if col not in keep_cols]) |
|
|
| datasets.append((train_set, valid_set, test_set, ppi)) |
| data_names.append(data_name) |
|
|
| return self.process_datasets(hf_datasets=datasets, data_names=data_names) |
|
|
| def get_embedding_dim_sql(self, save_path, test_seq, tokenizer): |
| import sqlite3 |
| test_seq_len = len(tokenizer(test_seq, return_tensors='pt')['input_ids'][0]) |
| |
| with sqlite3.connect(save_path) as conn: |
| c = conn.cursor() |
| c.execute("SELECT embedding FROM embeddings WHERE sequence = ?", (test_seq,)) |
| test_embedding = c.fetchone()[0] |
| test_embedding = torch.tensor(np.frombuffer(test_embedding, dtype=np.float32).reshape(1, -1)) |
| if self._full: |
| test_embedding = test_embedding.reshape(test_seq_len, -1) |
| embedding_dim = test_embedding.shape[-1] |
| return embedding_dim |
|
|
| def get_embedding_dim_pth(self, emb_dict, test_seq, tokenizer): |
| test_seq_len = len(tokenizer(test_seq, return_tensors='pt')['input_ids'][0]) |
| test_embedding = emb_dict[test_seq] |
| if self._full: |
| test_embedding = test_embedding.reshape(test_seq_len, -1) |
| embedding_dim = test_embedding.shape[-1] |
| return embedding_dim |
|
|
| def build_vector_numpy_dataset_from_embeddings( |
| self, |
| model_name, |
| train_seqs, |
| valid_seqs, |
| test_seqs, |
| ): |
| save_dir = self.embedding_args.embedding_save_dir |
| train_array, valid_array, test_array = [], [], [] |
| |
| pooling_types = self.embedding_args.pooling_types |
| if self._sql: |
| import sqlite3 |
| filename = get_embedding_filename(model_name, self._full, pooling_types, 'db') |
| save_path = os.path.join(save_dir, filename) |
| with sqlite3.connect(save_path) as conn: |
| c = conn.cursor() |
| for seq in train_seqs: |
| embedding = self._select_from_sql(c, seq, cast_to_torch=False) |
| train_array.append(embedding) |
|
|
| for seq in valid_seqs: |
| embedding = self._select_from_sql(c, seq, cast_to_torch=False) |
| valid_array.append(embedding) |
|
|
| for seq in test_seqs: |
| embedding = self._select_from_sql(c, seq, cast_to_torch=False) |
| test_array.append(embedding) |
| else: |
| filename = get_embedding_filename(model_name, self._full, pooling_types, 'pth') |
| save_path = os.path.join(save_dir, filename) |
| emb_dict = torch.load(save_path) |
| for seq in train_seqs: |
| embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) |
| train_array.append(embedding) |
| |
| for seq in valid_seqs: |
| embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) |
| valid_array.append(embedding) |
|
|
| for seq in test_seqs: |
| embedding = self._select_from_pth(emb_dict, seq, cast_to_np=True) |
| test_array.append(embedding) |
| del emb_dict |
|
|
| train_array = np.concatenate(train_array, axis=0) |
| valid_array = np.concatenate(valid_array, axis=0) |
| test_array = np.concatenate(test_array, axis=0) |
| |
| if self._full: |
| train_array = np.mean(train_array, axis=1) |
| valid_array = np.mean(valid_array, axis=1) |
| test_array = np.mean(test_array, axis=1) |
|
|
| print_message('Numpy dataset shapes') |
| print_message(f'Train: {train_array.shape}') |
| print_message(f'Valid: {valid_array.shape}') |
| print_message(f'Test: {test_array.shape}') |
| return train_array, valid_array, test_array |
|
|
| def build_pair_vector_numpy_dataset_from_embeddings( |
| self, |
| model_name, |
| train_seqs_a, |
| train_seqs_b, |
| valid_seqs_a, |
| valid_seqs_b, |
| test_seqs_a, |
| test_seqs_b, |
| ): |
| save_dir = self.embedding_args.embedding_save_dir |
| train_array, valid_array, test_array = [], [], [] |
| pooling_types = self.embedding_args.pooling_types |
| if self._sql: |
| filename = get_embedding_filename(model_name, self._full, pooling_types, 'db') |
| save_path = os.path.join(save_dir, filename) |
| with sqlite3.connect(save_path) as conn: |
| c = conn.cursor() |
| for seq_a, seq_b in zip(train_seqs_a, train_seqs_b): |
| seq_a, seq_b = self._random_order(seq_a, seq_b) |
| embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False) |
| embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False) |
| train_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
|
|
| for seq_a, seq_b in zip(valid_seqs_a, valid_seqs_b): |
| seq_a, seq_b = self._random_order(seq_a, seq_b) |
| embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False) |
| embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False) |
| valid_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
|
|
| for seq_a, seq_b in zip(test_seqs_a, test_seqs_b): |
| seq_a, seq_b = self._random_order(seq_a, seq_b) |
| embedding_a = self._select_from_sql(c, seq_a, cast_to_torch=False) |
| embedding_b = self._select_from_sql(c, seq_b, cast_to_torch=False) |
| test_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
| else: |
| filename = get_embedding_filename(model_name, self._full, pooling_types, 'pth') |
| save_path = os.path.join(save_dir, filename) |
| emb_dict = torch.load(save_path) |
| for seq_a, seq_b in zip(train_seqs_a, train_seqs_b): |
| seq_a, seq_b = self._random_order(seq_a, seq_b) |
| embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True) |
| embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True) |
| train_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
|
|
| for seq_a, seq_b in zip(valid_seqs_a, valid_seqs_b): |
| seq_a, seq_b = self._random_order(seq_a, seq_b) |
| embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True) |
| embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True) |
| valid_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
|
|
| for seq_a, seq_b in zip(test_seqs_a, test_seqs_b): |
| seq_a, seq_b = self._random_order(seq_a, seq_b) |
| embedding_a = self._select_from_pth(emb_dict, seq_a, cast_to_np=True) |
| embedding_b = self._select_from_pth(emb_dict, seq_b, cast_to_np=True) |
| test_array.append(np.concatenate([embedding_a, embedding_b], axis=-1)) |
| del emb_dict |
|
|
| train_array = np.concatenate(train_array, axis=0) |
| valid_array = np.concatenate(valid_array, axis=0) |
| test_array = np.concatenate(test_array, axis=0) |
| |
| if self._full: |
| train_array = np.mean(train_array, axis=1) |
| valid_array = np.mean(valid_array, axis=1) |
| test_array = np.mean(test_array, axis=1) |
|
|
| print_message('Numpy dataset shapes') |
| print_message(f'Train: {train_array.shape}') |
| print_message(f'Valid: {valid_array.shape}') |
| print_message(f'Test: {test_array.shape}') |
| return train_array, valid_array, test_array |
|
|
| def prepare_scikit_dataset(self, model_name, dataset): |
| train_set, valid_set, test_set, _, label_type, ppi = dataset |
|
|
| if ppi: |
| X_train, X_valid, X_test = self.build_pair_vector_numpy_dataset_from_embeddings( |
| model_name, |
| list(train_set['SeqA']), |
| list(train_set['SeqB']), |
| list(valid_set['SeqA']), |
| list(valid_set['SeqB']), |
| list(test_set['SeqA']), |
| list(test_set['SeqB']), |
| ) |
| else: |
| X_train, X_valid, X_test = self.build_vector_numpy_dataset_from_embeddings( |
| model_name, |
| list(train_set['seqs']), |
| list(valid_set['seqs']), |
| list(test_set['seqs']), |
| ) |
|
|
| y_train = self._labels_to_numpy(list(train_set['labels'])) |
| y_valid = self._labels_to_numpy(list(valid_set['labels'])) |
| y_test = self._labels_to_numpy(list(test_set['labels'])) |
|
|
| print_message('Numpy dataset shapes with labels') |
| print_message(f'Train: {X_train.shape}, {y_train.shape}') |
| print_message(f'Valid: {X_valid.shape}, {y_valid.shape}') |
| print_message(f'Test: {X_test.shape}, {y_test.shape}') |
| return X_train, y_train, X_valid, y_valid, X_test, y_test, label_type |
|
|