| import numpy as np |
|
|
|
|
| def create_cds_track(t, list_of_exons=None, list_of_cds_intervals=None): |
| """ |
| Create a track for the coding sequence of a transcript, handling both strands correctly. |
| |
| - The final track length = sum of all exon lengths. |
| - The region before the CDS is zeros (the '5′ UTR'). |
| - The CDS region is an every-third=1 pattern. |
| - The region after is zeros (the '3′ UTR'). |
| |
| Args: |
| t (gk.Transcript): The transcript object. Must have `t.cdss` for coding intervals. |
| list_of_exons (list, optional): List of exon intervals to use instead of t.exons. |
| list_of_cds_intervals (list, optional): List of CDS intervals to use instead of t.cdss. |
| |
| Returns: |
| np.ndarray: 1D array of shape (transcript_length,). |
| 0 for noncoding positions, 1 every third base in the CDS region. |
| """ |
| |
| if t is None: |
| assert list_of_exons is not None |
| transcript_length = sum(len(exon) for exon in list_of_exons) |
| exon_list = list_of_exons |
| |
| strand = exon_list[0].strand |
| else: |
| transcript_length = sum(len(exon) for exon in t.exons) |
| exon_list = t.exons |
| strand = t.strand |
|
|
| |
| assert all(exon.strand == strand for exon in exon_list), ( |
| "All exons must be on the same strand" |
| ) |
|
|
| if transcript_length == 0: |
| return np.array([], dtype=int) |
|
|
| |
| if t is None: |
| assert list_of_cds_intervals is not None |
| cds_intervals = list_of_cds_intervals |
| else: |
| cds_intervals = t.cdss |
| if not cds_intervals: |
| return np.zeros(transcript_length, dtype=int) |
|
|
| |
| assert all(cds.strand == strand for cds in cds_intervals), ( |
| "All CDS intervals must be on the same strand as exons" |
| ) |
|
|
| |
| if strand == "+": |
| sorted_cds_intervals = sorted(cds_intervals, key=lambda x: x.start) |
| first_cds = sorted_cds_intervals[0] |
| assert first_cds.end5.start == first_cds.start, ( |
| "On positive strand, end5 should equal start" |
| ) |
| else: |
| sorted_cds_intervals = sorted(cds_intervals, key=lambda x: x.end, reverse=True) |
| first_cds = sorted_cds_intervals[0] |
| assert first_cds.end5.start == first_cds.end, ( |
| "On negative strand, end5 should equal end" |
| ) |
|
|
| |
| cds_length = sum(len(c) for c in sorted_cds_intervals) |
|
|
| |
| if strand == "+": |
| sorted_exons = sorted(exon_list, key=lambda x: x.start) |
| else: |
| sorted_exons = sorted(exon_list, key=lambda x: x.end, reverse=True) |
|
|
| |
| |
| five_utr_length = 0 |
| for exon in sorted_exons: |
| if strand == "+": |
| if exon.end <= first_cds.start: |
| |
| five_utr_length += len(exon) |
| elif exon.overlaps(first_cds): |
| |
| five_utr_length += max(0, first_cds.start - exon.start) |
| break |
| else: |
| |
| break |
| else: |
| if exon.start >= first_cds.end: |
| |
| five_utr_length += len(exon) |
| elif exon.overlaps(first_cds): |
| |
| five_utr_length += max(0, exon.end - first_cds.end) |
| break |
| else: |
| |
| break |
|
|
| |
| three_utr_length = transcript_length - (five_utr_length + cds_length) |
| assert three_utr_length >= 0, "3' UTR length cannot be negative" |
|
|
| |
| cds_track = np.zeros(cds_length, dtype=int) |
| cds_track[0::3] = 1 |
|
|
| |
| track = np.concatenate( |
| [ |
| np.zeros(five_utr_length, dtype=int), |
| cds_track, |
| np.zeros(three_utr_length, dtype=int), |
| ] |
| ) |
|
|
| return track |
|
|
|
|
| def create_splice_track(t, list_of_exons=None): |
| """Create a track of the splice sites of a transcript. |
| The track is a 1D array where the positions of the splice sites are 1. |
| |
| Args: |
| t (gk.Transcript): The transcript object. |
| """ |
| if list_of_exons is None: |
| len_mrna = sum([len(x) for x in t.exons]) |
| list_of_exons = t.exons |
| else: |
| len_mrna = sum([len(x) for x in list_of_exons]) |
|
|
| splicing_track = np.zeros(len_mrna, dtype=int) |
| cumulative_len = 0 |
| for exon in list_of_exons: |
| cumulative_len += len(exon) |
| splicing_track[cumulative_len - 1 : cumulative_len] = 1 |
|
|
| return splicing_track |
|
|
|
|
| |
| def seq_to_oh(seq): |
| oh = np.zeros((len(seq), 4), dtype=int) |
| for i, base in enumerate(seq): |
| if base == "A": |
| oh[i, 0] = 1 |
| elif base == "C": |
| oh[i, 1] = 1 |
| elif base == "G": |
| oh[i, 2] = 1 |
| elif base == "T": |
| oh[i, 3] = 1 |
| return oh |
|
|
|
|
| def create_one_hot_encoding(t, genome, list_of_exons=None): |
| """Create a track of the sequence of a transcript. |
| The track is a 2D array where the rows are the positions |
| and the columns are the one-hot encoding of the bases. |
| |
| Args |
| t (gk.Transcript): The transcript object. |
| """ |
| if list_of_exons is None: |
| seq = "".join([genome.dna(exon) for exon in t.exons]) |
| else: |
| seq = "".join([genome.dna(exon) for exon in list_of_exons]) |
| oh = seq_to_oh(seq) |
| return oh |
|
|
|
|
| def create_six_track_encoding( |
| t, |
| genome, |
| list_of_exons=None, |
| list_of_cds_intervals=None, |
| channels_last=False, |
| ): |
| """Create a track of the sequence of a transcript. |
| |
| Produces an array of shape (L,6) if channels_last=True |
| or (6,L) if channels_last=False. |
| |
| Args: |
| t (gk.Transcript): The transcript object. |
| genome (gk.Genome): Genome reference. |
| channels_last (bool): If True, output is (L, 6). Otherwise, (6, L). |
| |
| Returns: |
| np.ndarray: A 2D array with 6 channels (one-hot base encoding + CDS + splice). |
| """ |
| if t is not None: |
| |
| oh = create_one_hot_encoding(t, genome) |
| cds_track = create_cds_track(t) |
| splice_track = create_splice_track(t) |
| else: |
| assert list_of_exons is not None |
| assert list_of_cds_intervals is not None |
| oh = create_one_hot_encoding( |
| t=None, list_of_exons=list_of_exons, genome=genome |
| ) |
| cds_track = create_cds_track( |
| t=None, |
| list_of_exons=list_of_exons, |
| list_of_cds_intervals=list_of_cds_intervals, |
| ) |
| splice_track = create_splice_track( |
| t=None, list_of_exons=list_of_exons |
| ) |
|
|
| |
| if channels_last: |
| |
| |
| six_track = np.concatenate( |
| [oh, cds_track[:, None], splice_track[:, None]], axis=1 |
| ) |
| else: |
| |
| |
| oh = oh.T |
| |
| cds_track = cds_track[None, :] |
| splice_track = splice_track[None, :] |
| |
| six_track = np.concatenate([oh, cds_track, splice_track], axis=0) |
|
|
| return six_track |
|
|
|
|
| def find_transcript_by_gene_name(genome, gene_name): |
| """Find all transcripts in a genome by gene name. |
| |
| Args: |
| genome (object): The genome object containing a list of transcripts. |
| gene_name (str): The name of the gene whose transcripts are to be found. |
| |
| Returns: |
| list: A list of transcript objects corresponding to the given gene name. |
| |
| Raises: |
| ValueError: If no transcripts for the given gene name are found. |
| |
| Example: |
| >>> # Find transcripts by gene name |
| >>> transcripts = find_transcript_by_gene_name(genome, 'PKP1') |
| >>> print(transcripts) |
| [<Transcript ENST00000367324.7 of PKP1>, |
| <Transcript ENST00000263946.7 of PKP1>, |
| <Transcript ENST00000352845.3 of PKP1>, |
| <Transcript ENST00000475988.1 of PKP1>, |
| <Transcript ENST00000477817.1 of PKP1>] |
| >>> # If gene name is not found |
| >>> find_transcript_by_gene_name(genome, 'XYZ') |
| ValueError: No transcripts found for gene name XYZ. |
| """ |
| genes = [x for x in genome.genes if x.name == gene_name] |
| if not genes: |
| raise ValueError(f"No genes found for gene name {gene_name}.") |
| if len(genes) > 1: |
| print(f"Warning: More than one gene found for gene name {gene_name}.") |
| print("Concatenating transcripts from all genes.") |
|
|
| transcripts = [] |
| for gene in genes: |
| transcripts += gene.transcripts |
| return transcripts |
|
|