| from typing import Any, Optional | |
| import torch | |
| from transformers import BertForSequenceClassification, Pipeline | |
| from transformers.pipelines import PIPELINE_REGISTRY | |
| from transformers.utils.generic import ModelOutput | |
| def process_label(p: str) -> str: | |
| if p == 0: | |
| return "EXON" | |
| if p == 1: | |
| return "INTRON" | |
| return "UNKNOWN" | |
| def ensure_optional_str(value: Any) -> str: | |
| return value if isinstance(value, str) else "" | |
| class DNABERT2NucleotideClassificationPipeline(Pipeline): | |
| def _build_prompt( | |
| self, | |
| sequence: str, | |
| before: str, | |
| after: str | |
| ) -> str: | |
| return ( | |
| f"{before}[SEP]" | |
| f"{sequence}[SEP]" | |
| f"{after}" | |
| ) | |
| def _sanitize_parameters( | |
| self, | |
| **kwargs | |
| ): | |
| preprocess_kwargs = {} | |
| for k in ("before", "after", "max_length"): | |
| if k in kwargs: | |
| preprocess_kwargs[k] = kwargs[k] | |
| forward_kwargs = { | |
| k: v for k, v in kwargs.items() | |
| if k not in preprocess_kwargs | |
| } | |
| postprocess_kwargs = {} | |
| return preprocess_kwargs, forward_kwargs, postprocess_kwargs | |
| def preprocess( | |
| self, | |
| input_, | |
| **preprocess_parameters | |
| ): | |
| assert self.tokenizer | |
| if isinstance(input_, str): | |
| sequence = input_ | |
| elif isinstance(input_, dict): | |
| sequence = input_.get("sequence", "") | |
| else: | |
| raise TypeError("input_ must be str or dict with 'sequence' key") | |
| before_raw = preprocess_parameters.get("before", None) | |
| after_raw = preprocess_parameters.get("after", None) | |
| if before_raw is None and isinstance(input_, dict): | |
| before_raw = input_.get("before", None) | |
| if after_raw is None and isinstance(input_, dict): | |
| after_raw = input_.get("after", None) | |
| before: str = ensure_optional_str(before_raw) | |
| after: str = ensure_optional_str(after_raw) | |
| max_length = preprocess_parameters.get("max_length", 256) | |
| if not isinstance(max_length, int): | |
| raise TypeError("max_length must be an int") | |
| prompt = self._build_prompt(sequence, before=before, after=after) | |
| token_kwargs: dict[str, Any] = {"return_tensors": "pt"} | |
| token_kwargs["max_length"] = max_length | |
| token_kwargs["truncation"] = True | |
| enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device) | |
| return {"prompt": prompt, "inputs": enc} | |
| def _forward(self, input_tensors: dict, **forward_params): | |
| kwargs = dict(forward_params) | |
| inputs = input_tensors.get("inputs") | |
| if inputs is None: | |
| raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').") | |
| if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor): | |
| try: | |
| expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()} | |
| except Exception: | |
| expanded_inputs = {} | |
| for k, v in dict(inputs).items(): | |
| expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v | |
| else: | |
| if isinstance(inputs, torch.Tensor): | |
| expanded_inputs = {"input_ids": inputs.to(self.model.device)} | |
| else: | |
| expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)} | |
| self.model.eval() | |
| with torch.no_grad(): | |
| outputs = self.model(**expanded_inputs, **kwargs) | |
| pred_id = torch.argmax(outputs.logits, dim=-1).item() | |
| return ModelOutput({"pred_id": pred_id}) | |
| def postprocess(self, model_outputs: dict, **kwargs): | |
| assert self.tokenizer | |
| pred_id = model_outputs["pred_id"] | |
| return process_label(pred_id) | |
| PIPELINE_REGISTRY.register_pipeline( | |
| "dnabert2-nucleotide-classification", | |
| pipeline_class=DNABERT2NucleotideClassificationPipeline, | |
| pt_model=BertForSequenceClassification, | |
| ) |