File size: 3,538 Bytes
91aa26b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Any, Optional

import torch
from transformers import BertForSequenceClassification, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.utils.generic import ModelOutput


def process_label(p: str) -> str:
	if p == 0:
		return "EXON"
	if p == 1:
		return "INTRON"
	return "UNKNOWN"

def ensure_optional_str(value: Any) -> str:
	return value if isinstance(value, str) else ""

class DNABERT2NucleotideClassificationPipeline(Pipeline):
	def _build_prompt(
		self,
		sequence: str,
		before: str,
		after: str
	) -> str:
		return (
			f"{before}[SEP]"
			f"{sequence}[SEP]"
			f"{after}"
		)

	def _sanitize_parameters(
		self,
		**kwargs
	):
		preprocess_kwargs = {}

		for k in ("before", "after", "max_length"):
			if k in kwargs:
				preprocess_kwargs[k] = kwargs[k]
		
		forward_kwargs = {
			k: v for k, v in kwargs.items()
			if k not in preprocess_kwargs
		}

		postprocess_kwargs = {}

		return preprocess_kwargs, forward_kwargs, postprocess_kwargs

	def preprocess(
		self,
		input_,
		**preprocess_parameters
	):
		assert self.tokenizer

		if isinstance(input_, str):
			sequence = input_
		elif isinstance(input_, dict):
			sequence = input_.get("sequence", "")
		else:
			raise TypeError("input_ must be str or dict with 'sequence' key")

		before_raw = preprocess_parameters.get("before", None)
		after_raw = preprocess_parameters.get("after", None)

		if before_raw is None and isinstance(input_, dict):
			before_raw = input_.get("before", None)
		if after_raw is None and isinstance(input_, dict):
			after_raw = input_.get("after", None)

		before: str = ensure_optional_str(before_raw)
		after: str = ensure_optional_str(after_raw)

		max_length = preprocess_parameters.get("max_length", 256)
		if not isinstance(max_length, int):
			raise TypeError("max_length must be an int")

		prompt = self._build_prompt(sequence, before=before, after=after)

		token_kwargs: dict[str, Any] = {"return_tensors": "pt"}
		token_kwargs["max_length"] = max_length
		token_kwargs["truncation"] = True

		enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device)

		return {"prompt": prompt, "inputs": enc}
	
	def _forward(self, input_tensors: dict, **forward_params):
		kwargs = dict(forward_params)

		inputs = input_tensors.get("inputs")

		if inputs is None:
			raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').")

		if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor):
			try:
				expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()}
			except Exception:
				expanded_inputs = {}
				for k, v in dict(inputs).items():
					expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v
		else:
			if isinstance(inputs, torch.Tensor):
				expanded_inputs = {"input_ids": inputs.to(self.model.device)}
			else:
				expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)}

		self.model.eval()
		with torch.no_grad():
			outputs = self.model(**expanded_inputs, **kwargs)
		
		pred_id = torch.argmax(outputs.logits, dim=-1).item()

		return ModelOutput({"pred_id": pred_id})

	def postprocess(self, model_outputs: dict, **kwargs):
		assert self.tokenizer

		pred_id = model_outputs["pred_id"]
		return process_label(pred_id)

PIPELINE_REGISTRY.register_pipeline(
	"dnabert2-nucleotide-classification",
	pipeline_class=DNABERT2NucleotideClassificationPipeline,
	pt_model=BertForSequenceClassification,
)