from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import BatchEncoding class PDeepPPProcessor(ProcessorMixin): def __init__(self, pad_char="X", target_length=33): self.pad_char = pad_char self.target_length = target_length def pad_sequence(self, seq): """确保序列长度为 target_length,不足的部分用 pad_char 在两侧均匀填充""" if len(seq) < self.target_length: total_padding = self.target_length - len(seq) left_padding = total_padding // 2 right_padding = total_padding - left_padding seq = self.pad_char * left_padding + seq + self.pad_char * right_padding return seq[:self.target_length] def extract_ptm_sequences(self, sequences): """处理 PTM 数据,确保目标氨基酸(S、T、Y)位于序列中心""" ptm_data = [] for seq in sequences: for i in range(len(seq)): if seq[i] in {'S', 'T', 'Y'}: # 仅提取 S、T、Y 作为中心的片段 start = max(0, i - self.target_length // 2) end = min(len(seq), start + self.target_length) padded_seq = self.pad_sequence(seq[start:end]) ptm_data.append(padded_seq) return ptm_data def extract_bps_sequences(self, sequences, overlapping=True, step_size=5): """处理生物活性数据(BPS),关注整个序列,可重叠""" bioactive_data = [] for seq in sequences: if len(seq) < self.target_length: # 如果序列长度不足,直接填充到 target_length padded_seq = self.pad_sequence(seq) bioactive_data.append(padded_seq) else: # 如果序列长度足够,按照滑动窗口提取片段 for i in range(0, len(seq) - self.target_length + 1, step_size if overlapping else self.target_length): bioactive_data.append(self.pad_sequence(seq[i:i + self.target_length])) return bioactive_data def __call__( self, sequences, mode, # 去除默认值,强制外部传入 overlapping=True, step_size=5, **kwargs ): """ 预处理蛋白质序列,仅处理数据到指定长度。 Args: sequences: 序列列表或单个序列字符串。 mode: 选择处理模式,必须从外部传入,"PTM" 或 "BPS"。 overlapping: BPS 模式下是否使用重叠窗口。 step_size: BPS 模式下的步长。 """ # 确保 sequences 是列表 if isinstance(sequences, str): sequences = [sequences] # 根据模式提取序列 if mode == "PTM": processed_sequences = self.extract_ptm_sequences(sequences) elif mode == "BPS": processed_sequences = self.extract_bps_sequences( sequences, overlapping=overlapping, step_size=step_size ) else: raise ValueError("Invalid mode. Please choose 'PTM' or 'BPS'.") if len(processed_sequences) == 0: raise ValueError("No sequences processed. Check input data and processing logic.") # 创建返回字典,仅包含预处理后的序列 model_inputs = { "raw_sequences": processed_sequences, # 预处理后的序列 } return BatchEncoding(data=model_inputs) # 返回处理后的数据