NuclDNABERT2 / dnabert2_nucleotide_classification.py
GustavoHCruz's picture
Upload folder using huggingface_hub
91aa26b verified
from typing import Any, Optional
import torch
from transformers import BertForSequenceClassification, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.utils.generic import ModelOutput
def process_label(p: str) -> str:
if p == 0:
return "EXON"
if p == 1:
return "INTRON"
return "UNKNOWN"
def ensure_optional_str(value: Any) -> str:
return value if isinstance(value, str) else ""
class DNABERT2NucleotideClassificationPipeline(Pipeline):
def _build_prompt(
self,
sequence: str,
before: str,
after: str
) -> str:
return (
f"{before}[SEP]"
f"{sequence}[SEP]"
f"{after}"
)
def _sanitize_parameters(
self,
**kwargs
):
preprocess_kwargs = {}
for k in ("before", "after", "max_length"):
if k in kwargs:
preprocess_kwargs[k] = kwargs[k]
forward_kwargs = {
k: v for k, v in kwargs.items()
if k not in preprocess_kwargs
}
postprocess_kwargs = {}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(
self,
input_,
**preprocess_parameters
):
assert self.tokenizer
if isinstance(input_, str):
sequence = input_
elif isinstance(input_, dict):
sequence = input_.get("sequence", "")
else:
raise TypeError("input_ must be str or dict with 'sequence' key")
before_raw = preprocess_parameters.get("before", None)
after_raw = preprocess_parameters.get("after", None)
if before_raw is None and isinstance(input_, dict):
before_raw = input_.get("before", None)
if after_raw is None and isinstance(input_, dict):
after_raw = input_.get("after", None)
before: str = ensure_optional_str(before_raw)
after: str = ensure_optional_str(after_raw)
max_length = preprocess_parameters.get("max_length", 256)
if not isinstance(max_length, int):
raise TypeError("max_length must be an int")
prompt = self._build_prompt(sequence, before=before, after=after)
token_kwargs: dict[str, Any] = {"return_tensors": "pt"}
token_kwargs["max_length"] = max_length
token_kwargs["truncation"] = True
enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device)
return {"prompt": prompt, "inputs": enc}
def _forward(self, input_tensors: dict, **forward_params):
kwargs = dict(forward_params)
inputs = input_tensors.get("inputs")
if inputs is None:
raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').")
if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor):
try:
expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()}
except Exception:
expanded_inputs = {}
for k, v in dict(inputs).items():
expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v
else:
if isinstance(inputs, torch.Tensor):
expanded_inputs = {"input_ids": inputs.to(self.model.device)}
else:
expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)}
self.model.eval()
with torch.no_grad():
outputs = self.model(**expanded_inputs, **kwargs)
pred_id = torch.argmax(outputs.logits, dim=-1).item()
return ModelOutput({"pred_id": pred_id})
def postprocess(self, model_outputs: dict, **kwargs):
assert self.tokenizer
pred_id = model_outputs["pred_id"]
return process_label(pred_id)
PIPELINE_REGISTRY.register_pipeline(
"dnabert2-nucleotide-classification",
pipeline_class=DNABERT2NucleotideClassificationPipeline,
pt_model=BertForSequenceClassification,
)