ExInBERT / bert_exon_intron_classification.py
GustavoHCruz's picture
Upload folder using huggingface_hub
5eb484a verified
from typing import Any, Optional
import torch
from transformers import BertForSequenceClassification, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.utils.generic import ModelOutput
DNA_MAP = {
"A": "[A]",
"C": "[C]",
"G": "[G]",
"T": "[T]",
"R": "[R]",
"Y": "[Y]",
"S": "[S]",
"W": "[W]",
"K": "[K]",
"M": "[M]",
"B": "[B]",
"D": "[D]",
"H": "[H]",
"V": "[V]",
"N": "[N]"
}
def process_sequence(seq: str) -> str:
seq = seq.strip().upper()
return "".join(DNA_MAP.get(ch, "[N]") for ch in seq)
def process_label(p: str) -> str:
return "EXON" if p == 0 else "INTRON"
def ensure_optional_str(value: Any) -> Optional[str]:
return value if isinstance(value, str) else None
class BERTExonIntronClassificationPipeline(Pipeline):
def _build_prompt(
self,
sequence: str,
organism: Optional[str],
gene: Optional[str],
before: Optional[str],
after: Optional[str]
) -> str:
out = f"<|SEQUENCE|>{process_sequence(sequence[:256])}"
if organism:
out += f"<|ORGANISM|>{organism[:10]}"
if gene:
out += f"<|GENE|>{gene[:10]}"
if before:
before_p = process_sequence(before[:25])
out += f"<|FLANK_BEFORE|>{before_p}"
if after:
after_p = process_sequence(after[:25])
out += f"<|FLANK_AFTER|>{after_p}"
return out
def _sanitize_parameters(
self,
**kwargs
):
preprocess_kwargs = {}
for k in ("organism", "gene", "before", "after", "max_length"):
if k in kwargs:
preprocess_kwargs[k] = kwargs[k]
forward_kwargs = {
k: v for k, v in kwargs.items()
if k not in preprocess_kwargs
}
postprocess_kwargs = {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(
self,
input_,
**preprocess_parameters
):
assert self.tokenizer
if isinstance(input_, str):
sequence = input_
elif isinstance(input_, dict):
sequence = input_.get("sequence", "")
else:
raise TypeError("input_ must be str or dict with 'sequence' key")
organism_raw = preprocess_parameters.get("organism", None)
gene_raw = preprocess_parameters.get("gene", None)
before_raw = preprocess_parameters.get("before", None)
after_raw = preprocess_parameters.get("after", None)
if organism_raw is None and isinstance(input_, dict):
organism_raw = input_.get("organism", None)
if gene_raw is None and isinstance(input_, dict):
gene_raw = input_.get("gene", None)
if before_raw is None and isinstance(input_, dict):
before_raw = input_.get("before", None)
if after_raw is None and isinstance(input_, dict):
after_raw = input_.get("after", None)
organism: Optional[str] = ensure_optional_str(organism_raw)
gene: Optional[str] = ensure_optional_str(gene_raw)
before: Optional[str] = ensure_optional_str(before_raw)
after: Optional[str] = ensure_optional_str(after_raw)
max_length = preprocess_parameters.get("max_length", 256)
if not isinstance(max_length, int):
raise TypeError("max_length must be an int")
prompt = self._build_prompt(sequence, organism, gene, before, after)
token_kwargs: dict[str, Any] = {"return_tensors": "pt"}
token_kwargs["max_length"] = max_length
token_kwargs["truncation"] = True
enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device)
return {"prompt": prompt, "inputs": enc}
def _forward(self, input_tensors: dict, **forward_params):
assert isinstance(self.model, BertForSequenceClassification)
kwargs = dict(forward_params)
inputs = input_tensors.get("inputs")
if inputs is None:
raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').")
if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor):
try:
expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()}
except Exception:
expanded_inputs = {}
for k, v in dict(inputs).items():
expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v
else:
if isinstance(inputs, torch.Tensor):
expanded_inputs = {"input_ids": inputs.to(self.model.device)}
else:
expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)}
self.model.eval()
with torch.no_grad():
outputs = self.model(**expanded_inputs, **kwargs)
pred_id = torch.argmax(outputs.logits, dim=-1).item()
return ModelOutput({"pred_id": pred_id})
def postprocess(self, model_outputs: dict, **kwargs):
assert self.tokenizer
pred_id = model_outputs["pred_id"]
return process_label(pred_id)
PIPELINE_REGISTRY.register_pipeline(
"bert-exon-intron-classification",
pipeline_class=BERTExonIntronClassificationPipeline,
pt_model=BertForSequenceClassification,
)