Spaces:
Runtime error
Runtime error
| import math | |
| import time | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import dataset_helper | |
| class AnimationTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| dim_model, # hidden_size; corresponds to embedding length | |
| num_heads, | |
| num_encoder_layers, | |
| num_decoder_layers, | |
| dropout_p, | |
| use_positional_encoder=True | |
| ): | |
| super().__init__() | |
| self.model_type = "Transformer" | |
| self.dim_model = dim_model | |
| # TODO: Currently left out, as input sequence shuffled. Later check if use is beneficial. | |
| self.use_positional_encoder = use_positional_encoder | |
| self.positional_encoder = PositionalEncoding( | |
| dim_model=dim_model, | |
| dropout_p=dropout_p | |
| ) | |
| self.transformer = nn.Transformer( | |
| d_model=dim_model, | |
| nhead=num_heads, | |
| num_encoder_layers=num_encoder_layers, | |
| num_decoder_layers=num_decoder_layers, | |
| dropout=dropout_p, | |
| batch_first=True | |
| ) | |
| def forward(self, src, tgt, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None): | |
| # Src size must be (batch_size, src sequence length) | |
| # Tgt size must be (batch_size, tgt sequence length) | |
| if self.use_positional_encoder: | |
| src = self.positional_encoder(src) | |
| tgt = self.positional_encoder(tgt) | |
| # Transformer blocks - Out size = (sequence length, batch_size, num_tokens) | |
| out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask) | |
| return out | |
| def get_tgt_mask(size) -> torch.tensor: | |
| # Generates a square matrix where each row allows one word more to be seen | |
| mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix | |
| mask = mask.float() | |
| mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf | |
| mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0 | |
| # EX for size=5: | |
| # [[0., -inf, -inf, -inf, -inf], | |
| # [0., 0., -inf, -inf, -inf], | |
| # [0., 0., 0., -inf, -inf], | |
| # [0., 0., 0., 0., -inf], | |
| # [0., 0., 0., 0., 0.]] | |
| return mask | |
| def create_pad_mask(matrix: torch.tensor) -> torch.tensor: | |
| pad_masks = [] | |
| # Iterate over each sequence in the batch. | |
| for i in range(0, matrix.size(0)): | |
| sequence = [] | |
| # Iterate over each element in the sequence and append True if padding value | |
| for j in range(0, matrix.size(1)): | |
| sequence.append(matrix[i, j, 0] == dataset_helper.PADDING_VALUE) | |
| pad_masks.append(sequence) | |
| #print("matrix", matrix, matrix.shape, "pad_mask", pad_masks) | |
| return torch.tensor(pad_masks) | |
| def _transformer_call_in_loops(model, batch, device, loss_function): | |
| source, target = batch[0], batch[1] | |
| source, target = source.to(device), target.to(device) | |
| # First index is all batch entries, second is | |
| target_input = target[:, :-1] # trg input is offset by one (SOS token and excluding EOS) | |
| target_expected = target[:, 1:] # trg is offset by one (excluding SOS token) | |
| # SOS - 1 - 2 - 3 - 4 - EOS - PAD - PAD // target_input | |
| # 1 - 2 - 3 - 4 - EOS - PAD - PAD - PAD // target_expected | |
| # Get mask to mask out the next words | |
| tgt_mask = get_tgt_mask(target_input.size(1)).to(device) | |
| # Standard training except we pass in y_input and tgt_mask | |
| prediction = model(source, target_input, | |
| tgt_mask=tgt_mask, | |
| src_key_padding_mask=create_pad_mask(source).to(device), | |
| # Mask with expected as EOS is no input (see above) | |
| tgt_key_padding_mask=create_pad_mask(target_expected).to(device)) | |
| return loss_function(prediction, target_expected, create_pad_mask(target_expected).to(device)) | |
| #return loss_function(prediction, target_expected) | |
| def train_loop(model, opt, loss_function, dataloader, device): | |
| model.train() | |
| total_loss = 0 | |
| t0 = time.time() | |
| i = 1 | |
| for batch in dataloader: | |
| loss = _transformer_call_in_loops(model, batch, device, loss_function) | |
| opt.zero_grad() | |
| loss.backward() | |
| opt.step() | |
| total_loss += loss.detach().item() | |
| if i == 1 or i % 10 == 0: | |
| elapsed_time = time.time() - t0 | |
| total_expected = elapsed_time / i * len(dataloader) | |
| print(f">> {i}: Time per Batch {elapsed_time / i : .2f}s | " | |
| f"Total expected {total_expected / 60 : .2f} min | " | |
| f"Remaining {(total_expected - elapsed_time) / 60 : .2f} min ") | |
| i += 1 | |
| print(f">> Epoch time: {(time.time() - t0)/60:.2f} min") | |
| return total_loss / len(dataloader) | |
| def validation_loop(model, loss_function, dataloader, device): | |
| model.eval() | |
| total_loss = 0 | |
| with torch.no_grad(): | |
| for batch in dataloader: | |
| loss = _transformer_call_in_loops(model, batch, device, loss_function) | |
| total_loss += loss.detach().item() | |
| return total_loss / len(dataloader) | |
| def fit(model, optimizer, loss_function, train_dataloader, val_dataloader, epochs, device): | |
| train_loss_list, validation_loss_list = [], [] | |
| print("Training and validating model") | |
| for epoch in range(epochs): | |
| print("-" * 25, f"Epoch {epoch + 1}", "-" * 25) | |
| train_loss = train_loop(model, optimizer, loss_function, train_dataloader, device) | |
| train_loss_list += [train_loss] | |
| validation_loss = validation_loop(model, loss_function, val_dataloader, device) | |
| validation_loss_list += [validation_loss] | |
| print(f"Training loss: {train_loss:.4f}") | |
| print(f"Validation loss: {validation_loss:.4f}") | |
| print() | |
| return train_loss_list, validation_loss_list | |
| def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_scaling=1, backpropagate=False, showResult= True, temperature=1): | |
| if backpropagate: | |
| model.train() | |
| else: | |
| model.eval() | |
| source_sequence = source_sequence.float().to(device) | |
| y_input = torch.unsqueeze(sos_token, dim=0).float().to(device) | |
| #print(source_sequence, source_sequence.unsqueeze(0)) | |
| i = 0 | |
| while i < max_length: | |
| # Get source mask | |
| #print(y_input, y_input.unsqueeze(0)) | |
| prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0), # un-squeeze for batch | |
| # tgt_mask=get_tgt_mask(y_input.size(0)).to(device), | |
| src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device)) | |
| next_embedding = prediction[0, -1, :] # prediction on last token | |
| pred_deep_svg, pred_type, pred_parameters = dataset_helper.unpack_embedding(next_embedding, dim=0) | |
| #print(pred_deep_svg, pred_type, pred_parameters) | |
| pred_deep_svg, pred_type, pred_parameters = pred_deep_svg.to(device), pred_type.to(device), pred_parameters.to( | |
| device) | |
| pred_type = pred_type / temperature | |
| # === TYPE === | |
| # Apply Softmax | |
| type_softmax = torch.softmax(pred_type, dim=0) | |
| type_softmax[0] = type_softmax[0] * eos_scaling # Reduce EOS | |
| indices = torch.argsort(type_softmax, descending=True) | |
| animation_type = random.choice(indices[:3]) | |
| #animation_type = torch.argmax(type_softmax, dim=0) | |
| # Break if EOS is most likely | |
| if animation_type == 0: | |
| print("END OF ANIMATION") | |
| y_input = torch.cat((y_input, sos_token.unsqueeze(0).to(device)), dim=0) | |
| return y_input | |
| pred_type = torch.zeros(11) | |
| pred_type[animation_type] = 1 | |
| # === DEEP SVG === | |
| # Find the closest path | |
| distances = [torch.norm(pred_deep_svg - embedding[:-26]) for embedding in source_sequence] | |
| closest_index = distances.index(min(distances)) | |
| closest_token = source_sequence[closest_index] | |
| # === PARAMETERS === | |
| # overwrite unused parameters | |
| for j in range(len(pred_parameters)): | |
| if j in dataset_helper.ANIMATION_PARAMETER_INDICES[int(animation_type)]: | |
| continue | |
| pred_parameters[j] = -1 | |
| # === SEQUENCE === | |
| y_new = torch.concat([closest_token[:-26], pred_type.to(device), pred_parameters], dim=0) | |
| #y_new = torch.concat([pred_deep_svg, pred_type.to(device), pred_parameters], dim=0) | |
| y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0) | |
| # === INFO PRINT === | |
| if showResult: | |
| print(f"{int(y_input.size(0))}: Path {closest_index} ({round(float(distances[closest_index]), 3)}) " | |
| f"got animation {animation_type} ({round(float(type_softmax[animation_type]), 3)}%) " | |
| f"with parameters {[round(num, 2) for num in pred_parameters.tolist()]}") | |
| i += 1 | |
| return y_input | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, dim_model, dropout_p, max_len=5000): | |
| """ | |
| Initializes the PositionalEncoding module which injects information about the relative or absolute position | |
| of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the | |
| two can be summed. Uses a sinusoidal pattern for positional encoding. | |
| Args: | |
| dim_model (int): The dimension of the embeddings and the expected dimension of the positional encoding. | |
| dropout_p (float): Dropout probability to be applied to the summed embeddings and positional encodings. | |
| max_len (int): The max length of the sequences for which positional encodings are precomputed and stored. | |
| """ | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout_p) | |
| position = torch.arange(max_len).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0) / dim_model)) | |
| pos_encoding = torch.zeros(max_len, 1, dim_model) | |
| pos_encoding[:, 0, 0::2] = torch.sin(position * div_term) | |
| pos_encoding[:, 0, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer('pos_encoding', pos_encoding) | |
| def forward(self, embedding: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Applies positional encoding to the input embeddings and applies dropout. | |
| Args: | |
| embedding (torch.Tensor): The input embeddings with shape [batch_size, seq_len, dim_model] | |
| Returns: | |
| torch.Tensor: The embeddings with positional encoding applied, and dropout, having the same shape as the | |
| input token embeddings [seq_len, batch_size, dim_model]. | |
| """ | |
| return self.dropout(embedding + self.pos_encoding[:embedding.size(0), :]) | |