jmjoseph's picture
Deploy TalkTuner probe training interface
5413412 verified
import torch
import torch.nn.functional as F
from torch import nn
class ProbeClassification(nn.Module):
def __init__(self, device, probe_class, input_dim=512, hidden_neurons=128): # from 0 to 15
super().__init__()
self.input_dim = input_dim
self.probe_class = probe_class
self.proj = nn.Sequential(
nn.Linear(self.input_dim, hidden_neurons),
nn.ReLU(True),
nn.Linear(hidden_neurons, self.probe_class),
)
self.apply(self._init_weights)
# logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
self.to(device)
def forward(self, act, y=None):
# [B, f], [B]
logits = self.proj(act)#.reshape(-1, self.probe_number, self.probe_class) # [B, C]
if y is None:
return logits, None
else:
targets = y.to(torch.long)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
return logits, loss
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# biases of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
# no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
print("Decayed:", decay)
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=0)
return optimizer, scheduler
class LinearProbeClassification(nn.Module):
def __init__(self, device, probe_class, input_dim=512, logistic=False, Relu=False, TanH=False): # from 0 to 15
super().__init__()
self.input_dim = input_dim
self.probe_class = probe_class
if logistic:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.probe_class),
nn.Sigmoid()
)
elif Relu:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.probe_class),
nn.ReLU(True)
)
elif TanH:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.probe_class),
# nn.Hardtanh(inplace=True, min_val=0.001, max_val=0.999)
nn.Hardsigmoid(inplace=True)
)
else:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.probe_class),
)
self.apply(self._init_weights)
# logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
self.to(device)
def forward(self, act, y=None):
# [B, f], [B]
logits = self.proj(act)#.reshape(-1, self.probe_number, self.probe_class) # [B, C]
if y is None:
return logits, None
else:
targets = y.to(torch.long)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
return logits, loss
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# biases of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
# no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
print("Decayed:", decay)
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=0)
return optimizer, scheduler
class TwoLayerLinearProbeClassification(nn.Module):
def __init__(self, device, probe_class, input_dim=512, logistic=False): # from 0 to 15
super().__init__()
self.input_dim = input_dim
self.probe_class = probe_class
if not logistic:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.input_dim),
nn.Linear(self.input_dim, self.probe_class),
)
else:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.input_dim),
nn.Linear(self.input_dim, self.probe_class),
nn.Sigmoid()
)
self.apply(self._init_weights)
# logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
self.to(device)
def forward(self, act, y=None):
# [B, f], [B]
logits = self.proj(act)#.reshape(-1, self.probe_number, self.probe_class) # [B, C]
if y is None:
return logits, None
else:
targets = y.to(torch.long)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
return logits, loss
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# biases of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
# no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
print("Decayed:", decay)
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=0)
return optimizer, scheduler
class ProbeClassificationMixScaler(nn.Module):
def __init__(self, device, probe_class, input_dim=512, num_layers=41, soft_weight_lr_rate=1e-1,
hidden_neurons=128): # from 0 to 15
super().__init__()
self.input_dim = input_dim
self.probe_class = probe_class
self.num_layers = num_layers
# self.mix_weights = torch.nn.Parameter(1 / num_layers * torch.ones(num_layers), requires_grad=True)
self.mix_weights = nn.Linear(num_layers, 1, bias=False)
torch.nn.init.constant_(self.mix_weights.weight, 1 / num_layers)
self.soft_weight_lr_rate=soft_weight_lr_rate
self.proj = nn.Sequential(
nn.Linear(self.input_dim, hidden_neurons),
nn.ReLU(True),
nn.Linear(hidden_neurons, self.probe_class),
)
self.apply(self._init_weights)
# logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
self.to(device)
def forward(self, act, y=None):
# [B, f], [B]
softmaxed_weights = torch.nn.functional.softmax(self.mix_weights.weight, dim=1)
act = act.permute([0, 2, 1])
act = (act @ softmaxed_weights.T)[..., 0]
logits = self.proj(act)#.reshape(-1, self.probe_number, self.probe_class) # [B, C]
if y is None:
return logits, None
else:
targets = y.to(torch.long)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
return logits, loss
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# biases of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and (not "mix" in fpn) and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
# no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
# assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
# % (str(param_dict.keys() - union_params), )
print("Decayed:", decay)
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
{'params': self.mix_weights.weight, "lr": self.soft_weight_lr_rate, "weight_decay": train_config.weight_decay},
]
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=0)
return optimizer, scheduler
class LinearProbeClassificationMixScaler(nn.Module):
def __init__(self, device, probe_class, input_dim=512, num_layers=41, soft_weight_lr_rate=1e-1,
logistic=False): # from 0 to 15
super().__init__()
self.input_dim = input_dim
self.probe_class = probe_class
self.num_layers = num_layers
# self.mix_weights = torch.nn.Parameter(1 / num_layers * torch.ones(num_layers), requires_grad=True)
self.mix_weights = nn.Linear(num_layers, 1, bias=False)
torch.nn.init.constant_(self.mix_weights.weight, 1 / num_layers)
self.soft_weight_lr_rate=soft_weight_lr_rate
if not logistic:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.probe_class),
)
else:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.probe_class),
nn.Sigmoid()
)
self.apply(self._init_weights)
# logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
self.to(device)
def forward(self, act, y=None):
# [B, f], [B]
softmaxed_weights = torch.nn.functional.softmax(self.mix_weights.weight, dim=1)
act = act.permute([0, 2, 1])
act = (act @ softmaxed_weights.T)[..., 0]
logits = self.proj(act)#.reshape(-1, self.probe_number, self.probe_class) # [B, C]
if y is None:
return logits, None
else:
targets = y.to(torch.long)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
return logits, loss
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# biases of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and (not "mix" in fpn) and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
# no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
# assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
# % (str(param_dict.keys() - union_params), )
print("Decayed:", decay)
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
{'params': self.mix_weights.weight, "lr": self.soft_weight_lr_rate, "weight_decay": train_config.weight_decay},
]
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=0)
return optimizer, scheduler
class TwoLayerLinearProbeClassificationMixScaler(nn.Module):
def __init__(self, device, probe_class, input_dim=512, num_layers=41, soft_weight_lr_rate=1e-1,
logistic=False): # from 0 to 15
super().__init__()
self.input_dim = input_dim
self.probe_class = probe_class
self.num_layers = num_layers
# self.mix_weights = torch.nn.Parameter(1 / num_layers * torch.ones(num_layers), requires_grad=True)
self.mix_weights = nn.Linear(num_layers, 1, bias=False)
torch.nn.init.constant_(self.mix_weights.weight, 1 / num_layers)
self.soft_weight_lr_rate=soft_weight_lr_rate
self.rotates = nn.ModuleList([nn.Linear(5120, 5120) for _ in range(41)]),
if not logistic:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.probe_class),
)
else:
self.proj = nn.Sequential(
nn.Linear(self.input_dim, self.probe_class),
nn.Sigmoid()
)
self.apply(self._init_weights)
# logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
self.to(device)
def forward(self, act, y=None):
# [B, f], [B]
outputs = []
for i in range(num_vectors):
output_i = self.rotates[i](act[:, i, :]) # shape: (batch_size, 5120)
outputs.append(output_i)
# Stack the outputs back together
act = torch.stack(outputs, dim=1)
softmaxed_weights = torch.nn.functional.softmax(self.mix_weights.weight, dim=1)
act = act.permute([0, 2, 1])
act = (act @ softmaxed_weights.T)[..., 0]
logits = self.proj(act)#.reshape(-1, self.probe_number, self.probe_class) # [B, C]
if y is None:
return logits, None
else:
targets = y.to(torch.long)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
return logits, loss
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# biases of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and (not "mix" in fpn) and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
# no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
# assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
# % (str(param_dict.keys() - union_params), )
print("Decayed:", decay)
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
{'params': self.mix_weights.weight, "lr": self.soft_weight_lr_rate, "weight_decay": train_config.weight_decay},
]
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=0)
return optimizer, scheduler
class TrainerConfig:
# optimization parameters
learning_rate = 1e-3
betas = (0.9, 0.95)
weight_decay = 0.1 # only applied on matmul weights
def __init__(self, **kwargs):
for k,v in kwargs.items():
setattr(self, k, v)