from typing import Any, Optional import torch from transformers import BertForSequenceClassification, Pipeline from transformers.pipelines import PIPELINE_REGISTRY from transformers.utils.generic import ModelOutput def process_label(p: str) -> str: if p == 0: return "EXON" if p == 1: return "INTRON" return "UNKNOWN" def ensure_optional_str(value: Any) -> str: return value if isinstance(value, str) else "" class DNABERT2NucleotideClassificationPipeline(Pipeline): def _build_prompt( self, sequence: str, before: str, after: str ) -> str: return ( f"{before}[SEP]" f"{sequence}[SEP]" f"{after}" ) def _sanitize_parameters( self, **kwargs ): preprocess_kwargs = {} for k in ("before", "after", "max_length"): if k in kwargs: preprocess_kwargs[k] = kwargs[k] forward_kwargs = { k: v for k, v in kwargs.items() if k not in preprocess_kwargs } postprocess_kwargs = {} return preprocess_kwargs, forward_kwargs, postprocess_kwargs def preprocess( self, input_, **preprocess_parameters ): assert self.tokenizer if isinstance(input_, str): sequence = input_ elif isinstance(input_, dict): sequence = input_.get("sequence", "") else: raise TypeError("input_ must be str or dict with 'sequence' key") before_raw = preprocess_parameters.get("before", None) after_raw = preprocess_parameters.get("after", None) if before_raw is None and isinstance(input_, dict): before_raw = input_.get("before", None) if after_raw is None and isinstance(input_, dict): after_raw = input_.get("after", None) before: str = ensure_optional_str(before_raw) after: str = ensure_optional_str(after_raw) max_length = preprocess_parameters.get("max_length", 256) if not isinstance(max_length, int): raise TypeError("max_length must be an int") prompt = self._build_prompt(sequence, before=before, after=after) token_kwargs: dict[str, Any] = {"return_tensors": "pt"} token_kwargs["max_length"] = max_length token_kwargs["truncation"] = True enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device) return {"prompt": prompt, "inputs": enc} def _forward(self, input_tensors: dict, **forward_params): kwargs = dict(forward_params) inputs = input_tensors.get("inputs") if inputs is None: raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').") if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor): try: expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()} except Exception: expanded_inputs = {} for k, v in dict(inputs).items(): expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v else: if isinstance(inputs, torch.Tensor): expanded_inputs = {"input_ids": inputs.to(self.model.device)} else: expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)} self.model.eval() with torch.no_grad(): outputs = self.model(**expanded_inputs, **kwargs) pred_id = torch.argmax(outputs.logits, dim=-1).item() return ModelOutput({"pred_id": pred_id}) def postprocess(self, model_outputs: dict, **kwargs): assert self.tokenizer pred_id = model_outputs["pred_id"] return process_label(pred_id) PIPELINE_REGISTRY.register_pipeline( "dnabert2-nucleotide-classification", pipeline_class=DNABERT2NucleotideClassificationPipeline, pt_model=BertForSequenceClassification, )