File size: 4,595 Bytes
dd192e9 778e7ea dd192e9 778e7ea dd192e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from typing import Any, Optional
import torch
from transformers import BertForSequenceClassification, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.utils.generic import ModelOutput
NUCLEOTIDE_MAP = {
"A": "[DNA_A]",
"C": "[DNA_C]",
"G": "[DNA_G]",
"T": "[DNA_T]",
"R": "[DNA_R]",
"Y": "[DNA_Y]",
"S": "[DNA_S]",
"W": "[DNA_W]",
"K": "[DNA_K]",
"M": "[DNA_M]",
"B": "[DNA_B]",
"D": "[DNA_D]",
"H": "[DNA_H]",
"V": "[DNA_V]",
"N": "[DNA_N]",
"I": "[INTRON]",
"E": "[EXON]",
"U": "[DNA_UNKNOWN]"
}
def process_sequence(seq: str) -> str:
seq = seq.strip().upper()
return "".join(NUCLEOTIDE_MAP.get(ch, "[DNA_INVALID]") for ch in seq)
def process_label(p: str) -> str:
if p == 0:
return "EXON"
if p == 1:
return "INTRON"
return "UNKNOWN"
def ensure_optional_str(value: Any) -> str:
return value if isinstance(value, str) else ""
class BERTNucleotideClassificationPipeline(Pipeline):
def _build_prompt(
self,
sequence: str,
before: str,
after: str,
organism: Optional[str]
) -> str:
out = f"<|SEQUENCE|>{process_sequence(sequence[0])}"
before_p = process_sequence(before[:24])
out += f"<|FLANK_BEFORE|>{before_p}"
after_p = process_sequence(after[:24])
out += f"<|FLANK_AFTER|>{after_p}"
if organism:
out += f"<|ORGANISM|>{organism[:10].lower()}"
out += "<|TARGET|>"
return out
def _sanitize_parameters(
self,
**kwargs
):
preprocess_kwargs = {}
for k in ("organism", "before", "after", "max_length"):
if k in kwargs:
preprocess_kwargs[k] = kwargs[k]
forward_kwargs = {
k: v for k, v in kwargs.items()
if k not in preprocess_kwargs
}
postprocess_kwargs = {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(
self,
input_,
**preprocess_parameters
):
assert self.tokenizer
if isinstance(input_, str):
sequence = input_
elif isinstance(input_, dict):
sequence = input_.get("sequence", "")
else:
raise TypeError("input_ must be str or dict with 'sequence' key")
organism_raw = preprocess_parameters.get("organism", None)
before_raw = preprocess_parameters.get("before", None)
after_raw = preprocess_parameters.get("after", None)
if organism_raw is None and isinstance(input_, dict):
organism_raw = input_.get("organism", None)
if before_raw is None and isinstance(input_, dict):
before_raw = input_.get("before", None)
if after_raw is None and isinstance(input_, dict):
after_raw = input_.get("after", None)
before: str = ensure_optional_str(before_raw)
after: str = ensure_optional_str(after_raw)
organism: Optional[str] = ensure_optional_str(organism_raw)
max_length = preprocess_parameters.get("max_length", 256)
if not isinstance(max_length, int):
raise TypeError("max_length must be an int")
prompt = self._build_prompt(sequence, before=before, after=after, organism=organism)
token_kwargs: dict[str, Any] = {"return_tensors": "pt"}
token_kwargs["max_length"] = max_length
token_kwargs["truncation"] = True
enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device)
return {"prompt": prompt, "inputs": enc}
def _forward(self, input_tensors: dict, **forward_params):
assert isinstance(self.model, BertForSequenceClassification)
kwargs = dict(forward_params)
inputs = input_tensors.get("inputs")
if inputs is None:
raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').")
if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor):
try:
expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()}
except Exception:
expanded_inputs = {}
for k, v in dict(inputs).items():
expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v
else:
if isinstance(inputs, torch.Tensor):
expanded_inputs = {"input_ids": inputs.to(self.model.device)}
else:
expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)}
self.model.eval()
with torch.no_grad():
outputs = self.model(**expanded_inputs, **kwargs)
pred_id = torch.argmax(outputs.logits, dim=-1).item()
return ModelOutput({"pred_id": pred_id})
def postprocess(self, model_outputs: dict, **kwargs):
assert self.tokenizer
pred_id = model_outputs["pred_id"]
return process_label(pred_id)
PIPELINE_REGISTRY.register_pipeline(
"bert-nucleotide-classification",
pipeline_class=BERTNucleotideClassificationPipeline,
pt_model=BertForSequenceClassification,
) |