File size: 3,538 Bytes
91aa26b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | from typing import Any, Optional
import torch
from transformers import BertForSequenceClassification, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.utils.generic import ModelOutput
def process_label(p: str) -> str:
if p == 0:
return "EXON"
if p == 1:
return "INTRON"
return "UNKNOWN"
def ensure_optional_str(value: Any) -> str:
return value if isinstance(value, str) else ""
class DNABERT2NucleotideClassificationPipeline(Pipeline):
def _build_prompt(
self,
sequence: str,
before: str,
after: str
) -> str:
return (
f"{before}[SEP]"
f"{sequence}[SEP]"
f"{after}"
)
def _sanitize_parameters(
self,
**kwargs
):
preprocess_kwargs = {}
for k in ("before", "after", "max_length"):
if k in kwargs:
preprocess_kwargs[k] = kwargs[k]
forward_kwargs = {
k: v for k, v in kwargs.items()
if k not in preprocess_kwargs
}
postprocess_kwargs = {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(
self,
input_,
**preprocess_parameters
):
assert self.tokenizer
if isinstance(input_, str):
sequence = input_
elif isinstance(input_, dict):
sequence = input_.get("sequence", "")
else:
raise TypeError("input_ must be str or dict with 'sequence' key")
before_raw = preprocess_parameters.get("before", None)
after_raw = preprocess_parameters.get("after", None)
if before_raw is None and isinstance(input_, dict):
before_raw = input_.get("before", None)
if after_raw is None and isinstance(input_, dict):
after_raw = input_.get("after", None)
before: str = ensure_optional_str(before_raw)
after: str = ensure_optional_str(after_raw)
max_length = preprocess_parameters.get("max_length", 256)
if not isinstance(max_length, int):
raise TypeError("max_length must be an int")
prompt = self._build_prompt(sequence, before=before, after=after)
token_kwargs: dict[str, Any] = {"return_tensors": "pt"}
token_kwargs["max_length"] = max_length
token_kwargs["truncation"] = True
enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device)
return {"prompt": prompt, "inputs": enc}
def _forward(self, input_tensors: dict, **forward_params):
kwargs = dict(forward_params)
inputs = input_tensors.get("inputs")
if inputs is None:
raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').")
if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor):
try:
expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()}
except Exception:
expanded_inputs = {}
for k, v in dict(inputs).items():
expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v
else:
if isinstance(inputs, torch.Tensor):
expanded_inputs = {"input_ids": inputs.to(self.model.device)}
else:
expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)}
self.model.eval()
with torch.no_grad():
outputs = self.model(**expanded_inputs, **kwargs)
pred_id = torch.argmax(outputs.logits, dim=-1).item()
return ModelOutput({"pred_id": pred_id})
def postprocess(self, model_outputs: dict, **kwargs):
assert self.tokenizer
pred_id = model_outputs["pred_id"]
return process_label(pred_id)
PIPELINE_REGISTRY.register_pipeline(
"dnabert2-nucleotide-classification",
pipeline_class=DNABERT2NucleotideClassificationPipeline,
pt_model=BertForSequenceClassification,
) |