from typing import Any import torch from transformers import BertForSequenceClassification, Pipeline from transformers.pipelines import PIPELINE_REGISTRY from transformers.utils.generic import ModelOutput def process_label(p: str) -> str: return "INTRON" if p == 0 else "EXON" class DNABERT2ExonIntronClassificationPipeline(Pipeline): def _sanitize_parameters( self, **kwargs ): preprocess_kwargs = {} for k in ("max_length"): if k in kwargs: preprocess_kwargs[k] = kwargs[k] forward_kwargs = { k: v for k, v in kwargs.items() if k not in preprocess_kwargs } postprocess_kwargs = {} return preprocess_kwargs, forward_kwargs, postprocess_kwargs def preprocess( self, input_, **preprocess_parameters ): assert self.tokenizer if isinstance(input_, str): sequence = input_ elif isinstance(input_, dict): sequence = input_.get("sequence", "") else: raise TypeError("input_ must be str or dict with 'sequence' key") sequence = sequence[:256] max_length = preprocess_parameters.get("max_length", 256) if not isinstance(max_length, int): raise TypeError("max_length must be an int") token_kwargs: dict[str, Any] = {"return_tensors": "pt"} token_kwargs["max_length"] = max_length token_kwargs["truncation"] = True enc = self.tokenizer(sequence, **token_kwargs).to(self.model.device) return {"prompt": sequence, "inputs": enc} def _forward(self, input_tensors: dict, **forward_params): kwargs = dict(forward_params) inputs = input_tensors.get("inputs") if inputs is None: raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').") if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor): try: expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()} except Exception: expanded_inputs = {} for k, v in dict(inputs).items(): expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v else: if isinstance(inputs, torch.Tensor): expanded_inputs = {"input_ids": inputs.to(self.model.device)} else: expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)} self.model.eval() with torch.no_grad(): outputs = self.model(**expanded_inputs, **kwargs) pred_id = torch.argmax(outputs.logits, dim=-1).item() return ModelOutput({"pred_id": pred_id}) def postprocess(self, model_outputs: dict, **kwargs): assert self.tokenizer pred_id = model_outputs["pred_id"] return process_label(pred_id) PIPELINE_REGISTRY.register_pipeline( "dnabert2-exon-intron-classification", pipeline_class=DNABERT2ExonIntronClassificationPipeline, pt_model=BertForSequenceClassification, )