NuclBERT / bert_nucleotide_classification.py
GustavoHCruz's picture
Upload folder using huggingface_hub
778e7ea verified
from typing import Any, Optional
import torch
from transformers import BertForSequenceClassification, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.utils.generic import ModelOutput
NUCLEOTIDE_MAP = {
"A": "[DNA_A]",
"C": "[DNA_C]",
"G": "[DNA_G]",
"T": "[DNA_T]",
"R": "[DNA_R]",
"Y": "[DNA_Y]",
"S": "[DNA_S]",
"W": "[DNA_W]",
"K": "[DNA_K]",
"M": "[DNA_M]",
"B": "[DNA_B]",
"D": "[DNA_D]",
"H": "[DNA_H]",
"V": "[DNA_V]",
"N": "[DNA_N]",
"I": "[INTRON]",
"E": "[EXON]",
"U": "[DNA_UNKNOWN]"
}
def process_sequence(seq: str) -> str:
seq = seq.strip().upper()
return "".join(NUCLEOTIDE_MAP.get(ch, "[DNA_INVALID]") for ch in seq)
def process_label(p: str) -> str:
if p == 0:
return "EXON"
if p == 1:
return "INTRON"
return "UNKNOWN"
def ensure_optional_str(value: Any) -> str:
return value if isinstance(value, str) else ""
class BERTNucleotideClassificationPipeline(Pipeline):
def _build_prompt(
self,
sequence: str,
before: str,
after: str,
organism: Optional[str]
) -> str:
out = f"<|SEQUENCE|>{process_sequence(sequence[0])}"
before_p = process_sequence(before[:24])
out += f"<|FLANK_BEFORE|>{before_p}"
after_p = process_sequence(after[:24])
out += f"<|FLANK_AFTER|>{after_p}"
if organism:
out += f"<|ORGANISM|>{organism[:10].lower()}"
out += "<|TARGET|>"
return out
def _sanitize_parameters(
self,
**kwargs
):
preprocess_kwargs = {}
for k in ("organism", "before", "after", "max_length"):
if k in kwargs:
preprocess_kwargs[k] = kwargs[k]
forward_kwargs = {
k: v for k, v in kwargs.items()
if k not in preprocess_kwargs
}
postprocess_kwargs = {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(
self,
input_,
**preprocess_parameters
):
assert self.tokenizer
if isinstance(input_, str):
sequence = input_
elif isinstance(input_, dict):
sequence = input_.get("sequence", "")
else:
raise TypeError("input_ must be str or dict with 'sequence' key")
organism_raw = preprocess_parameters.get("organism", None)
before_raw = preprocess_parameters.get("before", None)
after_raw = preprocess_parameters.get("after", None)
if organism_raw is None and isinstance(input_, dict):
organism_raw = input_.get("organism", None)
if before_raw is None and isinstance(input_, dict):
before_raw = input_.get("before", None)
if after_raw is None and isinstance(input_, dict):
after_raw = input_.get("after", None)
before: str = ensure_optional_str(before_raw)
after: str = ensure_optional_str(after_raw)
organism: Optional[str] = ensure_optional_str(organism_raw)
max_length = preprocess_parameters.get("max_length", 256)
if not isinstance(max_length, int):
raise TypeError("max_length must be an int")
prompt = self._build_prompt(sequence, before=before, after=after, organism=organism)
token_kwargs: dict[str, Any] = {"return_tensors": "pt"}
token_kwargs["max_length"] = max_length
token_kwargs["truncation"] = True
enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device)
return {"prompt": prompt, "inputs": enc}
def _forward(self, input_tensors: dict, **forward_params):
assert isinstance(self.model, BertForSequenceClassification)
kwargs = dict(forward_params)
inputs = input_tensors.get("inputs")
if inputs is None:
raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').")
if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor):
try:
expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()}
except Exception:
expanded_inputs = {}
for k, v in dict(inputs).items():
expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v
else:
if isinstance(inputs, torch.Tensor):
expanded_inputs = {"input_ids": inputs.to(self.model.device)}
else:
expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)}
self.model.eval()
with torch.no_grad():
outputs = self.model(**expanded_inputs, **kwargs)
pred_id = torch.argmax(outputs.logits, dim=-1).item()
return ModelOutput({"pred_id": pred_id})
def postprocess(self, model_outputs: dict, **kwargs):
assert self.tokenizer
pred_id = model_outputs["pred_id"]
return process_label(pred_id)
PIPELINE_REGISTRY.register_pipeline(
"bert-nucleotide-classification",
pipeline_class=BERTNucleotideClassificationPipeline,
pt_model=BertForSequenceClassification,
)