| from typing import Any, Optional | |
| import torch | |
| from transformers import BertForSequenceClassification, Pipeline | |
| from transformers.pipelines import PIPELINE_REGISTRY | |
| from transformers.utils.generic import ModelOutput | |
| DNA_MAP = { | |
| "A": "[A]", | |
| "C": "[C]", | |
| "G": "[G]", | |
| "T": "[T]", | |
| "R": "[R]", | |
| "Y": "[Y]", | |
| "S": "[S]", | |
| "W": "[W]", | |
| "K": "[K]", | |
| "M": "[M]", | |
| "B": "[B]", | |
| "D": "[D]", | |
| "H": "[H]", | |
| "V": "[V]", | |
| "N": "[N]" | |
| } | |
| def process_sequence(seq: str) -> str: | |
| seq = seq.strip().upper() | |
| return "".join(DNA_MAP.get(ch, "[N]") for ch in seq) | |
| def process_label(p: str) -> str: | |
| return "EXON" if p == 0 else "INTRON" | |
| def ensure_optional_str(value: Any) -> Optional[str]: | |
| return value if isinstance(value, str) else None | |
| class BERTExonIntronClassificationPipeline(Pipeline): | |
| def _build_prompt( | |
| self, | |
| sequence: str, | |
| organism: Optional[str], | |
| gene: Optional[str], | |
| before: Optional[str], | |
| after: Optional[str] | |
| ) -> str: | |
| out = f"<|SEQUENCE|>{process_sequence(sequence[:256])}" | |
| if organism: | |
| out += f"<|ORGANISM|>{organism[:10]}" | |
| if gene: | |
| out += f"<|GENE|>{gene[:10]}" | |
| if before: | |
| before_p = process_sequence(before[:25]) | |
| out += f"<|FLANK_BEFORE|>{before_p}" | |
| if after: | |
| after_p = process_sequence(after[:25]) | |
| out += f"<|FLANK_AFTER|>{after_p}" | |
| return out | |
| def _sanitize_parameters( | |
| self, | |
| **kwargs | |
| ): | |
| preprocess_kwargs = {} | |
| for k in ("organism", "gene", "before", "after", "max_length"): | |
| if k in kwargs: | |
| preprocess_kwargs[k] = kwargs[k] | |
| forward_kwargs = { | |
| k: v for k, v in kwargs.items() | |
| if k not in preprocess_kwargs | |
| } | |
| postprocess_kwargs = {} | |
| return preprocess_kwargs, forward_kwargs, postprocess_kwargs | |
| def preprocess( | |
| self, | |
| input_, | |
| **preprocess_parameters | |
| ): | |
| assert self.tokenizer | |
| if isinstance(input_, str): | |
| sequence = input_ | |
| elif isinstance(input_, dict): | |
| sequence = input_.get("sequence", "") | |
| else: | |
| raise TypeError("input_ must be str or dict with 'sequence' key") | |
| organism_raw = preprocess_parameters.get("organism", None) | |
| gene_raw = preprocess_parameters.get("gene", None) | |
| before_raw = preprocess_parameters.get("before", None) | |
| after_raw = preprocess_parameters.get("after", None) | |
| if organism_raw is None and isinstance(input_, dict): | |
| organism_raw = input_.get("organism", None) | |
| if gene_raw is None and isinstance(input_, dict): | |
| gene_raw = input_.get("gene", None) | |
| if before_raw is None and isinstance(input_, dict): | |
| before_raw = input_.get("before", None) | |
| if after_raw is None and isinstance(input_, dict): | |
| after_raw = input_.get("after", None) | |
| organism: Optional[str] = ensure_optional_str(organism_raw) | |
| gene: Optional[str] = ensure_optional_str(gene_raw) | |
| before: Optional[str] = ensure_optional_str(before_raw) | |
| after: Optional[str] = ensure_optional_str(after_raw) | |
| max_length = preprocess_parameters.get("max_length", 256) | |
| if not isinstance(max_length, int): | |
| raise TypeError("max_length must be an int") | |
| prompt = self._build_prompt(sequence, organism, gene, before, after) | |
| token_kwargs: dict[str, Any] = {"return_tensors": "pt"} | |
| token_kwargs["max_length"] = max_length | |
| token_kwargs["truncation"] = True | |
| enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device) | |
| return {"prompt": prompt, "inputs": enc} | |
| def _forward(self, input_tensors: dict, **forward_params): | |
| assert isinstance(self.model, BertForSequenceClassification) | |
| kwargs = dict(forward_params) | |
| inputs = input_tensors.get("inputs") | |
| if inputs is None: | |
| raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').") | |
| if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor): | |
| try: | |
| expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()} | |
| except Exception: | |
| expanded_inputs = {} | |
| for k, v in dict(inputs).items(): | |
| expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v | |
| else: | |
| if isinstance(inputs, torch.Tensor): | |
| expanded_inputs = {"input_ids": inputs.to(self.model.device)} | |
| else: | |
| expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)} | |
| self.model.eval() | |
| with torch.no_grad(): | |
| outputs = self.model(**expanded_inputs, **kwargs) | |
| pred_id = torch.argmax(outputs.logits, dim=-1).item() | |
| return ModelOutput({"pred_id": pred_id}) | |
| def postprocess(self, model_outputs: dict, **kwargs): | |
| assert self.tokenizer | |
| pred_id = model_outputs["pred_id"] | |
| return process_label(pred_id) | |
| PIPELINE_REGISTRY.register_pipeline( | |
| "bert-exon-intron-classification", | |
| pipeline_class=BERTExonIntronClassificationPipeline, | |
| pt_model=BertForSequenceClassification, | |
| ) |