orthrus / src /gk_utils.py
antichronology's picture
feat: updated with new Orthrus models
20e1cef
import numpy as np
def create_cds_track(t, list_of_exons=None, list_of_cds_intervals=None):
"""
Create a track for the coding sequence of a transcript, handling both strands correctly.
- The final track length = sum of all exon lengths.
- The region before the CDS is zeros (the '5′ UTR').
- The CDS region is an every-third=1 pattern.
- The region after is zeros (the '3′ UTR').
Args:
t (gk.Transcript): The transcript object. Must have `t.cdss` for coding intervals.
list_of_exons (list, optional): List of exon intervals to use instead of t.exons.
list_of_cds_intervals (list, optional): List of CDS intervals to use instead of t.cdss.
Returns:
np.ndarray: 1D array of shape (transcript_length,).
0 for noncoding positions, 1 every third base in the CDS region.
"""
# 1) Compute total length of the transcript (sum of exon lengths)
if t is None:
assert list_of_exons is not None
transcript_length = sum(len(exon) for exon in list_of_exons)
exon_list = list_of_exons
# If t is None, we need to infer strand from the exons
strand = exon_list[0].strand
else:
transcript_length = sum(len(exon) for exon in t.exons)
exon_list = t.exons
strand = t.strand
# Assert all intervals are on the same strand
assert all(exon.strand == strand for exon in exon_list), (
"All exons must be on the same strand"
)
if transcript_length == 0:
return np.array([], dtype=int)
# 2) If there are no CDS intervals, return an all-zero track
if t is None:
assert list_of_cds_intervals is not None
cds_intervals = list_of_cds_intervals
else:
cds_intervals = t.cdss
if not cds_intervals:
return np.zeros(transcript_length, dtype=int)
# Assert all CDS intervals are on the same strand as exons
assert all(cds.strand == strand for cds in cds_intervals), (
"All CDS intervals must be on the same strand as exons"
)
# Sort CDS intervals by 5' to 3' direction
if strand == "+":
sorted_cds_intervals = sorted(cds_intervals, key=lambda x: x.start)
first_cds = sorted_cds_intervals[0] # Most 5' CDS interval
assert first_cds.end5.start == first_cds.start, (
"On positive strand, end5 should equal start"
)
else: # negative strand
sorted_cds_intervals = sorted(cds_intervals, key=lambda x: x.end, reverse=True)
first_cds = sorted_cds_intervals[0] # Most 5' CDS interval
assert first_cds.end5.start == first_cds.end, (
"On negative strand, end5 should equal end"
)
# 3) Sum the lengths of all CDS intervals
cds_length = sum(len(c) for c in sorted_cds_intervals)
# Sort exons in 5' to 3' direction
if strand == "+":
sorted_exons = sorted(exon_list, key=lambda x: x.start)
else:
sorted_exons = sorted(exon_list, key=lambda x: x.end, reverse=True)
# Find the 5' UTR length by calculating the total length of exons or parts of exons
# that come before the first CDS region in 5' to 3' direction
five_utr_length = 0
for exon in sorted_exons:
if strand == "+":
if exon.end <= first_cds.start:
# This exon is entirely in the 5' UTR
five_utr_length += len(exon)
elif exon.overlaps(first_cds):
# This exon contains the start of the first CDS
five_utr_length += max(0, first_cds.start - exon.start)
break
else:
# This exon comes after the first CDS, stop counting
break
else: # negative strand
if exon.start >= first_cds.end:
# This exon is entirely in the 5' UTR
five_utr_length += len(exon)
elif exon.overlaps(first_cds):
# This exon contains the start of the first CDS
five_utr_length += max(0, exon.end - first_cds.end)
break
else:
# This exon comes after the first CDS, stop counting
break
# 6) The remainder after we place the CDS is the "3′ UTR" length
three_utr_length = transcript_length - (five_utr_length + cds_length)
assert three_utr_length >= 0, "3' UTR length cannot be negative"
# 7) Build the CDS region track: every 3rd base is 1
cds_track = np.zeros(cds_length, dtype=int)
cds_track[0::3] = 1
# 8) Concatenate: 5′ zeros, the CDS track, 3′ zeros
track = np.concatenate(
[
np.zeros(five_utr_length, dtype=int),
cds_track,
np.zeros(three_utr_length, dtype=int),
]
)
return track
def create_splice_track(t, list_of_exons=None):
"""Create a track of the splice sites of a transcript.
The track is a 1D array where the positions of the splice sites are 1.
Args:
t (gk.Transcript): The transcript object.
"""
if list_of_exons is None:
len_mrna = sum([len(x) for x in t.exons])
list_of_exons = t.exons
else:
len_mrna = sum([len(x) for x in list_of_exons])
splicing_track = np.zeros(len_mrna, dtype=int)
cumulative_len = 0
for exon in list_of_exons:
cumulative_len += len(exon)
splicing_track[cumulative_len - 1 : cumulative_len] = 1
return splicing_track
# convert to one hot
def seq_to_oh(seq):
oh = np.zeros((len(seq), 4), dtype=int)
for i, base in enumerate(seq):
if base == "A":
oh[i, 0] = 1
elif base == "C":
oh[i, 1] = 1
elif base == "G":
oh[i, 2] = 1
elif base == "T":
oh[i, 3] = 1
return oh
def create_one_hot_encoding(t, genome, list_of_exons=None):
"""Create a track of the sequence of a transcript.
The track is a 2D array where the rows are the positions
and the columns are the one-hot encoding of the bases.
Args
t (gk.Transcript): The transcript object.
"""
if list_of_exons is None:
seq = "".join([genome.dna(exon) for exon in t.exons])
else:
seq = "".join([genome.dna(exon) for exon in list_of_exons])
oh = seq_to_oh(seq)
return oh
def create_six_track_encoding(
t,
genome,
list_of_exons=None,
list_of_cds_intervals=None,
channels_last=False,
):
"""Create a track of the sequence of a transcript.
Produces an array of shape (L,6) if channels_last=True
or (6,L) if channels_last=False.
Args:
t (gk.Transcript): The transcript object.
genome (gk.Genome): Genome reference.
channels_last (bool): If True, output is (L, 6). Otherwise, (6, L).
Returns:
np.ndarray: A 2D array with 6 channels (one-hot base encoding + CDS + splice).
"""
if t is not None:
# Step 1: Generate base tracks
oh = create_one_hot_encoding(t, genome) # shape is (L, 4)
cds_track = create_cds_track(t) # shape is (L,)
splice_track = create_splice_track(t) # shape is (L,)
else:
assert list_of_exons is not None
assert list_of_cds_intervals is not None
oh = create_one_hot_encoding(
t=None, list_of_exons=list_of_exons, genome=genome
) # shape is (L, 4)
cds_track = create_cds_track(
t=None,
list_of_exons=list_of_exons,
list_of_cds_intervals=list_of_cds_intervals,
) # shape is (L,)
splice_track = create_splice_track(
t=None, list_of_exons=list_of_exons
) # shape is (L,)
# Step 2: Create final track based on channels_last
if channels_last:
# Channels along axis=1 => shape (L, 6)
# (L, 4), (L, 1), (L, 1) -> (L, 6)
six_track = np.concatenate(
[oh, cds_track[:, None], splice_track[:, None]], axis=1
)
else:
# Channels along axis=0 => shape (6, L)
# first transpose one-hot from (L, 4) to (4, L)
oh = oh.T
# reshape cds/splice from (L,) to (1, L)
cds_track = cds_track[None, :]
splice_track = splice_track[None, :]
# now concatenate on axis=0 => shape (6, L)
six_track = np.concatenate([oh, cds_track, splice_track], axis=0)
return six_track
def find_transcript_by_gene_name(genome, gene_name):
"""Find all transcripts in a genome by gene name.
Args:
genome (object): The genome object containing a list of transcripts.
gene_name (str): The name of the gene whose transcripts are to be found.
Returns:
list: A list of transcript objects corresponding to the given gene name.
Raises:
ValueError: If no transcripts for the given gene name are found.
Example:
>>> # Find transcripts by gene name
>>> transcripts = find_transcript_by_gene_name(genome, 'PKP1')
>>> print(transcripts)
[<Transcript ENST00000367324.7 of PKP1>,
<Transcript ENST00000263946.7 of PKP1>,
<Transcript ENST00000352845.3 of PKP1>,
<Transcript ENST00000475988.1 of PKP1>,
<Transcript ENST00000477817.1 of PKP1>]
>>> # If gene name is not found
>>> find_transcript_by_gene_name(genome, 'XYZ')
ValueError: No transcripts found for gene name XYZ.
"""
genes = [x for x in genome.genes if x.name == gene_name]
if not genes:
raise ValueError(f"No genes found for gene name {gene_name}.")
if len(genes) > 1:
print(f"Warning: More than one gene found for gene name {gene_name}.")
print("Concatenating transcripts from all genes.")
transcripts = []
for gene in genes:
transcripts += gene.transcripts
return transcripts