Upload 2 files
Browse filesScripts for data prep and model training
- prepare_data_script.py +91 -0
- train_script.py +102 -0
prepare_data_script.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding
|
| 6 |
+
|
| 7 |
+
import datasets
|
| 8 |
+
from datasets import disable_caching
|
| 9 |
+
disable_caching()
|
| 10 |
+
|
| 11 |
+
DEVICE = 'cuda:0' # model device
|
| 12 |
+
ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m" # encoder name
|
| 13 |
+
ENCODER_BATCH_SIZE = 1024 # batch size for computing embeddings
|
| 14 |
+
|
| 15 |
+
TOKENIZER_MAX_LEN = 256 # max_length param on tokenizer
|
| 16 |
+
TOKENIZATION_NUM_PROC = 32 # number of processes for tokenization
|
| 17 |
+
|
| 18 |
+
'''
|
| 19 |
+
Data source is expected to be a CSV file with a column of SMILES strings
|
| 20 |
+
denoted by `SMILES_COLUMN`. The CSV is processed in chunks of size `PROCESS_CHUNKSIZE`.
|
| 21 |
+
|
| 22 |
+
Processed chunks are saved to `SAVE_PATH` with the format `SAVE_PATH/processed_shard_{i}.hf`
|
| 23 |
+
'''
|
| 24 |
+
|
| 25 |
+
DATASET_CSV_FILENAME = None # path to data csv
|
| 26 |
+
PROCESS_CHUNKSIZE = 1000000 # how many rows to process/save for each dataset shard
|
| 27 |
+
SMILES_COLUMN = 'smiles' # csv column holding smiles strings
|
| 28 |
+
MAX_CHUNKS = None # total number of chunks to process (if None, all chunks are processed)
|
| 29 |
+
MAX_SMILES_LENGTH = 90 # max smiles string length (exclusive)
|
| 30 |
+
MIN_SMILES_LENGTH = 5 # min smiles string length (exclusive)
|
| 31 |
+
FILTER_NUM_PROC = 32 # number of processes for filtering
|
| 32 |
+
SAVE_PATH = None # directory to save data shards to
|
| 33 |
+
|
| 34 |
+
assert DATASET_CSV_FILENAME is not None, "must specify dataset filename"
|
| 35 |
+
assert SAVE_PATH is not None, "must specify save path"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def tokenization(example):
|
| 39 |
+
return tokenizer(example[SMILES_COLUMN], add_special_tokens=True,
|
| 40 |
+
truncation=True, max_length=TOKENIZER_MAX_LEN)
|
| 41 |
+
|
| 42 |
+
def embed(inputs):
|
| 43 |
+
inputs = {k:inputs[k] for k in ['input_ids', 'attention_mask']}
|
| 44 |
+
inputs = collator(inputs)
|
| 45 |
+
inputs = {k:v.to(DEVICE) for k,v in inputs.items()}
|
| 46 |
+
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
outputs = model(**inputs, output_hidden_states=True)
|
| 49 |
+
full_embeddings = outputs[-1][-1]
|
| 50 |
+
mask = inputs['attention_mask']
|
| 51 |
+
|
| 52 |
+
mean_embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1))
|
| 53 |
+
|
| 54 |
+
return {'encoder_hidden_states' : mean_embeddings}
|
| 55 |
+
|
| 56 |
+
def length_filter_smiles(example):
|
| 57 |
+
min_check = (len(example[SMILES_COLUMN])>MIN_SMILES_LENGTH) if (MIN_SMILES_LENGTH is not None) else True
|
| 58 |
+
max_check = (len(example[SMILES_COLUMN])<MAX_SMILES_LENGTH) if (MIN_SMILES_LENGTH is not None) else True
|
| 59 |
+
type_check = type(example[SMILES_COLUMN])==str
|
| 60 |
+
filter_pass = all([min_check, max_check, type_check])
|
| 61 |
+
return filter_pass
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN)
|
| 65 |
+
collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt')
|
| 66 |
+
|
| 67 |
+
model = RobertaForMaskedLM.from_pretrained(ENCODER_MODEL_NAME)
|
| 68 |
+
model.to(DEVICE)
|
| 69 |
+
model.eval()
|
| 70 |
+
|
| 71 |
+
df_iter = pd.read_csv(DATASET_CSV_FILENAME, chunksize=PROCESS_CHUNKSIZE, usecols=[SMILES_COLUMN])
|
| 72 |
+
|
| 73 |
+
for i, df in enumerate(df_iter):
|
| 74 |
+
print(f'processing dataset chunk {i}')
|
| 75 |
+
|
| 76 |
+
dataset = datasets.Dataset.from_pandas(df)
|
| 77 |
+
|
| 78 |
+
dataset = dataset.filter(lambda example: length_filter_smiles(example), num_proc=FILTER_NUM_PROC)
|
| 79 |
+
|
| 80 |
+
dataset = dataset.map(tokenization, batched=True, num_proc=TOKENIZATION_NUM_PROC)
|
| 81 |
+
|
| 82 |
+
dataset = dataset.map(embed, batched=True, batch_size=ENCODER_BATCH_SIZE)
|
| 83 |
+
|
| 84 |
+
dataset.save_to_disk(f'{SAVE_PATH}/processed_shard_{i}.hf')
|
| 85 |
+
|
| 86 |
+
if (MAX_CHUNKS is not None) and (i >= MAX_CHUNKS-1):
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
print('finished data processing')
|
| 90 |
+
|
| 91 |
+
|
train_script.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
|
| 8 |
+
from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling
|
| 9 |
+
from transformers import Trainer, TrainingArguments, RobertaTokenizerFast
|
| 10 |
+
|
| 11 |
+
import datasets
|
| 12 |
+
from datasets import disable_caching
|
| 13 |
+
disable_caching()
|
| 14 |
+
from datasets import IterableDataset
|
| 15 |
+
|
| 16 |
+
from conditional_gpt2_model import ConditionalGPT2LMHeadModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m" # encoder model name
|
| 20 |
+
TOKENIZER_MAX_LEN = 256 # max_length param on tokenizer
|
| 21 |
+
|
| 22 |
+
DATA_SUBSHARDS = 10 # number of shards to break each data chunk into
|
| 23 |
+
|
| 24 |
+
DATA_DIR = None # directory with saved data shards
|
| 25 |
+
TRAINER_SAVE_DIR = None # directory to save model checkpoints
|
| 26 |
+
|
| 27 |
+
assert DATA_DIR is not None, "data directory must be specified"
|
| 28 |
+
assert TRAINER_SAVE_DIR is not None, "trainer save directory must be specified"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def gen_dataset():
|
| 33 |
+
|
| 34 |
+
data_filenames = sorted([i for i in os.listdir(DATA_DIR) if '.hf' in i])
|
| 35 |
+
|
| 36 |
+
for filename in data_filenames:
|
| 37 |
+
|
| 38 |
+
dataset = datasets.Dataset.load_from_disk(f'{DATA_DIR}/{filename}')
|
| 39 |
+
|
| 40 |
+
keep_cols = ['input_ids', 'encoder_hidden_states']
|
| 41 |
+
|
| 42 |
+
dataset = dataset.remove_columns([i for i in dataset.column_names
|
| 43 |
+
if not i in keep_cols]).with_format("torch")
|
| 44 |
+
|
| 45 |
+
# contiguous shards for faster loading
|
| 46 |
+
shards = [dataset.shard(num_shards=DATA_SUBSHARDS, index=index, contiguous=True)
|
| 47 |
+
for index in range(DATA_SUBSHARDS)]
|
| 48 |
+
|
| 49 |
+
for i, shard in enumerate(shards):
|
| 50 |
+
for example in shard:
|
| 51 |
+
# need to add unit axis to hidden states
|
| 52 |
+
example['encoder_hidden_states'] = example['encoder_hidden_states'][None,:]
|
| 53 |
+
yield example
|
| 54 |
+
|
| 55 |
+
dataset = IterableDataset.from_generator(gen_dataset)
|
| 56 |
+
dataset = dataset.with_format("torch")
|
| 57 |
+
|
| 58 |
+
tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN)
|
| 59 |
+
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
config = GPT2Config(
|
| 63 |
+
vocab_size=len(tokenizer),
|
| 64 |
+
n_positions=TOKENIZER_MAX_LEN,
|
| 65 |
+
bos_token_id=tokenizer.bos_token_id,
|
| 66 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 67 |
+
n_layer=6,
|
| 68 |
+
n_head=8,
|
| 69 |
+
add_cross_attention=True,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
model = ConditionalGPT2LMHeadModel(config)
|
| 73 |
+
|
| 74 |
+
# change trainer args as needed
|
| 75 |
+
args = TrainingArguments(
|
| 76 |
+
output_dir=TRAINER_SAVE_DIR,
|
| 77 |
+
per_device_train_batch_size=192,
|
| 78 |
+
logging_steps=25,
|
| 79 |
+
gradient_accumulation_steps=8,
|
| 80 |
+
num_train_epochs=1,
|
| 81 |
+
weight_decay=0.1,
|
| 82 |
+
warmup_steps=1000,
|
| 83 |
+
lr_scheduler_type="cosine",
|
| 84 |
+
learning_rate=1e-5,
|
| 85 |
+
save_steps=200,
|
| 86 |
+
save_total_limit=30,
|
| 87 |
+
fp16=True,
|
| 88 |
+
push_to_hub=False,
|
| 89 |
+
max_steps=50000,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
trainer = Trainer(
|
| 94 |
+
model=model,
|
| 95 |
+
tokenizer=tokenizer,
|
| 96 |
+
args=args,
|
| 97 |
+
data_collator=collator,
|
| 98 |
+
train_dataset=dataset,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
trainer.train()
|
| 102 |
+
|