File size: 7,071 Bytes
9627ce0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import re
from typing import Sequence
import torch
# from src.data.esm.models.function_decoder import (
# FunctionTokenDecoder,
# merge_annotations,
# )
from src.data.esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from src.data.esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from src.data.esm.utils.constants import esm3 as C
from src.data.esm.utils.types import FunctionAnnotation
def encode_function_annotations(
sequence: str,
function_annotations: Sequence[FunctionAnnotation],
function_tokens_tokenizer: InterProQuantizedTokenizer,
residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
add_special_tokens: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
assert isinstance(
residue_annotations_tokenizer, ResidueAnnotationsTokenizer
), "residue_annotations_tokenizer must be of type ResidueAnnotationsTokenizer"
# Split the user's annotations by type
ft_annotations: list[FunctionAnnotation] = []
ra_annotations: list[FunctionAnnotation] = []
for fa in function_annotations:
assert (
1 <= fa.start <= fa.end <= len(sequence)
), f"Invalid (start, end) in function annotation {fa}. Indices 1-indexed and [inclusive, inclusive]"
supported_label = False
# Is it an InterPro label?
if match := re.search(r"IPR\d+", fa.label):
if match.group() in function_tokens_tokenizer.interpro_to_index:
ft_annotations.append(fa)
supported_label = True
# Is it a function keyword?
if fa.label in function_tokens_tokenizer._tfidf.vocab_to_index:
ft_annotations.append(fa)
supported_label = True
# Is it a residue annotation?
if fa.label in residue_annotations_tokenizer._labels:
ra_annotations.append(fa)
supported_label = True
if not supported_label:
raise ValueError(f"Unknown label in FunctionAnnotation: {fa.label}")
# Convert function token FunctionAnnotations -> Tensor
function_tokens = function_tokens_tokenizer.tokenize(
annotations=ft_annotations, seqlen=len(sequence)
)
function_token_ids = function_tokens_tokenizer.encode(
function_tokens, add_special_tokens=add_special_tokens
)
# Convert residue annotation FunctionAnnotations -> Tensor
if ra_annotations:
descriptions, starts, ends = zip(
*[(anot.label, anot.start, anot.end) for anot in ra_annotations]
)
else:
descriptions = starts = ends = None
ra_tokens = residue_annotations_tokenizer.tokenize(
{
"interpro_site_descriptions": descriptions,
"interpro_site_starts": starts,
"interpro_site_ends": ends,
},
sequence=sequence,
fail_on_mismatch=True,
)
residue_annotation_ids = residue_annotations_tokenizer.encode(
ra_tokens, add_special_tokens=add_special_tokens
)
return function_token_ids, residue_annotation_ids
def decode_function_tokens(
function_token_ids: torch.Tensor,
# function_token_decoder: FunctionTokenDecoder,
function_token_decoder,
function_tokens_tokenizer: InterProQuantizedTokenizer,
decoder_annotation_threshold: float = 0.1,
annotation_min_length: int | None = 5,
annotation_gap_merge_max: int | None = 3,
) -> list[FunctionAnnotation]:
"""Decodes model prediction logits into function predictions.
Merges function token and residue annotation predictions into a single
set of FunctionAnnotation predictions.
Args:
function_token_ids: Tensor <float>[length, depth] of
function token ids.
residue_annotation_logits: Tensor <float>[length, RA-vocab] of residue
annotation binary classification logits.
function_tokens_tokenizer: InterPro annotation tokenizer.
residue_annotation_threshold: tokenizer of residue annotations.
residue_annotation_threshold: predicted probability threshold for emitting
a predicted residue annotation.
Returns:
Predicted function annotations merged from both predictions.
"""
assert (
function_token_ids.ndim == 2
), "function_token_ids must be of shape (length, depth)"
annotations: list[FunctionAnnotation] = []
# Function Annotations from predicted function tokens.
decoded = function_token_decoder.decode(
function_token_ids,
tokenizer=function_tokens_tokenizer,
annotation_threshold=decoder_annotation_threshold,
annotation_min_length=annotation_min_length,
annotation_gap_merge_max=annotation_gap_merge_max,
)
# Convert predicted InterPro annotation to FunctionAnnotation.
annotations.extend(decoded["function_keywords"])
for annotation in decoded["interpro_annotations"]:
annotation: FunctionAnnotation
label = function_tokens_tokenizer.format_annotation(annotation)
annotations.append(
FunctionAnnotation(label=label, start=annotation.start, end=annotation.end)
)
return annotations
def decode_residue_annotation_tokens(
residue_annotations_token_ids: torch.Tensor,
residue_annotations_tokenizer: ResidueAnnotationsTokenizer,
annotation_min_length: int | None = 5,
annotation_gap_merge_max: int | None = 3,
) -> list[FunctionAnnotation]:
"""Decodes residue annotation tokens into FunctionAnnotations.
Args:
tokens: Tensor <int>[length, MAX_RESIDUE_ANNOTATIONS] of residue annotation tokens.
residue_annotations_tokenizer: Tokenizer of residue annotations.
threshold: predicted probability threshold for emitting a predicted residue
annotation.
Returns:
Predicted residue annotations.
"""
assert (
residue_annotations_token_ids.ndim == 2
), "logits must be of shape (length, MAX_RESIDUE_ANNOTATIONS)"
annotations: list[FunctionAnnotation] = []
for depth in range(0, C.MAX_RESIDUE_ANNOTATIONS):
token_ids = residue_annotations_token_ids[:, depth]
nonzero_indices = torch.nonzero(token_ids).squeeze(dim=1).cpu().numpy()
if len(nonzero_indices) == 0:
continue
for loc in nonzero_indices:
vocab_index: int = token_ids[loc].item() # type: ignore
label = residue_annotations_tokenizer.vocabulary[vocab_index]
if label not in [*residue_annotations_tokenizer.special_tokens, "<none>"]:
annotation = FunctionAnnotation(label=label, start=loc, end=loc)
annotations.append(annotation)
annotations = merge_annotations(annotations, merge_gap_max=annotation_gap_merge_max)
# Drop very small annotations.
if annotation_min_length is not None:
annotations = [
annotation
for annotation in annotations
if annotation.end - annotation.start + 1 >= annotation_min_length
]
return annotations
|