| from typing import Any, Optional | |
| import torch | |
| from transformers import BertForSequenceClassification, Pipeline | |
| from transformers.pipelines import PIPELINE_REGISTRY | |
| from transformers.utils.generic import ModelOutput | |
| NUCLEOTIDE_MAP = { | |
| "A": "[DNA_A]", | |
| "C": "[DNA_C]", | |
| "G": "[DNA_G]", | |
| "T": "[DNA_T]", | |
| "R": "[DNA_R]", | |
| "Y": "[DNA_Y]", | |
| "S": "[DNA_S]", | |
| "W": "[DNA_W]", | |
| "K": "[DNA_K]", | |
| "M": "[DNA_M]", | |
| "B": "[DNA_B]", | |
| "D": "[DNA_D]", | |
| "H": "[DNA_H]", | |
| "V": "[DNA_V]", | |
| "N": "[DNA_N]", | |
| "I": "[INTRON]", | |
| "E": "[EXON]", | |
| "U": "[DNA_UNKNOWN]" | |
| } | |
| def process_sequence(seq: str) -> str: | |
| seq = seq.strip().upper() | |
| return "".join(NUCLEOTIDE_MAP.get(ch, "[DNA_INVALID]") for ch in seq) | |
| def process_label(p: str) -> str: | |
| if p == 0: | |
| return "EXON" | |
| if p == 1: | |
| return "INTRON" | |
| return "UNKNOWN" | |
| def ensure_optional_str(value: Any) -> str: | |
| return value if isinstance(value, str) else "" | |
| class BERTNucleotideClassificationPipeline(Pipeline): | |
| def _build_prompt( | |
| self, | |
| sequence: str, | |
| before: str, | |
| after: str, | |
| organism: Optional[str] | |
| ) -> str: | |
| out = f"<|SEQUENCE|>{process_sequence(sequence[0])}" | |
| before_p = process_sequence(before[:24]) | |
| out += f"<|FLANK_BEFORE|>{before_p}" | |
| after_p = process_sequence(after[:24]) | |
| out += f"<|FLANK_AFTER|>{after_p}" | |
| if organism: | |
| out += f"<|ORGANISM|>{organism[:10].lower()}" | |
| out += "<|TARGET|>" | |
| return out | |
| def _sanitize_parameters( | |
| self, | |
| **kwargs | |
| ): | |
| preprocess_kwargs = {} | |
| for k in ("organism", "before", "after", "max_length"): | |
| if k in kwargs: | |
| preprocess_kwargs[k] = kwargs[k] | |
| forward_kwargs = { | |
| k: v for k, v in kwargs.items() | |
| if k not in preprocess_kwargs | |
| } | |
| postprocess_kwargs = {} | |
| return preprocess_kwargs, forward_kwargs, postprocess_kwargs | |
| def preprocess( | |
| self, | |
| input_, | |
| **preprocess_parameters | |
| ): | |
| assert self.tokenizer | |
| if isinstance(input_, str): | |
| sequence = input_ | |
| elif isinstance(input_, dict): | |
| sequence = input_.get("sequence", "") | |
| else: | |
| raise TypeError("input_ must be str or dict with 'sequence' key") | |
| organism_raw = preprocess_parameters.get("organism", None) | |
| before_raw = preprocess_parameters.get("before", None) | |
| after_raw = preprocess_parameters.get("after", None) | |
| if organism_raw is None and isinstance(input_, dict): | |
| organism_raw = input_.get("organism", None) | |
| if before_raw is None and isinstance(input_, dict): | |
| before_raw = input_.get("before", None) | |
| if after_raw is None and isinstance(input_, dict): | |
| after_raw = input_.get("after", None) | |
| before: str = ensure_optional_str(before_raw) | |
| after: str = ensure_optional_str(after_raw) | |
| organism: Optional[str] = ensure_optional_str(organism_raw) | |
| max_length = preprocess_parameters.get("max_length", 256) | |
| if not isinstance(max_length, int): | |
| raise TypeError("max_length must be an int") | |
| prompt = self._build_prompt(sequence, before=before, after=after, organism=organism) | |
| token_kwargs: dict[str, Any] = {"return_tensors": "pt"} | |
| token_kwargs["max_length"] = max_length | |
| token_kwargs["truncation"] = True | |
| enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device) | |
| return {"prompt": prompt, "inputs": enc} | |
| def _forward(self, input_tensors: dict, **forward_params): | |
| assert isinstance(self.model, BertForSequenceClassification) | |
| kwargs = dict(forward_params) | |
| inputs = input_tensors.get("inputs") | |
| if inputs is None: | |
| raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').") | |
| if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor): | |
| try: | |
| expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()} | |
| except Exception: | |
| expanded_inputs = {} | |
| for k, v in dict(inputs).items(): | |
| expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v | |
| else: | |
| if isinstance(inputs, torch.Tensor): | |
| expanded_inputs = {"input_ids": inputs.to(self.model.device)} | |
| else: | |
| expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)} | |
| self.model.eval() | |
| with torch.no_grad(): | |
| outputs = self.model(**expanded_inputs, **kwargs) | |
| pred_id = torch.argmax(outputs.logits, dim=-1).item() | |
| return ModelOutput({"pred_id": pred_id}) | |
| def postprocess(self, model_outputs: dict, **kwargs): | |
| assert self.tokenizer | |
| pred_id = model_outputs["pred_id"] | |
| return process_label(pred_id) | |
| PIPELINE_REGISTRY.register_pipeline( | |
| "bert-nucleotide-classification", | |
| pipeline_class=BERTNucleotideClassificationPipeline, | |
| pt_model=BertForSequenceClassification, | |
| ) |