Commit ·
15e8d2f
1
Parent(s): 6b9513a
Upload 2 files
Browse files
mapping_adapter_checkpoint_114000steps.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb6e95ae3b9cd81f6d5bdcad0387c65aa804ec172336db5f89ba7ad7ffc1f8d2
|
| 3 |
+
size 125866547
|
representation_mapping.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, AdamW, get_linear_schedule_with_warmup
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
import transformers
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
from datasets import load_dataset, DatasetDict
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch
|
| 8 |
+
import wandb
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
args_max_epoch = 1
|
| 12 |
+
args_batch_size = 64
|
| 13 |
+
args_learning_rate = 3e-5
|
| 14 |
+
args_num_warmup_steps = 100
|
| 15 |
+
args_gradient_accumulation_steps_default = 2
|
| 16 |
+
adapter_hidden_dim = 4096
|
| 17 |
+
|
| 18 |
+
device = 'cuda'
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def main():
|
| 22 |
+
wandb.init(project="MappingAdapater_training_v6", name="training_run")
|
| 23 |
+
|
| 24 |
+
model = MappingStructure(checkpointE = "sentence-transformers/stsb-roberta-large",
|
| 25 |
+
checkpointD = "mistralai/Mistral-7B-Instruct-v0.1",
|
| 26 |
+
hidden_dim = adapter_hidden_dim,
|
| 27 |
+
torch_dtype = torch.float16,
|
| 28 |
+
flash_attn = True,
|
| 29 |
+
).to(device)
|
| 30 |
+
|
| 31 |
+
for n,p in model.named_parameters():
|
| 32 |
+
if 'mapping' not in n:
|
| 33 |
+
p.requires_grad = False
|
| 34 |
+
else:
|
| 35 |
+
p.requires_grad = True
|
| 36 |
+
|
| 37 |
+
dataset = load_dataset("sade-adrien/redpajama_v2_sample_10M")['train']
|
| 38 |
+
train_dataset, val_dataset = split_dataset(dataset, train_size=.989333)
|
| 39 |
+
datasets = DatasetDict({
|
| 40 |
+
'train': train_dataset,
|
| 41 |
+
'val': val_dataset
|
| 42 |
+
})
|
| 43 |
+
|
| 44 |
+
train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True)
|
| 45 |
+
val_dataloader = DataLoader(datasets['val'], batch_size=args_batch_size, shuffle=False)
|
| 46 |
+
|
| 47 |
+
optimizer = AdamW(model.parameters(), lr=args_learning_rate)
|
| 48 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, args_num_warmup_steps, args_max_epoch*len(train_dataloader))
|
| 49 |
+
|
| 50 |
+
global_step = 0
|
| 51 |
+
for epoch in range(args_max_epoch):
|
| 52 |
+
train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True, worker_init_fn=lambda _: torch.manual_seed(epoch))
|
| 53 |
+
|
| 54 |
+
for batch in tqdm(train_dataloader):
|
| 55 |
+
input_prompt = batch['raw_content']
|
| 56 |
+
outputs = model(input_prompt=input_prompt, compute_loss=True)
|
| 57 |
+
loss = outputs['loss']
|
| 58 |
+
|
| 59 |
+
# Gradient accumulation
|
| 60 |
+
loss = loss / args_gradient_accumulation_steps_default
|
| 61 |
+
loss.backward()
|
| 62 |
+
|
| 63 |
+
if (global_step + 1) % args_gradient_accumulation_steps_default == 0:
|
| 64 |
+
optimizer.step()
|
| 65 |
+
optimizer.zero_grad()
|
| 66 |
+
scheduler.step()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if (global_step + 1) % 2000 == 0:
|
| 70 |
+
torch.save({
|
| 71 |
+
'epoch': epoch,
|
| 72 |
+
'mapping_state_dict': model.mapping.state_dict(),
|
| 73 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 74 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 75 |
+
'global_step': global_step,
|
| 76 |
+
}, f'models/mapping_adapter_checkpoint_{global_step + 1}steps.pth')
|
| 77 |
+
|
| 78 |
+
global_step += 1
|
| 79 |
+
val_loss = None
|
| 80 |
+
if (global_step + 1) % 8000 == 0:
|
| 81 |
+
model.eval()
|
| 82 |
+
val_loss = 0.0
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
for val_batch in tqdm(val_dataloader):
|
| 85 |
+
val_inputs = val_batch['raw_content']
|
| 86 |
+
val_outputs = model(input_prompt=val_inputs, compute_loss=True)
|
| 87 |
+
val_loss += val_outputs['loss']
|
| 88 |
+
val_loss /= len(val_dataloader)
|
| 89 |
+
|
| 90 |
+
model.train()
|
| 91 |
+
|
| 92 |
+
wandb.log({
|
| 93 |
+
'step': global_step + 1,
|
| 94 |
+
'learning_rate': scheduler.get_last_lr()[0],
|
| 95 |
+
'train_loss': loss.item() * args_gradient_accumulation_steps_default,
|
| 96 |
+
'val_loss': val_loss.item() if val_loss else None
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def split_dataset(dataset, train_size=.9):
|
| 103 |
+
index = int(len(dataset) * train_size)
|
| 104 |
+
return dataset.select(range(index)), dataset.select(range(index, len(dataset)))
|
| 105 |
+
|
| 106 |
+
class MappingAdapter(nn.Module):
|
| 107 |
+
def __init__(self, input_dim, output_dim, hidden_dim):
|
| 108 |
+
super(MappingAdapter, self).__init__()
|
| 109 |
+
self.layer1 = nn.Linear(input_dim, hidden_dim)
|
| 110 |
+
self.layer2 = nn.Linear(hidden_dim, output_dim)
|
| 111 |
+
self.activation = nn.LeakyReLU(.01)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
x = self.layer1(x)
|
| 115 |
+
x = self.activation(x)
|
| 116 |
+
x = self.layer2(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
class MappingStructure(nn.Module):
|
| 120 |
+
def __init__(self, checkpointE, checkpointD, hidden_dim=2048, torch_dtype=torch.float32, flash_attn=False):
|
| 121 |
+
super(MappingStructure, self).__init__()
|
| 122 |
+
|
| 123 |
+
self.configE = AutoConfig.from_pretrained(checkpointE)
|
| 124 |
+
self.Encoder = AutoModel.from_pretrained(checkpointE,
|
| 125 |
+
low_cpu_mem_usage = True,
|
| 126 |
+
torch_dtype = torch_dtype,
|
| 127 |
+
config = self.configE
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.configD = AutoConfig.from_pretrained(checkpointD)
|
| 131 |
+
if flash_attn:
|
| 132 |
+
self.configD.update({'_flash_attn_2_enabled' : True})
|
| 133 |
+
self.Decoder = AutoModel.from_pretrained(checkpointD,
|
| 134 |
+
low_cpu_mem_usage = True,
|
| 135 |
+
torch_dtype = torch_dtype,
|
| 136 |
+
config = self.configD
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.mapping = MappingAdapter(self.configD.hidden_size, self.configE.hidden_size, hidden_dim=hidden_dim).to(torch_dtype)
|
| 140 |
+
|
| 141 |
+
self._init_tokenizers(checkpointE, checkpointD)
|
| 142 |
+
|
| 143 |
+
def _init_tokenizers(self, checkpointE, checkpointD):
|
| 144 |
+
self.tokenizerE = AutoTokenizer.from_pretrained(checkpointE, use_fast = False, revision = 'main', config = self.configE, padding_side='left')
|
| 145 |
+
self.tokenizerD = AutoTokenizer.from_pretrained(checkpointD, use_fast = False, revision = 'main', config = self.configD, padding_side='left')
|
| 146 |
+
self.tokenizerD.pad_token_id = self.tokenizerD.unk_token_id
|
| 147 |
+
|
| 148 |
+
def cosine_sim(self, u, v):
|
| 149 |
+
assert u.shape == v.shape, "u and v must have the same shape"
|
| 150 |
+
u_normalized = u / torch.norm(u, dim=1, keepdim=True)
|
| 151 |
+
v_normalized = v / torch.norm(v, dim=1, keepdim=True)
|
| 152 |
+
|
| 153 |
+
# Compute cosine similarity using dot product
|
| 154 |
+
return torch.sum(u_normalized * v_normalized, dim=1)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def mean_pooling(self, hidden_state, attention_mask):
|
| 158 |
+
token_embeddings = hidden_state
|
| 159 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 160 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def build_batch(self, input_prompt):
|
| 164 |
+
size = torch.randint(1, self.configE.max_position_embeddings-2, (1,)).item()
|
| 165 |
+
targets = []
|
| 166 |
+
|
| 167 |
+
for prompt in input_prompt:
|
| 168 |
+
tokenized_input = self.tokenizerE(prompt)
|
| 169 |
+
tokenized_input = {'input_ids': tokenized_input['input_ids'][:size],
|
| 170 |
+
'attention_mask': tokenized_input['attention_mask'][:size],
|
| 171 |
+
|
| 172 |
+
}
|
| 173 |
+
targets.append(tokenized_input)
|
| 174 |
+
targets = self.tokenizerE.pad(targets, padding=True, return_tensors='pt')
|
| 175 |
+
|
| 176 |
+
return targets
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def forward(self, input_prompt, compute_loss=False):
|
| 180 |
+
loss = None
|
| 181 |
+
|
| 182 |
+
# Slice prompt of needed to fit encoder max position embeddings (hard constraint)
|
| 183 |
+
if not compute_loss:
|
| 184 |
+
inputs = self.tokenizerD(input_prompt, return_tensors='pt', padding=True).to(device)
|
| 185 |
+
|
| 186 |
+
hidden_state_D = self.Decoder(**inputs).last_hidden_state
|
| 187 |
+
hidden_state_D_mapped = self.mapping(hidden_state_D)
|
| 188 |
+
|
| 189 |
+
else:
|
| 190 |
+
targets = self.build_batch(input_prompt).to(device)
|
| 191 |
+
|
| 192 |
+
input_prompt_sliced = self.tokenizerE.batch_decode(targets['input_ids'], skip_special_tokens=True)
|
| 193 |
+
inputs = self.tokenizerD(input_prompt_sliced, return_tensors='pt', padding=True).to(device)
|
| 194 |
+
|
| 195 |
+
hidden_state_D = self.Decoder(**inputs).last_hidden_state
|
| 196 |
+
hidden_state_D_mapped = self.mapping(hidden_state_D)
|
| 197 |
+
|
| 198 |
+
hidden_state_E = self.Encoder(**targets).last_hidden_state
|
| 199 |
+
|
| 200 |
+
proj_E = self.mean_pooling(hidden_state_E, targets['attention_mask'])
|
| 201 |
+
proj_D = self.mean_pooling(hidden_state_D_mapped, inputs['attention_mask'])
|
| 202 |
+
|
| 203 |
+
loss = 1 - torch.mean(self.cosine_sim(proj_E, proj_D))
|
| 204 |
+
|
| 205 |
+
del inputs
|
| 206 |
+
del targets
|
| 207 |
+
del input_prompt_sliced
|
| 208 |
+
del hidden_state_E
|
| 209 |
+
del proj_E
|
| 210 |
+
del proj_D
|
| 211 |
+
torch.cuda.empty_cache()
|
| 212 |
+
|
| 213 |
+
return {'loss': loss,
|
| 214 |
+
'last_hidden_state': hidden_state_D,
|
| 215 |
+
'last_hidden_state_mapped': hidden_state_D_mapped,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
if __name__ == '__main__':
|
| 220 |
+
main()
|