File size: 4,595 Bytes
dd192e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778e7ea
dd192e9
778e7ea
 
dd192e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from typing import Any, Optional

import torch
from transformers import BertForSequenceClassification, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.utils.generic import ModelOutput

NUCLEOTIDE_MAP = {
	"A": "[DNA_A]",
	"C": "[DNA_C]",
	"G": "[DNA_G]",
	"T": "[DNA_T]",
	"R": "[DNA_R]",
	"Y": "[DNA_Y]",
	"S": "[DNA_S]",
	"W": "[DNA_W]",
	"K": "[DNA_K]",
	"M": "[DNA_M]",
	"B": "[DNA_B]",
	"D": "[DNA_D]",
	"H": "[DNA_H]",
	"V": "[DNA_V]",
	"N": "[DNA_N]",
	"I": "[INTRON]",
	"E": "[EXON]",
	"U": "[DNA_UNKNOWN]"
}

def process_sequence(seq: str) -> str:
	seq = seq.strip().upper()
	return "".join(NUCLEOTIDE_MAP.get(ch, "[DNA_INVALID]") for ch in seq)

def process_label(p: str) -> str:
	if p == 0:
		return "EXON"
	if p == 1:
		return "INTRON"
	return "UNKNOWN"

def ensure_optional_str(value: Any) -> str:
	return value if isinstance(value, str) else ""

class BERTNucleotideClassificationPipeline(Pipeline):
	def _build_prompt(
		self,
		sequence: str,
		before: str,
		after: str,
		organism: Optional[str]
	) -> str: 
		out = f"<|SEQUENCE|>{process_sequence(sequence[0])}"
		
		before_p = process_sequence(before[:24])
		out += f"<|FLANK_BEFORE|>{before_p}"
	
		after_p = process_sequence(after[:24])
		out += f"<|FLANK_AFTER|>{after_p}"

		if organism:
			out += f"<|ORGANISM|>{organism[:10].lower()}"

		out += "<|TARGET|>"

		return out

	def _sanitize_parameters(
		self,
		**kwargs
	):
		preprocess_kwargs = {}

		for k in ("organism", "before", "after", "max_length"):
			if k in kwargs:
				preprocess_kwargs[k] = kwargs[k]
		
		forward_kwargs = {
			k: v for k, v in kwargs.items()
			if k not in preprocess_kwargs
		}

		postprocess_kwargs = {}

		return preprocess_kwargs, forward_kwargs, postprocess_kwargs

	def preprocess(
		self,
		input_,
		**preprocess_parameters
	):
		assert self.tokenizer

		if isinstance(input_, str):
			sequence = input_
		elif isinstance(input_, dict):
			sequence = input_.get("sequence", "")
		else:
			raise TypeError("input_ must be str or dict with 'sequence' key")

		organism_raw = preprocess_parameters.get("organism", None)
		before_raw = preprocess_parameters.get("before", None)
		after_raw = preprocess_parameters.get("after", None)

		if organism_raw is None and isinstance(input_, dict):
			organism_raw = input_.get("organism", None)
		if before_raw is None and isinstance(input_, dict):
			before_raw = input_.get("before", None)
		if after_raw is None and isinstance(input_, dict):
			after_raw = input_.get("after", None)

		before: str = ensure_optional_str(before_raw)
		after: str = ensure_optional_str(after_raw)
		organism: Optional[str] = ensure_optional_str(organism_raw)

		max_length = preprocess_parameters.get("max_length", 256)
		if not isinstance(max_length, int):
			raise TypeError("max_length must be an int")

		prompt = self._build_prompt(sequence, before=before, after=after, organism=organism)

		token_kwargs: dict[str, Any] = {"return_tensors": "pt"}
		token_kwargs["max_length"] = max_length
		token_kwargs["truncation"] = True

		enc = self.tokenizer(prompt, **token_kwargs).to(self.model.device)

		return {"prompt": prompt, "inputs": enc}
	
	def _forward(self, input_tensors: dict, **forward_params):
		assert isinstance(self.model, BertForSequenceClassification)
		kwargs = dict(forward_params)

		inputs = input_tensors.get("inputs")

		if inputs is None:
			raise ValueError("Model inputs missing in input_tensors (expected key 'inputs').")

		if hasattr(inputs, "items") and not isinstance(inputs, torch.Tensor):
			try:
				expanded_inputs: dict[str, torch.Tensor] = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in dict(inputs).items()}
			except Exception:
				expanded_inputs = {}
				for k, v in dict(inputs).items():
					expanded_inputs[k] = v.to(self.model.device) if isinstance(v, torch.Tensor) else v
		else:
			if isinstance(inputs, torch.Tensor):
				expanded_inputs = {"input_ids": inputs.to(self.model.device)}
			else:
				expanded_inputs = {"input_ids": torch.tensor(inputs, device=self.model.device)}

		self.model.eval()
		with torch.no_grad():
			outputs = self.model(**expanded_inputs, **kwargs)
		
		pred_id = torch.argmax(outputs.logits, dim=-1).item()

		return ModelOutput({"pred_id": pred_id})

	def postprocess(self, model_outputs: dict, **kwargs):
		assert self.tokenizer

		pred_id = model_outputs["pred_id"]
		return process_label(pred_id)

PIPELINE_REGISTRY.register_pipeline(
	"bert-nucleotide-classification",
	pipeline_class=BERTNucleotideClassificationPipeline,
	pt_model=BertForSequenceClassification,
)