Spaces:
Running
Running
| """ | |
| File: CodonData.py | |
| --------------------- | |
| Includes helper functions for preprocessing NCBI or Kazusa databases and | |
| preparing the data for training and inference of the CodonTransformer model. | |
| """ | |
| import json | |
| import os | |
| import random | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import pandas as pd | |
| import python_codon_tables as pct | |
| from Bio import SeqIO | |
| from Bio.Seq import Seq | |
| from sklearn.utils import shuffle as sk_shuffle | |
| from tqdm import tqdm | |
| from CodonTransformer.CodonUtils import ( | |
| AMBIGUOUS_AMINOACID_MAP, | |
| AMINO2CODON_TYPE, | |
| AMINO_ACIDS, | |
| ORGANISM2ID, | |
| START_CODONS, | |
| STOP_CODONS, | |
| STOP_SYMBOL, | |
| STOP_SYMBOLS, | |
| ProteinConfig, | |
| find_pattern_in_fasta, | |
| get_taxonomy_id, | |
| sort_amino2codon_skeleton, | |
| ) | |
| def prepare_training_data( | |
| dataset: Union[str, pd.DataFrame], output_file: str, shuffle: bool = True | |
| ) -> None: | |
| """ | |
| Prepare a JSON dataset for training the CodonTransformer model. | |
| Input dataset should have columns below: | |
| - dna: str (DNA sequence) | |
| - protein: str (Protein sequence) | |
| - organism: Union[int, str] (ID or Name of the organism) | |
| The output JSON dataset will have the following format: | |
| {"idx": 0, "codons": "M_ATG R_AGG L_TTG L_CTA R_CGA __TAG", "organism": 51} | |
| {"idx": 1, "codons": "M_ATG K_AAG C_TGC F_TTT F_TTC __TAA", "organism": 59} | |
| Args: | |
| dataset (Union[str, pd.DataFrame]): Input dataset in CSV or DataFrame format. | |
| output_file (str): Path to save the output JSON dataset. | |
| shuffle (bool, optional): Whether to shuffle the dataset before saving. | |
| Defaults to True. | |
| Returns: | |
| None | |
| """ | |
| if isinstance(dataset, str): | |
| dataset = pd.read_csv(dataset) | |
| required_columns = {"dna", "protein", "organism"} | |
| if not required_columns.issubset(dataset.columns): | |
| raise ValueError(f"Input dataset must have columns: {required_columns}") | |
| # Prepare the dataset for finetuning | |
| dataset["codons"] = dataset.apply( | |
| lambda row: get_merged_seq(row["protein"], row["dna"], separator="_"), axis=1 | |
| ) | |
| # Replace organism str with organism id using ORGANISM2ID | |
| dataset["organism"] = dataset["organism"].apply( | |
| lambda org: process_organism(org, ORGANISM2ID) | |
| ) | |
| # Save the dataset to a JSON file | |
| dataframe_to_json(dataset[["codons", "organism"]], output_file, shuffle=shuffle) | |
| def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) -> None: | |
| """ | |
| Convert pandas DataFrame to JSON file format suitable for training CodonTransformer. | |
| This function takes a preprocessed DataFrame and writes it to a JSON file | |
| where each line is a JSON object representing a single record. | |
| Args: | |
| df (pd.DataFrame): The input DataFrame with 'codons' and 'organism' columns. | |
| output_file (str): Path to the output JSON file. | |
| shuffle (bool, optional): Whether to shuffle the dataset before saving. | |
| Defaults to True. | |
| Returns: | |
| None | |
| Raises: | |
| ValueError: If the required columns are not present in the DataFrame. | |
| """ | |
| required_columns = {"codons", "organism"} | |
| if not required_columns.issubset(df.columns): | |
| raise ValueError(f"DataFrame must contain columns: {required_columns}") | |
| print(f"\nStarted writing to {output_file}...") | |
| # Shuffle the DataFrame if requested | |
| if shuffle: | |
| df = sk_shuffle(df) | |
| # Write the DataFrame to a JSON file | |
| with open(output_file, "w") as f: | |
| for idx, row in tqdm( | |
| df.iterrows(), total=len(df), desc="Writing JSON...", unit=" records" | |
| ): | |
| doc = {"idx": idx, "codons": row["codons"], "organism": row["organism"]} | |
| f.write(json.dumps(doc) + "\n") | |
| print(f"\nTotal Entries Saved: {len(df)}, JSON data saved to {output_file}") | |
| def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) -> int: | |
| """ | |
| Process and validate the organism input, converting it to a valid organism ID. | |
| This function handles both string (organism name) and integer (organism ID) inputs. | |
| It validates the input against a provided mapping of organism names to IDs. | |
| Args: | |
| organism (Union[str, int]): Input organism, either as a name (str) or ID (int). | |
| organism_to_id (Dict[str, int]): Dictionary mapping organism names to their | |
| corresponding IDs. | |
| Returns: | |
| int: The validated organism ID. | |
| Raises: | |
| ValueError: If the input is an invalid organism name or ID. | |
| TypeError: If the input is neither a string nor an integer. | |
| """ | |
| if isinstance(organism, str): | |
| if organism not in organism_to_id: | |
| raise ValueError(f"Invalid organism name: {organism}") | |
| return organism_to_id[organism] | |
| elif isinstance(organism, int): | |
| if organism not in organism_to_id.values(): | |
| raise ValueError(f"Invalid organism ID: {organism}") | |
| return organism | |
| raise TypeError( | |
| f"Organism must be a string or integer, not {type(organism).__name__}" | |
| ) | |
| def preprocess_protein_sequence(protein: str) -> str: | |
| """ | |
| Preprocess a protein sequence by cleaning, standardizing, and handling | |
| ambiguous amino acids. | |
| Args: | |
| protein (str): The input protein sequence. | |
| Returns: | |
| str: The preprocessed protein sequence. | |
| Raises: | |
| ValueError: If the protein sequence is invalid or if the configuration is invalid. | |
| """ | |
| if not protein: | |
| raise ValueError("Protein sequence is empty.") | |
| # Clean and standardize the protein sequence | |
| protein = ( | |
| protein.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "") | |
| ) | |
| # Handle ambiguous amino acids based on the specified behavior | |
| config = ProteinConfig() | |
| ambiguous_aminoacid_map_override = config.get("ambiguous_aminoacid_map_override") | |
| ambiguous_aminoacid_behavior = config.get("ambiguous_aminoacid_behavior") | |
| ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy() | |
| for aminoacid, standard_aminoacids in ambiguous_aminoacid_map_override.items(): | |
| ambiguous_aminoacid_map[aminoacid] = standard_aminoacids | |
| if ambiguous_aminoacid_behavior == "raise_error": | |
| if any(aminoacid in ambiguous_aminoacid_map for aminoacid in protein): | |
| raise ValueError("Ambiguous amino acids found in protein sequence.") | |
| elif ambiguous_aminoacid_behavior == "standardize_deterministic": | |
| protein = "".join( | |
| ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0] | |
| for aminoacid in protein | |
| ) | |
| elif ambiguous_aminoacid_behavior == "standardize_random": | |
| protein = "".join( | |
| random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid])) | |
| for aminoacid in protein | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}." | |
| ) | |
| # Check for sequence validity | |
| if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein): | |
| raise ValueError("Invalid characters in protein sequence.") | |
| if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS: | |
| raise ValueError( | |
| "Protein sequence must end with `*`, or `_`, or an amino acid." | |
| ) | |
| # Replace '*' at the end of protein with STOP_SYMBOL if present | |
| if protein[-1] == "*": | |
| protein = protein[:-1] + STOP_SYMBOL | |
| # Add stop symbol to end of protein | |
| if protein[-1] != STOP_SYMBOL: | |
| protein += STOP_SYMBOL | |
| return protein | |
| def replace_ambiguous_codons(dna: str) -> str: | |
| """ | |
| Replaces ambiguous codons in a DNA sequence with "UNK". | |
| Args: | |
| dna (str): The DNA sequence to process. | |
| Returns: | |
| str: The processed DNA sequence with ambiguous codons replaced by "UNK". | |
| """ | |
| result = [] | |
| dna = dna.upper() | |
| # Check codons in DNA sequence | |
| for i in range(0, len(dna), 3): | |
| codon = dna[i : i + 3] | |
| if len(codon) == 3 and all(nucleotide in "ATCG" for nucleotide in codon): | |
| result.append(codon) | |
| else: | |
| result.append("UNK") | |
| return "".join(result) | |
| def preprocess_dna_sequence(dna: str) -> str: | |
| """ | |
| Cleans and preprocesses a DNA sequence by standardizing it and replacing | |
| ambiguous codons. | |
| Args: | |
| dna (str): The DNA sequence to preprocess. | |
| Returns: | |
| str: The cleaned and preprocessed DNA sequence. | |
| """ | |
| if not dna: | |
| return "" | |
| # Clean and standardize the DNA sequence | |
| dna = dna.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "") | |
| # Replace codons with ambigous nucleotides with "UNK" | |
| dna = replace_ambiguous_codons(dna) | |
| # Add unkown stop codon to end of DNA sequence if not present | |
| if dna[-3:] not in STOP_CODONS: | |
| dna += "UNK" | |
| return dna | |
| def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str: | |
| """ | |
| Return the merged sequence of protein amino acids and DNA codons in the form | |
| of tokens separated by space, where each token is composed of an amino acid + | |
| separator + codon. | |
| Args: | |
| protein (str): Protein sequence. | |
| dna (str): DNA sequence. | |
| separator (str): Separator between amino acid and codon. | |
| Returns: | |
| str: Merged sequence. | |
| Example: | |
| >>> get_merged_seq(protein="MAV_", dna="ATGGCTGTGTAA", separator="_") | |
| 'M_ATG A_GCT V_GTG __TAA' | |
| >>> get_merged_seq(protein="QHH_", dna="", separator="_") | |
| 'Q_UNK H_UNK H_UNK __UNK' | |
| """ | |
| merged_seq = "" | |
| # Prepare protein and dna sequences | |
| dna = preprocess_dna_sequence(dna) | |
| protein = preprocess_protein_sequence(protein) | |
| # Check if the length of protein and dna sequences are equal | |
| if len(dna) > 0 and len(protein) != len(dna) / 3: | |
| raise ValueError( | |
| 'Length of protein (including stop symbol such as "_") and ' | |
| "the number of codons in DNA sequence (including stop codon) " | |
| "must be equal." | |
| ) | |
| # Merge protein and DNA sequences into tokens | |
| for i, aminoacid in enumerate(protein): | |
| merged_seq += f'{aminoacid}{separator}{dna[i * 3:i * 3 + 3] if dna else "UNK"} ' | |
| return merged_seq.strip() | |
| def is_correct_seq(dna: str, protein: str, stop_symbol: str = STOP_SYMBOL) -> bool: | |
| """ | |
| Check if the given DNA and protein pair is correct, that is: | |
| 1. The length of dna is divisible by 3 | |
| 2. There is an initiator codon in the beginning of dna | |
| 3. There is only one stop codon in the sequence | |
| 4. The only stop codon is the last codon | |
| Note since in Codon Table 3, 'TGA' is interpreted as Triptophan (W), | |
| there is a separate check to make sure those sequences are considered correct. | |
| Args: | |
| dna (str): DNA sequence. | |
| protein (str): Protein sequence. | |
| stop_symbol (str): Stop symbol. | |
| Returns: | |
| bool: True if the sequence is correct, False otherwise. | |
| """ | |
| return ( | |
| len(dna) % 3 == 0 # Check if DNA length is divisible by 3 | |
| and dna[:3].upper() in START_CODONS # Check for initiator codon | |
| and protein[-1] | |
| == stop_symbol # Check if the last protein symbol is the stop symbol | |
| and protein.count(stop_symbol) == 1 # Check if there is only one stop symbol | |
| and len(set(dna)) | |
| == 4 # Check if DNA consists of 4 unique nucleotides (A, T, C, G) | |
| ) | |
| def get_amino_acid_sequence( | |
| dna: str, | |
| stop_symbol: str = "_", | |
| codon_table: int = 1, | |
| return_correct_seq: bool = False, | |
| ) -> Union[str, Tuple[str, bool]]: | |
| """ | |
| Return the translated protein sequence given a DNA sequence and codon table. | |
| Args: | |
| dna (str): DNA sequence. | |
| stop_symbol (str): Stop symbol. | |
| codon_table (int): Codon table number. | |
| return_correct_seq (bool): Whether to return if the sequence is correct. | |
| Returns: | |
| Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if | |
| return_correct_seq is True, otherwise just the protein sequence. | |
| """ | |
| dna_seq = Seq(dna).strip() | |
| # Translate the DNA sequence to a protein sequence | |
| protein_seq = str( | |
| dna_seq.translate( | |
| stop_symbol=stop_symbol, # Symbol to use for stop codons | |
| to_stop=False, # Translate the entire sequence, including any stop codons | |
| cds=False, # Do not assume the input is a coding sequence | |
| table=codon_table, # Codon table to use for translation | |
| ) | |
| ).strip() | |
| return ( | |
| protein_seq | |
| if not return_correct_seq | |
| else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol)) | |
| ) | |
| def read_fasta_file( | |
| input_file: str, | |
| save_to_file: Optional[str] = None, | |
| organism: str = "", | |
| buffer_size: int = 50000, | |
| ) -> pd.DataFrame: | |
| """ | |
| Read a FASTA file of DNA sequences and convert it to a Pandas DataFrame. | |
| Optionally, save the DataFrame to a CSV file. | |
| Args: | |
| input_file (str): Path to the input FASTA file. | |
| save_to_file (Optional[str]): Path to save the output DataFrame. If None, | |
| data is only returned. | |
| organism (str): Name of the organism. If empty, it will be extracted from | |
| the FASTA description. | |
| buffer_size (int): Number of records to process before writing to file. | |
| Returns: | |
| pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe | |
| is True, else None. | |
| Raises: | |
| FileNotFoundError: If the input file does not exist. | |
| """ | |
| if not os.path.exists(input_file): | |
| raise FileNotFoundError(f"Input file not found: {input_file}") | |
| buffer = [] | |
| columns = [ | |
| "dna", | |
| "protein", | |
| "correct_seq", | |
| "organism", | |
| "GeneID", | |
| "description", | |
| "tokenized", | |
| ] | |
| # Initialize DataFrame to store all data if return_dataframe is True | |
| all_data = pd.DataFrame(columns=columns) | |
| with open(input_file, "r") as fasta_file: | |
| for record in tqdm( | |
| SeqIO.parse(fasta_file, "fasta"), | |
| desc=f"Processing {organism}", | |
| unit=" Records", | |
| ): | |
| dna = str(record.seq).strip().upper() # Ensure uppercase DNA sequence | |
| # Determine the organism from the record if not provided | |
| current_organism = organism or find_pattern_in_fasta( | |
| "organism", record.description | |
| ) | |
| gene_id = find_pattern_in_fasta("GeneID", record.description) | |
| # Get the appropriate codon table for the organism | |
| codon_table = get_codon_table(current_organism) | |
| # Translate DNA to protein sequence | |
| protein, correct_seq = get_amino_acid_sequence( | |
| dna, | |
| stop_symbol=STOP_SYMBOL, | |
| codon_table=codon_table, | |
| return_correct_seq=True, | |
| ) | |
| description = record.description.split("[", 1)[0].strip() | |
| tokenized = get_merged_seq(protein, dna, separator=STOP_SYMBOL) | |
| # Create a data row for the current sequence | |
| data_row = { | |
| "dna": dna, | |
| "protein": protein, | |
| "correct_seq": correct_seq, | |
| "organism": current_organism, | |
| "GeneID": gene_id, | |
| "description": description, | |
| "tokenized": tokenized, | |
| } | |
| buffer.append(data_row) | |
| # Write buffer to CSV file when buffer size is reached | |
| if save_to_file and len(buffer) >= buffer_size: | |
| write_buffer_to_csv(buffer, save_to_file, columns) | |
| buffer = [] | |
| all_data = pd.concat( | |
| [all_data, pd.DataFrame([data_row])], ignore_index=True | |
| ) | |
| # Write remaining buffer to CSV file | |
| if save_to_file and buffer: | |
| write_buffer_to_csv(buffer, save_to_file, columns) | |
| return all_data | |
| def write_buffer_to_csv(buffer: List[Dict], output_path: str, columns: List[str]): | |
| """Helper function to write buffer to CSV file.""" | |
| buffer_df = pd.DataFrame(buffer, columns=columns) | |
| buffer_df.to_csv( | |
| output_path, | |
| mode="a", | |
| header=(not os.path.exists(output_path)), | |
| index=True, | |
| ) | |
| def download_codon_frequencies_from_kazusa( | |
| taxonomy_id: Optional[int] = None, | |
| organism: Optional[str] = None, | |
| taxonomy_reference: Optional[str] = None, | |
| return_original_format: bool = False, | |
| ) -> AMINO2CODON_TYPE: | |
| """ | |
| Return the codon table of the given taxonomy ID from the Kazusa Database. | |
| Args: | |
| taxonomy_id (Optional[int]): Taxonomy ID. | |
| organism (Optional[str]): Name of the organism. | |
| taxonomy_reference (Optional[str]): Taxonomy reference. | |
| return_original_format (bool): Whether to return in the original format. | |
| Returns: | |
| AMINO2CODON_TYPE: Codon table. | |
| """ | |
| if taxonomy_reference: | |
| taxonomy_id = get_taxonomy_id(taxonomy_reference, organism=organism) | |
| kazusa_amino2codon = pct.get_codons_table(table_name=taxonomy_id) | |
| if return_original_format: | |
| return kazusa_amino2codon | |
| # Replace "*" with STOP_SYMBOL in the codon table | |
| kazusa_amino2codon[STOP_SYMBOL] = kazusa_amino2codon.pop("*") | |
| # Create amino2codon dictionary | |
| amino2codon = { | |
| aminoacid: (list(codon2freq.keys()), list(codon2freq.values())) | |
| for aminoacid, codon2freq in kazusa_amino2codon.items() | |
| } | |
| return sort_amino2codon_skeleton(amino2codon) | |
| def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE: | |
| """ | |
| Return the empty skeleton of the amino2codon dictionary, needed for | |
| get_codon_frequencies. | |
| Args: | |
| organism (str): Name of the organism. | |
| Returns: | |
| AMINO2CODON_TYPE: Empty amino2codon dictionary. | |
| """ | |
| amino2codon = {} | |
| possible_codons = [f"{i}{j}{k}" for i in "ACGT" for j in "ACGT" for k in "ACGT"] | |
| possible_aminoacids = get_amino_acid_sequence( | |
| dna="".join(possible_codons), | |
| codon_table=get_codon_table(organism), | |
| return_correct_seq=False, | |
| ) | |
| # Initialize the amino2codon skeleton with all possible codons and set their | |
| # frequencies to 0 | |
| for i, (codon, amino) in enumerate(zip(possible_codons, possible_aminoacids)): | |
| if amino not in amino2codon: | |
| amino2codon[amino] = ([], []) | |
| amino2codon[amino][0].append(codon) | |
| amino2codon[amino][1].append(0) | |
| # Sort the dictionary and each list of codon frequency alphabetically | |
| amino2codon = sort_amino2codon_skeleton(amino2codon) | |
| return amino2codon | |
| def get_codon_frequencies( | |
| dna_sequences: List[str], | |
| protein_sequences: Optional[List[str]] = None, | |
| organism: Optional[str] = None, | |
| ) -> AMINO2CODON_TYPE: | |
| """ | |
| Return a dictionary mapping each codon to its respective frequency based on | |
| the collection of DNA sequences and protein sequences. | |
| Args: | |
| dna_sequences (List[str]): List of DNA sequences. | |
| protein_sequences (Optional[List[str]]): List of protein sequences. | |
| organism (Optional[str]): Name of the organism. | |
| Returns: | |
| AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons | |
| and frequencies. | |
| """ | |
| if organism: | |
| codon_table = get_codon_table(organism) | |
| protein_sequences = [ | |
| get_amino_acid_sequence( | |
| dna, codon_table=codon_table, return_correct_seq=False | |
| ) | |
| for dna in dna_sequences | |
| ] | |
| amino2codon = build_amino2codon_skeleton(organism) | |
| # Count the frequencies of each codon for each amino acid | |
| for dna, protein in zip(dna_sequences, protein_sequences): | |
| for i, amino in enumerate(protein): | |
| codon = dna[i * 3 : (i + 1) * 3] | |
| codon_loc = amino2codon[amino][0].index(codon) | |
| amino2codon[amino][1][codon_loc] += 1 | |
| # Normalize codon frequencies per amino acid so they sum to 1 | |
| amino2codon = { | |
| amino: (codons, [freq / (sum(frequencies) + 1e-100) for freq in frequencies]) | |
| for amino, (codons, frequencies) in amino2codon.items() | |
| } | |
| return amino2codon | |
| def get_organism_to_codon_frequencies( | |
| dataset: pd.DataFrame, organisms: List[str] | |
| ) -> Dict[str, AMINO2CODON_TYPE]: | |
| """ | |
| Return a dictionary mapping each organism to their codon frequency distribution. | |
| Args: | |
| dataset (pd.DataFrame): DataFrame containing DNA sequences. | |
| organisms (List[str]): List of organisms. | |
| Returns: | |
| Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon | |
| frequency distribution. | |
| """ | |
| organism2frequencies = {} | |
| # Calculate codon frequencies for each organism in the dataset | |
| for organism in tqdm( | |
| organisms, desc="Calculating Codon Frequencies: ", unit="Organism" | |
| ): | |
| organism_data = dataset.loc[dataset["organism"] == organism] | |
| dna_sequences = organism_data["dna"].to_list() | |
| protein_sequences = organism_data["protein"].to_list() | |
| codon_frequencies = get_codon_frequencies(dna_sequences, protein_sequences) | |
| organism2frequencies[organism] = codon_frequencies | |
| return organism2frequencies | |
| def get_codon_table(organism: str) -> int: | |
| """ | |
| Return the appropriate NCBI codon table for a given organism. | |
| Args: | |
| organism (str): Name of the organism. | |
| Returns: | |
| int: Codon table number. | |
| """ | |
| # Common codon table (Table 1) for many model organisms | |
| if organism in [ | |
| "Arabidopsis thaliana", | |
| "Caenorhabditis elegans", | |
| "Chlamydomonas reinhardtii", | |
| "Saccharomyces cerevisiae", | |
| "Danio rerio", | |
| "Drosophila melanogaster", | |
| "Homo sapiens", | |
| "Mus musculus", | |
| "Nicotiana tabacum", | |
| "Solanum tuberosum", | |
| "Solanum lycopersicum", | |
| "Oryza sativa", | |
| "Glycine max", | |
| "Zea mays", | |
| ]: | |
| codon_table = 1 | |
| # Chloroplast codon table (Table 11) | |
| elif organism in [ | |
| "Chlamydomonas reinhardtii chloroplast", | |
| "Nicotiana tabacum chloroplast", | |
| ]: | |
| codon_table = 11 | |
| # Default to Table 11 for other bacteria and archaea | |
| else: | |
| codon_table = 11 | |
| return codon_table | |