from typing import Any, Optional import torch from transformers import BertForSequenceClassification, Pipeline from transformers.pipelines import PIPELINE_REGISTRY from transformers.utils.generic import ModelOutput NUCLEOTIDE_MAP = { "A": "[DNA_A]", "C": "[DNA_C]", "G": "[DNA_G]", "T": "[DNA_T]", "R": "[DNA_R]", "Y": "[DNA_Y]", "S": "[DNA_S]", "W": "[DNA_W]", "K": "[DNA_K]", "M": "[DNA_M]", "B": "[DNA_B]", "D": "[DNA_D]", "H": "[DNA_H]", "V": "[DNA_V]", "N": "[DNA_N]", "I": "[INTRON]", "E": "[EXON]", "U": "[DNA_UNKNOWN]" } def process_sequence(seq: str) -> str: seq = seq.strip().upper() return "".join(NUCLEOTIDE_MAP.get(ch, "[DNA_INVALID]") for ch in seq) def process_label(p: str) -> str: if p == 0: return "EXON" if p == 1: return "INTRON" return "UNKNOWN" def ensure_optional_str(value: Any) -> str: return value if isinstance(value, str) else "" class BERTNucleotideClassificationPipeline(Pipeline): def _build_prompt( self, sequence: str, before: str, after: str, organism: Optional[str] ) -> str: out = f"<|SEQUENCE|>{process_sequence(sequence[0])}" before_p = process_sequence(before[:24]) out += f"<|FLANK_BEFORE|>{before_p}" after_p = process_sequence(after[:24]) out += f"<|FLANK_AFTER|>{after_p}" if organism: out += f"<|ORGANISM|>{organism[:10].lower()}" out += "<|TARGET|>" return out def _sanitize_parameters( self, **kwargs ): preprocess_kwargs = {} for k in ("organism", "before", "after", "max_length"): if k in kwargs: preprocess_kwargs[k] = kwargs[k] forward_kwargs = { k: v for k, v in kwargs.items() if k not in preprocess_kwargs } postprocess_kwargs = {} return preprocess_kwargs, forward_kwargs, postprocess_kwargs def preprocess( self, input_, **preprocess_parameters ): assert self.tokenizer if isinstance(input_, str): sequence = input_ elif isinstance(input_, dict): sequence = input_.get("sequence", "") else: raise TypeError("input_ must be str or dict with 'sequence' key") organism_raw = preprocess_parameters.get("organism", None) before_raw = preprocess_parameters.get("before", None) after_raw = preprocess_parameters.get("after", None) if organism_raw is None and isinstance(input_, dict): organism_raw = input_.get("organism", None) if before_raw is None and isinstance(input_, dict): before_raw = input_.get("before", None) if after_raw is None and isinstance(input_, dict): after_raw = input_.get("after", None) before: str = ensure_optional_str(before_raw) after: str = ensure_optional_str(after_raw) organism: Optional[str] = ensure_optional_str(organism_raw) max_length = preprocess_parameters.get("max_length", 256) if not isinstance(max_length, int): raise TypeError("max_length must be an int") prompt = self._build_prompt(sequence, before=before, after=after, organism=organism) token_kwargs: dict[str, Any] = {"return_tensors": "pt"} token_kwargs["max_length"] = max_length token_kwargs["truncation"] = True enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device) return {"prompt": prompt, "inputs": enc} def _forward(self, input_tensors: dict, **forward_params): assert isinstance(self.model, BertForSequenceClassification) kwargs = dict(forward_params) inputs = input_tensors.get("inputs") if inputs is None: raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').") if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor): try: expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()} except Exception: expanded_inputs = {} for k, v in dict(inputs).items(): expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v else: if isinstance(inputs, torch.Tensor): expanded_inputs = {"input_ids": inputs.to(self.model.device)} else: expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)} self.model.eval() with torch.no_grad(): outputs = self.model(**expanded_inputs, **kwargs) pred_id = torch.argmax(outputs.logits, dim=-1).item() return ModelOutput({"pred_id": pred_id}) def postprocess(self, model_outputs: dict, **kwargs): assert self.tokenizer pred_id = model_outputs["pred_id"] return process_label(pred_id) PIPELINE_REGISTRY.register_pipeline( "bert-nucleotide-classification", pipeline_class=BERTNucleotideClassificationPipeline, pt_model=BertForSequenceClassification, )