| import wandb |
| |
| |
| wandb.login(key="04fa40f46e9b09c72fc2dcb1457767c7ad809037") |
| import os |
| import sys |
| os.environ["DISABLE_TRITON"] = "1" |
| sys.modules['triton'] = None |
| sys.modules['flash_attn_triton'] = None |
|
|
| import csv |
| import copy |
| import json |
| import logging |
| from dataclasses import dataclass, field |
| from typing import Any, Optional, Dict, Sequence, Tuple, List, Union |
|
|
| import torch |
| import transformers |
| import sklearn |
| import numpy as np |
| from torch.utils.data import Dataset |
| import importlib |
| from pathlib import Path |
| import itertools |
|
|
| from transformers import BertConfig, BertForSequenceClassification |
|
|
| from transformers import ( |
| WEIGHTS_NAME, |
| AdamW, |
| BertConfig, |
| BertForMaskedLM, |
| BertTokenizer, |
| CamembertConfig, |
| CamembertForMaskedLM, |
| CamembertTokenizer, |
| DistilBertConfig, |
| DistilBertForMaskedLM, |
| DistilBertTokenizer, |
| GPT2Config, |
| GPT2LMHeadModel, |
| GPT2Tokenizer, |
| OpenAIGPTConfig, |
| OpenAIGPTLMHeadModel, |
| OpenAIGPTTokenizer, |
| PreTrainedModel, |
| PreTrainedTokenizer, |
| RobertaConfig, |
| RobertaForMaskedLM, |
| RobertaTokenizer, |
| get_linear_schedule_with_warmup, |
| ) |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
| trust_remote_code: bool = field(default=False, metadata={"help": "for custom models(has custom code that needs to be executed (e.g., custom architectures, tokenizers, or modeling files)), whether local or from the Hub"}) |
| use_lora: bool = field(default=False, metadata={"help": "whether to use LoRA"}) |
| lora_r: int = field(default=8, metadata={"help": "hidden dimension for LoRA"}) |
| lora_alpha: int = field(default=32, metadata={"help": "alpha for LoRA"}) |
| lora_dropout: float = field(default=0.05, metadata={"help": "dropout rate for LoRA"}) |
| lora_target_modules: str = field(default="query,value", metadata={"help": "where to perform LoRA"}) |
| tokenizer_path: Optional[str] = field(default="facebook/opt-125m") |
|
|
|
|
| @dataclass |
| class DataArguments: |
| data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
| kmer: int = field(default=-1, metadata={"help": "k-mer for input sequence. -1 means not using k-mer."}) |
| customized_tokenizer: Optional[str] = field(default=None) |
|
|
|
|
| @dataclass |
| class TrainingArguments(transformers.TrainingArguments): |
| vocab_file: Optional[str] = field( |
| default=None, |
| metadata={"help": "Path to custom vocabulary file (overrides Hugging Face default)"} |
| ) |
| cache_dir: Optional[str] = field(default=None) |
| run_name: str = field(default="run") |
| optim: str = field(default="adamw_torch") |
| model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length."}) |
| gradient_accumulation_steps: int = field(default=1) |
| per_device_train_batch_size: int = field(default=1) |
| per_device_eval_batch_size: int = field(default=1) |
| num_train_epochs: int = field(default=1) |
| fp16: bool = field(default=False) |
| logging_steps: int = field(default=100) |
| save_steps: int = field(default=100) |
| eval_steps: int = field(default=100) |
| evaluation_strategy: str = field(default="steps"), |
| warmup_steps: int = field(default=50) |
| weight_decay: float = field(default=0.01) |
| learning_rate: float = field(default=1e-4) |
| save_total_limit: int = field(default=3) |
| load_best_model_at_end: bool = field(default=False) |
| output_dir: str = field(default="output") |
| find_unused_parameters: bool = field(default=False) |
| checkpointing: bool = field(default=False) |
| dataloader_pin_memory: bool = field(default=False) |
| eval_and_save_results: bool = field(default=True) |
| save_model: bool = field(default=False) |
| seed: int = field(default=42) |
| project_name: str = field(default=None) |
| |
|
|
| def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): |
| """Collects the state dict and dump to disk.""" |
| state_dict = trainer.model.state_dict() |
| if trainer.args.should_save: |
| cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} |
| del state_dict |
| trainer._save(output_dir, state_dict=cpu_state_dict) |
|
|
|
|
| """ |
| Get the reversed complement of the original DNA sequence. |
| """ |
| def get_alter_of_dna_sequence(sequence: str): |
| MAP = {"A": "T", "T": "A", "C": "G", "G": "C"} |
| |
| return "".join([MAP[c] for c in sequence]) |
|
|
| """ |
| Transform a dna sequence to k-mer string |
| """ |
| def generate_kmer_str(sequence: str, k: int) -> str: |
| """Generate k-mer string from DNA sequence.""" |
| return " ".join([sequence[i:i+k] for i in range(len(sequence) - k + 1)]) |
|
|
|
|
| """ |
| Load or generate k-mer string for each DNA sequence. The generated k-mer string will be saved to the same directory as the original data with the same name but with a suffix of "_{k}mer". |
| """ |
| def load_or_generate_kmer(data_path: str, texts: List[str], k: int) -> List[str]: |
| """Load or generate k-mer string for each DNA sequence.""" |
| kmer_path = data_path.replace(".csv", f"_{k}mer.json") |
| if os.path.exists(kmer_path): |
| logging.warning(f"Loading k-mer from {kmer_path}...") |
| with open(kmer_path, "r") as f: |
| kmer = json.load(f) |
| else: |
| logging.warning(f"Generating k-mer...") |
| kmer = [generate_kmer_str(text, k) for text in texts] |
| with open(kmer_path, "w") as f: |
| logging.warning(f"Saving k-mer to {kmer_path}...") |
| json.dump(kmer, f) |
| |
| return kmer |
|
|
| def load_customized_data(data_path: str, texts: List[str], customized_tokenizer: str) -> List[str]: |
| """Load or generate k-mer string for each DNA sequence.""" |
| customize_path = data_path.replace(".csv", f"_{customized_tokenizer}.json") |
| print(customize_path) |
| if os.path.exists(customize_path): |
| logging.warning(f"Loading data by customized tokenizer from {customize_path}...") |
| with open(customize_path, "r") as f: |
| data = json.load(f) |
| |
| return data |
|
|
|
|
| class SupervisedDataset(Dataset): |
| """Dataset for supervised fine-tuning.""" |
|
|
| def __init__(self, |
| data_path: str, |
| tokenizer: transformers.PreTrainedTokenizer, |
| kmer: int = -1, |
| customized_tokenizer = None): |
|
|
| super(SupervisedDataset, self).__init__() |
|
|
| |
| with open(data_path, "r") as f: |
| data = list(csv.reader(f))[1:] |
| if len(data[0]) == 2: |
| |
| logging.warning("Perform single sequence classification...") |
| texts = [d[0] for d in data] |
| labels = [int(d[1]) for d in data] |
| elif len(data[0]) == 3: |
| |
| logging.warning("Perform sequence-pair classification...") |
| texts = [[d[0], d[1]] for d in data] |
| labels = [int(d[2]) for d in data] |
| else: |
| raise ValueError("Data format not supported.") |
| |
| if kmer != -1: |
|
|
| logging.warning(f"Using {kmer}-mer as input...") |
| texts = load_or_generate_kmer(data_path, texts, kmer) |
|
|
| elif kmer == -1 and customized_tokenizer: |
| logging.warning(f"Using {customized_tokenizer} as input...") |
| texts = load_customized_data(data_path, texts, customized_tokenizer) |
|
|
| output = tokenizer( |
| texts, |
| return_tensors="pt", |
| padding="longest", |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ) |
| |
|
|
| self.input_ids = output["input_ids"] |
| self.attention_mask = output["attention_mask"] |
| self.labels = labels |
| self.num_labels = len(set(labels)) |
|
|
| def __len__(self): |
| return len(self.input_ids) |
|
|
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| return dict(input_ids=self.input_ids[i], labels=self.labels[i]) |
|
|
|
|
| @dataclass |
| class DataCollatorForSupervisedDataset(object): |
| """Collate examples for supervised fine-tuning.""" |
|
|
| tokenizer: transformers.PreTrainedTokenizer |
|
|
| def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
| input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) |
| input_ids = torch.nn.utils.rnn.pad_sequence( |
| input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id |
| ) |
| labels = torch.Tensor(labels).long() |
| return dict( |
| input_ids=input_ids, |
| labels=labels, |
| attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
| ) |
|
|
| """ |
| Manually calculate the accuracy, f1, matthews_correlation, precision, recall with sklearn. |
| """ |
| def calculate_metric_with_sklearn(predictions: np.ndarray, labels: np.ndarray): |
| valid_mask = labels != -100 |
| valid_predictions = predictions[valid_mask] |
| valid_labels = labels[valid_mask] |
| return { |
| "accuracy": sklearn.metrics.accuracy_score(valid_labels, valid_predictions), |
| "f1": sklearn.metrics.f1_score( |
| valid_labels, valid_predictions, average="macro", zero_division=0 |
| ), |
| "matthews_correlation": sklearn.metrics.matthews_corrcoef( |
| valid_labels, valid_predictions |
| ), |
| "precision": sklearn.metrics.precision_score( |
| valid_labels, valid_predictions, average="macro", zero_division=0 |
| ), |
| "recall": sklearn.metrics.recall_score( |
| valid_labels, valid_predictions, average="macro", zero_division=0 |
| ), |
| } |
|
|
| |
| def preprocess_logits_for_metrics(logits:Union[torch.Tensor, Tuple[torch.Tensor, Any]], _): |
| if isinstance(logits, tuple): |
| logits = logits[0] |
|
|
| if logits.ndim == 3: |
| |
| logits = logits.reshape(-1, logits.shape[-1]) |
|
|
| return torch.argmax(logits, dim=-1) |
|
|
|
|
| """ |
| Compute metrics used for huggingface trainer. |
| """ |
| def compute_metrics(eval_pred): |
| predictions, labels = eval_pred |
| return calculate_metric_with_sklearn(predictions, labels) |
|
|
| def load_token_v5_1(tokenizer_kwargs): |
| config_class, model_class, tokenizer_class = MODEL_CLASSES['motifBert'] |
| tokenizer = MotifTokenizer(**tokenizer_kwargs) |
| |
| bases = ['A', 'T', 'C', 'G'] |
| |
| token_wc = [ |
| f"{operator}_POS_{i}_*_{char}" |
| for operator, i, char in itertools.product(['WC'], range(12), bases) |
| ] |
| |
| motif_wildcarded = [] |
| with open(os.path.join('/storage2/fs1/btc/Active/yeli/xiaoxiao.zhou/tokenize/tokenizers/tokenizer_v5.1/hg38_NOOP', "motifs_wildcard.txt"), "r") as file: |
| for line in file: |
| seq, operations = line.strip().split(maxsplit=1) |
| motif_wildcarded.append(operations.split()[0]) |
| |
| tokenizer.add_tokens(token_wc + motif_wildcarded) |
| return tokenizer |
|
|
| def load_token_v4(tokenizer_kwargs): |
| config_class, model_class, tokenizer_class = MODEL_CLASSES['motifBert'] |
| tokenizer = MotifTokenizer(**tokenizer_kwargs) |
| |
| bases = ['A', 'T', 'C', 'G'] |
| token_del = [ |
| f"{operator}_POS_{i}_{char}" |
| for operator, i, char in itertools.product(['DEL'], range(12), bases) |
| ] |
| token_rep = [ |
| f"{operator}_POS_{i}_{char1}_{char2}" |
| for operator, i, char1, char2 in itertools.product(['SUB'], range(12), bases, bases) |
| if char1 != char2 |
| ] |
| |
| token_wc = [ |
| f"{operator}_POS_{i}_*_{char}" |
| for operator, i, char in itertools.product(['WC'], range(12), bases) |
| ] |
| |
| token_ins = [ |
| f"{operator}_POS_{i}_{char}" |
| for operator, i, char in itertools.product(['INS'], range(13), bases) |
| ] |
| |
| motif_wildcarded = [] |
| with open(os.path.join('/storage2/fs1/btc/Active/yeli/xiaoxiao.zhou/tokenize/tokenizers/tokenizer_v4/hg38', "motifs_wildcard.txt"), "r") as file: |
| for line in file: |
| seq, operations = line.strip().split(maxsplit=1) |
| motif_wildcarded.append(operations.split()[0]) |
| |
| tokenizer.add_tokens(token_del + token_rep + token_wc + token_ins + motif_wildcarded) |
| return tokenizer |
|
|
| def train(): |
|
|
| parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
| wandb.init( |
| project=training_args.project_name, |
| ) |
|
|
| tokenizer_kwargs = { |
| "cache_dir": training_args.cache_dir, |
| "model_max_length": training_args.model_max_length, |
| "padding_side": "right", |
| "use_fast": True, |
| "trust_remote_code": model_args.trust_remote_code |
| } |
|
|
| if training_args.vocab_file is not None: |
| if not os.path.exists(training_args.vocab_file): |
| raise ValueError(f"Vocab file not found at: {training_args.vocab_file}") |
| tokenizer_kwargs["vocab_file"] = training_args.vocab_file |
|
|
| if data_args.customized_tokenizer == 'token_v4': |
| tokenizer = load_token_v4(tokenizer_kwargs) |
|
|
| elif data_args.customized_tokenizer == 'token_v5_1': |
| tokenizer = load_token_v5_1(tokenizer_kwargs) |
|
|
| else: |
| tokenizer = transformers.PreTrainedTokenizerFast( |
| tokenizer_file=model_args.tokenizer_path, |
| **tokenizer_kwargs |
| ) |
|
|
| tokenizer.pad_token = "[PAD]" |
| tokenizer.unk_token = "[UNK]" |
| tokenizer.cls_token = "[CLS]" |
| tokenizer.sep_token = "[SEP]" |
| tokenizer.mask_token = "[MASK]" |
|
|
| if "InstaDeepAI" in model_args.model_name_or_path: |
| tokenizer.eos_token = tokenizer.pad_token |
|
|
| |
| train_dataset = SupervisedDataset(tokenizer=tokenizer, |
| data_path=os.path.join(data_args.data_path, "train.csv"), |
| kmer=data_args.kmer, |
| customized_tokenizer=data_args.customized_tokenizer) |
| val_dataset = SupervisedDataset(tokenizer=tokenizer, |
| data_path=os.path.join(data_args.data_path, "dev.csv"), |
| kmer=data_args.kmer, |
| customized_tokenizer=data_args.customized_tokenizer) |
| test_dataset = SupervisedDataset(tokenizer=tokenizer, |
| data_path=os.path.join(data_args.data_path, "test.csv"), |
| kmer=data_args.kmer, |
| customized_tokenizer=data_args.customized_tokenizer) |
| data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
|
|
|
| config = transformers.AutoConfig.from_pretrained( |
| model_args.model_name_or_path, |
| num_labels = train_dataset.num_labels, |
| trust_remote_code=model_args.trust_remote_code |
| ) |
|
|
| model = transformers.AutoModelForSequenceClassification.from_pretrained( |
| model_args.model_name_or_path, |
| cache_dir=training_args.cache_dir, |
| config=config, |
| trust_remote_code=model_args.trust_remote_code |
| ).to("cuda") |
|
|
| |
| if model_args.use_lora: |
| lora_config = LoraConfig( |
| r=model_args.lora_r, |
| lora_alpha=model_args.lora_alpha, |
| target_modules=list(model_args.lora_target_modules.split(",")), |
| lora_dropout=model_args.lora_dropout, |
| bias="none", |
| task_type="SEQ_CLS", |
| inference_mode=False, |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
|
|
| |
| trainer = transformers.Trainer(model=model, |
| tokenizer=tokenizer, |
| args=training_args, |
| preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
| compute_metrics=compute_metrics, |
| train_dataset=train_dataset, |
| eval_dataset=val_dataset, |
| data_collator=data_collator) |
| trainer.train() |
|
|
| if training_args.save_model: |
| trainer.save_state() |
| safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) |
|
|
| |
| if training_args.eval_and_save_results: |
| results_path = os.path.join(training_args.output_dir, "results", training_args.run_name) |
| results = trainer.evaluate(eval_dataset=test_dataset) |
| os.makedirs(results_path, exist_ok=True) |
| with open(os.path.join(results_path, "eval_results.json"), "w") as f: |
| json.dump(results, f) |
|
|
|
|
|
|
|
|
| if __name__ == "__main__": |
|
|
| train() |