| import ast |
| import copy |
| import math |
| import pickle |
| import os |
| from collections import deque |
| from typing import List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset |
|
|
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask |
| from transformers.generation import GenerationMixin |
| from transformers.utils import ( |
| add_end_docstrings, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| logging, |
| replace_return_docstrings, |
| ) |
| from transformers.modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPastAndCrossAttentions, |
| CausalLMOutputWithCrossAttentions, |
| Seq2SeqLMOutput, |
| Seq2SeqModelOutput, |
| ) |
|
|
| from configuration import STLConfig |
| from nltk.translate.bleu_score import sentence_bleu |
| from stl import * |
| import networkx as nx |
| from datasets import load_dataset |
|
|
|
|
| |
|
|
| import re |
| import json |
| from typing import Any, Dict, List, Optional, Tuple, Union |
| from transformers import PreTrainedTokenizer |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
|
|
| def load_pickle(path): |
| with open(path, 'rb') as f: |
| x = pickle.load(f) |
| return x |
|
|
| def dump_pickle(name, thing): |
| with open(name + '.pickle', 'wb') as f: |
| pickle.dump(thing, f) |
|
|
| def from_string_to_formula(st): |
| root_arity = 2 if st.startswith('(') else 1 |
| st_split = st.split() |
| if root_arity <= 1: |
| root_op_str = copy.deepcopy(st_split[0]) |
| if root_op_str.startswith('x'): |
| atom_sign = True if st_split[1] == '<=' else False |
| root_phi = Atom(var_index=int(st_split[0][2]), lte=atom_sign, threshold=float(st_split[2])) |
| return root_phi |
| else: |
| assert (root_op_str.startswith('not') or root_op_str.startswith('eventually') |
| or root_op_str.startswith('always')) |
| current_st = copy.deepcopy(st_split[2:-1]) |
| if root_op_str == 'not': |
| root_phi = Not(child=from_string_to_formula(' '.join(current_st))) |
| elif root_op_str.startswith('eventually'): |
| unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str) |
| root_phi = Eventually(child=from_string_to_formula(' '.join(current_st)), unbound=unbound, |
| right_unbound=right_unbound, left_time_bound=left_time_bound, |
| right_time_bound=right_time_bound) |
| else: |
| unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str) |
| root_phi = Globally(child=from_string_to_formula(' '.join(current_st)), unbound=unbound, |
| right_unbound=right_unbound, left_time_bound=left_time_bound, |
| right_time_bound=right_time_bound) |
| else: |
| |
| current_st = copy.deepcopy(st_split[1:-1]) |
| if '(' in current_st: |
| par_queue = deque() |
| par_idx_list = [] |
| for i, sub in enumerate(current_st): |
| if sub == '(': |
| par_queue.append(i) |
| elif sub == ')': |
| par_idx_list.append(tuple([par_queue.pop(), i])) |
| |
| |
| children_range = [] |
| for begin, end in sorted(par_idx_list): |
| if children_range and children_range[-1][1] >= begin - 1: |
| children_range[-1][1] = max(children_range[-1][1], end) |
| else: |
| children_range.append([begin, end]) |
| n_children = len(children_range) |
| assert (n_children in [1, 2]) |
| if n_children == 1: |
| |
| var_child_idx = 1 if children_range[0][0] <= 1 else 0 |
| if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']: |
| children_range[0][0] -= 1 |
| left_child_str = current_st[:3] if var_child_idx == 0 else \ |
| current_st[children_range[0][0]:children_range[0][1] + 1] |
| right_child_str = current_st[-3:] if var_child_idx == 1 else \ |
| current_st[children_range[0][0]:children_range[0][1] + 1] |
| root_op_str = current_st[children_range[0][1] + 1] if var_child_idx == 1 else \ |
| current_st[children_range[0][0] - 1] |
| assert (root_op_str[:2] in ['an', 'or', 'un']) |
| else: |
| if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']: |
| children_range[0][0] -= 1 |
| if current_st[children_range[1][0] - 1][0:2] in ['no', 'ev', 'al']: |
| children_range[1][0] -= 1 |
| |
| root_op_str = current_st[children_range[0][1] + 1] |
| assert (root_op_str[:2] in ['an', 'or', 'un']) |
| left_child_str = current_st[children_range[0][0]:children_range[0][1] + 1] |
| right_child_str = current_st[children_range[1][0]:children_range[1][1] + 1] |
| else: |
| |
| left_child_str = current_st[:3] |
| right_child_str = current_st[-3:] |
| root_op_str = current_st[3] |
| left_child_str = ' '.join(left_child_str) |
| right_child_str = ' '.join(right_child_str) |
| if root_op_str == 'and': |
| root_phi = And(left_child=from_string_to_formula(left_child_str), |
| right_child=from_string_to_formula(right_child_str)) |
| elif root_op_str == 'or': |
| root_phi = Or(left_child=from_string_to_formula(left_child_str), |
| right_child=from_string_to_formula(right_child_str)) |
| else: |
| unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str) |
| root_phi = Until(left_child=from_string_to_formula(left_child_str), |
| right_child=from_string_to_formula(right_child_str), |
| unbound=unbound, right_unbound=right_unbound, left_time_bound=left_time_bound, |
| right_time_bound=right_time_bound) |
| return root_phi |
|
|
| def load_json(path: str) -> Union[Dict, List]: |
| """ |
| Load a JSON file from the given path. |
| |
| Args: |
| path (str): The path to the JSON file to be loaded. |
| |
| Returns: |
| Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list. |
| """ |
| with open(path, "r") as f: |
| return json.load(f) |
|
|
| |
|
|
| class StlGenerator: |
| def __init__( |
| self, |
| leaf_prob: float = 0.3, |
| inner_node_prob: list = None, |
| threshold_mean: float = 0.0, |
| threshold_sd: float = 1.0, |
| unbound_prob: float = 0.1, |
| right_unbound_prob: float = 0.2, |
| time_bound_max_range: float = 20, |
| adaptive_unbound_temporal_ops: bool = True, |
| max_timespan: int = 100, |
| ): |
| """ |
| leaf_prob |
| probability of generating a leaf (always zero for root) |
| node_types = ["not", "and", "or", "always", "eventually", "until"] |
| Inner node types |
| inner_node_prob |
| probability vector for the different types of internal nodes |
| threshold_mean |
| threshold_sd |
| mean and std for the normal distribution of the thresholds of atoms |
| unbound_prob |
| probability of a temporal operator to have a time bound o the type [0,infty] |
| time_bound_max_range |
| maximum value of time span of a temporal operator (i.e. max value of t in [0,t]) |
| adaptive_unbound_temporal_ops |
| if true, unbounded temporal operators are computed from current point to the end of the signal, otherwise |
| they are evaluated only at time zero. |
| max_timespan |
| maximum time depth of a formula. |
| """ |
|
|
| |
| if inner_node_prob is None: |
| inner_node_prob = [0.166, 0.166, 0.166, 0.17, 0.166, 0.166] |
|
|
| self.leaf_prob = leaf_prob |
| self.inner_node_prob = inner_node_prob |
| self.threshold_mean = threshold_mean |
| self.threshold_sd = threshold_sd |
| self.unbound_prob = unbound_prob |
| self.right_unbound_prob = right_unbound_prob |
| self.time_bound_max_range = time_bound_max_range |
| self.adaptive_unbound_temporal_ops = adaptive_unbound_temporal_ops |
| self.node_types = ["not", "and", "or", "always", "eventually", "until"] |
| self.max_timespan = max_timespan |
|
|
| def sample(self, nvars): |
| """ |
| Samples a random formula with distribution defined in class instance parameters |
| |
| Parameters |
| ---------- |
| nvars : number of variables of input signals |
| how many variables the formula is expected to consider. |
| |
| Returns |
| ------- |
| TYPE |
| A random formula. |
| |
| """ |
| return self._sample_internal_node(nvars) |
|
|
| def bag_sample(self, bag_size, nvars): |
| """ |
| Samples a bag of bag_size formulae |
| |
| Parameters |
| ---------- |
| bag_size : INT |
| number of formulae. |
| nvars : INT |
| number of vars in formulae. |
| |
| Returns |
| ------- |
| a list of formulae. |
| |
| """ |
| formulae = [] |
| for _ in range(bag_size): |
| phi = self.sample(nvars) |
| formulae.append(phi) |
| return formulae |
|
|
| def _sample_internal_node(self, nvars): |
| |
| node: Union[None, Node] |
| node = None |
| |
| nodetype = rnd.choice(self.node_types, p=self.inner_node_prob) |
| while True: |
| if nodetype == "not": |
| n = self._sample_node(nvars) |
| node = stl.Not(n) |
| elif nodetype == "and": |
| n1 = self._sample_node(nvars) |
| n2 = self._sample_node(nvars) |
| node = stl.And(n1, n2) |
| elif nodetype == "or": |
| n1 = self._sample_node(nvars) |
| n2 = self._sample_node(nvars) |
| node = stl.Or(n1, n2) |
| elif nodetype == "always": |
| n = self._sample_node(nvars) |
| unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters() |
| node = stl.Globally( |
| n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops |
| ) |
| elif nodetype == "eventually": |
| n = self._sample_node(nvars) |
| unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters() |
| node = stl.Eventually( |
| n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops |
| ) |
| elif nodetype == "until": |
| n1 = self._sample_node(nvars) |
| n2 = self._sample_node(nvars) |
| unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters() |
| node = stl.Until( |
| n1, n2, unbound, right_unbound, left_time_bound, right_time_bound |
| ) |
|
|
| if (node is not None) and (node.time_depth() < self.max_timespan): |
| return node |
|
|
| def _sample_node(self, nvars): |
| if rnd.rand() < self.leaf_prob: |
| |
| var, thr, lte = self._get_atom(nvars) |
| return stl.Atom(var, thr, lte) |
| else: |
| return self._sample_internal_node(nvars) |
|
|
| def _get_temporal_parameters(self): |
| if rnd.rand() < self.unbound_prob: |
| return True, False, 0, 0 |
| elif rnd.rand() < self.right_unbound_prob: |
| return False, True, rnd.randint(self.time_bound_max_range), 1 |
| else: |
| left_bound = rnd.randint(self.time_bound_max_range) |
| return False, False, left_bound, rnd.randint(left_bound, self.time_bound_max_range) + 1 |
|
|
| def _get_atom(self, nvars): |
| variable = rnd.randint(nvars) |
| lte = rnd.rand() > 0.5 |
| threshold = rnd.normal(self.threshold_mean, self.threshold_sd) |
| return variable, threshold, lte |
|
|
| |
|
|
| class Measure: |
| def sample(self, samples=100000, varn=2, points=100): |
| |
| pass |
|
|
| class BaseMeasure(Measure): |
| def __init__( |
| self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu" |
| ): |
| """ |
| |
| Parameters |
| ---------- |
| mu0 : mean of normal distribution of initial state, optional |
| The default is 0.0. |
| sigma0 : standard deviation of normal distribution of initial state, optional |
| The default is 1.0. |
| mu1 : DOUBLE, optional |
| mean of normal distribution of total variation. The default is 0.0. |
| sigma1 : standard deviation of normal distribution of total variation, optional |
| The default is 1.0. |
| q : DOUBLE, optional |
| probability of change of sign in derivative. The default is 0.1. |
| q0 : DOUBLE, optional |
| probability of initial sign of derivative. The default is 0.5. |
| device : 'cpu' or 'cuda', optional |
| device on which to run the algorithm. The default is 'cpu'. |
| |
| Returns |
| ------- |
| None. |
| |
| """ |
| self.mu0 = mu0 |
| self.sigma0 = sigma0 |
| self.mu1 = mu1 |
| self.sigma1 = sigma1 |
| self.q = q |
| self.q0 = q0 |
| self.device = device |
|
|
| def sample(self, samples=100000, varn=2, points=100): |
| """ |
| Samples a set of trajectories from the basic measure space, with parameters |
| passed to the sampler |
| |
| Parameters |
| ---------- |
| points : INT, optional |
| number of points per trajectory, including initial one. The default is 1000. |
| samples : INT, optional |
| number of trajectories. The default is 100000. |
| varn : INT, optional |
| number of variables per trajectory. The default is 2. |
| |
| |
| Returns |
| ------- |
| signal : samples x varn x points double pytorch tensor |
| The sampled signals. |
| |
| """ |
| if self.device == "cuda" and not torch.cuda.is_available(): |
| raise RuntimeError("GPU card or CUDA library not available!") |
|
|
| |
| signal = torch.rand(samples, varn, points, device=self.device) |
| |
| signal[:, :, 0] = 0.0 |
| signal[:, :, -1] = 1.0 |
| |
| signal, _ = torch.sort(signal, 2) |
| |
| signal[:, :, 1:] = signal[:, :, 1:] - signal[:, :, :-1] |
| |
| signal[:, :, 0] = self.mu0 + self.sigma0 * torch.randn(signal[:, :, 0].size()) |
|
|
| |
| derivs = (1 - self.q) * torch.ones(samples, varn, points, device=self.device) |
| derivs = 2 * torch.bernoulli(derivs) - 1 |
| |
| derivs[:, :, 0] = self.q0 |
| derivs[:, :, 0] = 2 * torch.bernoulli(derivs[:, :, 0]) - 1 |
| |
| derivs = torch.cumprod(derivs, 2) |
|
|
| |
| totvar = torch.pow( |
| self.mu1 + self.sigma1 * torch.randn(samples, varn, 1, device=self.device), |
| 2, |
| ) |
| |
| derivs = derivs * totvar |
| derivs[:, :, 0] = 1.0 |
|
|
| |
| signal = signal * derivs |
| signal = torch.cumsum(signal, 2) |
| return signal |
|
|
| |
|
|
| realnum = Union[float, int] |
|
|
| class StlKernel: |
| def __init__( |
| self, |
| measure, |
| normalize=True, |
| exp_kernel=True, |
| sigma2=0.2, |
| integrate_time=False, |
| samples=100000, |
| varn=2, |
| points=100, |
| boolean=False, |
| signals=None, |
| ): |
| self.traj_measure = measure |
| self.exp_kernel = exp_kernel |
| self.normalize = normalize |
| self.sigma2 = sigma2 |
| self.samples = samples |
| self.varn = varn |
| self.points = points |
| self.integrate_time = integrate_time |
| if signals is not None: |
| self.signals = signals |
| else: |
| self.signals = measure.sample(points=points, samples=samples, varn=varn) |
| self.boolean = boolean |
|
|
| def compute(self, phi1, phi2): |
| return self.compute_one_one(phi1, phi2) |
|
|
| def compute_one_one(self, phi1, phi2): |
| phis1: list = [phi1] |
| phis2: list = [phi2] |
| ker = self.compute_bag_bag(phis1, phis2) |
| return ker[0, 0] |
|
|
| def compute_bag(self, phis, return_robustness=True): |
| if self.integrate_time: |
| rhos, selfk, len0 = self._compute_robustness_time(phis) |
| kernel_matrix = self._compute_kernel_time( |
| rhos, rhos, selfk, selfk, len0, len0 |
| ) |
| else: |
| rhos, selfk = self._compute_robustness_no_time(phis) |
| kernel_matrix = self._compute_kernel_no_time(rhos, rhos, selfk, selfk) |
| len0 = None |
| if return_robustness: |
| return kernel_matrix.cpu(), rhos, selfk, len0 |
| else: |
| return kernel_matrix.cpu() |
|
|
| def compute_one_bag(self, phi1, phis2, return_robustness=False): |
| phis1: list = [phi1] |
| return self.compute_bag_bag(phis1, phis2, return_robustness) |
|
|
| def compute_bag_bag(self, phis1, phis2, return_robustness=False): |
| if self.integrate_time: |
| rhos1, selfk1, len1 = self._compute_robustness_time(phis1) |
| rhos2, selfk2, len2 = self._compute_robustness_time(phis2) |
| kernel_matrix = self._compute_kernel_time( |
| rhos1, rhos2, selfk1, selfk2, len1, len2 |
| ) |
| else: |
| rhos1, selfk1 = self._compute_robustness_no_time(phis1) |
| rhos2, selfk2 = self._compute_robustness_no_time(phis2) |
| len1, len2 = [None, None] |
| kernel_matrix = self._compute_kernel_no_time(rhos1, rhos2, selfk1, selfk2) |
| if return_robustness: |
| return kernel_matrix.cpu(), rhos1, rhos2, selfk1, selfk2, len1, len2 |
| else: |
| return kernel_matrix.cpu() |
|
|
| def compute_one_from_robustness(self, phi, rhos, rho_self, lengths=None, return_robustness=False): |
| phis: list = [phi] |
| return self.compute_bag_from_robustness(phis, rhos, rho_self, lengths, return_robustness) |
|
|
| def compute_bag_from_robustness(self, phis, rhos, rho_self, lengths=None, return_robustness=False): |
| if self.integrate_time: |
| rhos1, selfk1, len1 = self._compute_robustness_time(phis) |
| kernel_matrix = self._compute_kernel_time( |
| rhos1, rhos, selfk1, rho_self, len1, lengths |
| ) |
| else: |
| rhos1, selfk1 = self._compute_robustness_no_time(phis) |
| len1 = None |
| kernel_matrix = self._compute_kernel_no_time(rhos1, rhos, selfk1, rho_self) |
| if return_robustness: |
| return kernel_matrix.cpu(), rhos1, selfk1, len1 |
| else: |
| return kernel_matrix.cpu() |
|
|
| def _compute_robustness_time(self, phis): |
| n = self.samples |
| p = self.points |
| k = len(phis) |
| rhos = torch.zeros((k, n, p), device="cpu") |
| lengths = torch.zeros(k) |
| self_kernels = torch.zeros((k, 1)) |
| for i, phi in enumerate(phis): |
| if self.boolean: |
| rho = phi.boolean(self.signals, evaluate_at_all_times=True).float() |
| rho[rho == 0.0] = -1.0 |
| else: |
| rho = phi.quantitative(self.signals, evaluate_at_all_times=True) |
| actual_p = rho.size()[2] |
| rho = rho.reshape(n, actual_p).cpu() |
| rhos[i, :, :actual_p] = rho |
| lengths[i] = actual_p |
| self_kernels[i] = torch.tensordot( |
| rho.reshape(1, n, -1), rho.reshape(1, n, -1), dims=[[1, 2], [1, 2]] |
| ) / (actual_p * n) |
| return rhos, self_kernels, lengths |
|
|
| def _compute_robustness_no_time(self, phis): |
| n = self.samples |
| k = len(phis) |
| rhos = torch.zeros((k, n), device=self.traj_measure.device) |
| self_kernels = torch.zeros((k, 1), device=self.traj_measure.device) |
| for i, phi in enumerate(phis): |
| if self.boolean: |
| rho = phi.boolean(self.signals, evaluate_at_all_times=False).float() |
| rho[rho == 0.0] = -1.0 |
| else: |
| rho = phi.quantitative(self.signals, evaluate_at_all_times=False) |
| self_kernels[i] = rho.dot(rho) / n |
| rhos[i, :] = rho |
| return rhos, self_kernels |
|
|
| def _compute_kernel_time(self, rhos1, rhos2, selfk1, selfk2, len1, len2): |
| kernel_matrix = torch.tensordot(rhos1, rhos2, [[1, 2], [1, 2]]) |
| length_normalizer = self._compute_trajectory_length_normalizer(len1, len2) |
| kernel_matrix = kernel_matrix * length_normalizer / self.samples |
| if self.normalize: |
| kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2) |
| if self.exp_kernel: |
| kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2) |
| return kernel_matrix |
|
|
| def _compute_kernel_no_time(self, rhos1, rhos2, selfk1, selfk2): |
| kernel_matrix = torch.tensordot(rhos1, rhos2, [[1], [1]]) |
| kernel_matrix = kernel_matrix / self.samples |
| if self.normalize: |
| kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2) |
| if self.exp_kernel: |
| kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2) |
| return kernel_matrix |
|
|
| @staticmethod |
| def _normalize(kernel_matrix, selfk1, selfk2): |
| normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1))) |
| kernel_matrix = kernel_matrix / normalize |
| return kernel_matrix |
|
|
| def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None): |
| if sigma2 is None: |
| sigma2 = self.sigma2 |
| if self.normalize: |
| |
| selfk = 2.0 |
| else: |
| k1 = selfk1.size()[0] |
| k2 = selfk2.size()[0] |
| selfk = (selfk1 * selfk1).repeat(1, k2) + torch.transpose( |
| selfk2 * selfk2, 0, 1 |
| ).repeat(k1, 1) |
| return torch.exp(-(selfk - 2 * kernel_matrix) / (2 * sigma2)) |
|
|
| @staticmethod |
| def _compute_trajectory_length_normalizer(len1, len2): |
| k1 = len1.size()[0] |
| k2 = len2.size()[0] |
| y1 = len1.reshape(-1, 1) |
| y1 = y1.repeat(1, k2) |
| y2 = len2.repeat(k1, 1) |
| return 1.0 / torch.min(y1, y2) |
|
|
| class GramMatrix: |
| def __init__(self, kernel, formulae, store_robustness=True, sample=False, sampler=None, bag_size=None): |
| self.kernel = kernel |
| self.formulae_list = formulae |
| |
| |
| |
| self.store_robustness = store_robustness |
| self.dim = len(self.formulae_list) if not bag_size else int(bag_size) |
| self.sample = sample |
| if self.sample: |
| self.t = 0.99 if self.kernel.boolean else 0.85 |
| self.sampler = sampler |
| self._compute_gram_matrix() |
|
|
| def _compute_gram_matrix(self): |
| if self.sample: |
| gram = torch.zeros(self.dim, self.dim) |
| rhos = torch.zeros((self.dim, self.kernel.samples), device=self.kernel.traj_measure.device) if \ |
| not self.kernel.integrate_time else torch.zeros((self.dim, self.kernel.samples, self.kernel.points), |
| device=self.kernel.traj_measure.device) |
| lengths = torch.zeros(self.dim) if self.kernel.integrate_time else np.zeros(self.dim) |
| kernels = torch.zeros((self.dim, 1), device=self.kernel.traj_measure.device) |
| phis = [self.sampler.sample(nvars=self.kernel.varn)] |
| gram[0, :1], rhos[0], kernels[0, :], lengths[0] = self.kernel.compute_bag(phis, return_robustness=True) |
| while len(phis) < self.dim: |
| i = len(phis) |
| phi = self.sampler.sample(nvars=self.kernel.varn) |
| gram[i, :i], rhos[i], kernels[i, :], lengths[i] = self.kernel.compute_one_from_robustness( |
| phi, rhos[:i, :], kernels[:i, :], lengths[:i], return_robustness=True) |
| if torch.sum(gram[i, :i + 1] >= self.t) < 3: |
| phis.append(phi) |
| gram[:i, i] = gram[i, :i] |
| gram[i, i] = kernels[i, :] |
|
|
| self.formulae_list = phis |
| self.gram = gram.cpu() |
| self.robustness = rhos if self.store_robustness else None |
| self.self_kernels = kernels if self.store_robustness else None |
| self.robustness_lengths = lengths if self.store_robustness else None |
| else: |
| if self.store_robustness: |
| k_matrix, rhos, selfk, len0 = self.kernel.compute_bag( |
| self.formulae_list, return_robustness=True |
| ) |
| self.gram = k_matrix |
| self.robustness = rhos |
| self.self_kernels = selfk |
| self.robustness_lengths = len0 |
| else: |
| self.gram = self.kernel.compute_bag( |
| self.formulae_list, return_robustness=False |
| ) |
| self.robustness = None |
| self.self_kernels = None |
| self.robustness_lengths = None |
|
|
| def compute_kernel_vector(self, phi): |
| if self.store_robustness: |
| return self.kernel.compute_one_from_robustness( |
| phi, self.robustness, self.self_kernels, self.robustness_lengths |
| ) |
| else: |
| return self.kernel.compute_one_bag(phi, self.formulae_list) |
|
|
| def compute_bag_kernel_vector(self, phis, generate_phis=False, bag_size=None): |
| if generate_phis: |
| gram_test = torch.zeros(bag_size, self.dim) |
| rhos_test = torch.zeros((bag_size, self.kernel.samples), device=self.kernel.traj_measure.device) if \ |
| not self.kernel.integrate_time else torch.zeros((bag_size, self.kernel.samples, self.kernel.points), |
| device=self.kernel.traj_measure.device) |
| lengths_test = torch.zeros(bag_size) if self.kernel.integrate_time else np.zeros(bag_size) |
| kernels_test = torch.zeros((bag_size, 1), device=self.kernel.traj_measure.device) |
| phi_test = [] |
| while len(phi_test) < bag_size: |
| i = len(phi_test) |
| phi = self.sampler.sample(nvars=self.kernel.varn) |
| if self.store_robustness: |
| gram_test[i, :], rhos_test[i], kernels_test[i, :], lengths_test[i] = \ |
| self.kernel.compute_one_from_robustness(phi, self.robustness, self.self_kernels, |
| self.robustness_lengths, return_robustness=True) |
| else: |
| gram_test[i, :], rhos_test[i], _, kernels_test[i, :], _, lengths_test[i], _ = \ |
| self.kernel.compute_one_bag(phi, self.formulae_list, return_robustness=True) |
| if not ((rhos_test[i] > 0).all() or (rhos_test[i] < 0).all()): |
| phi_test.append(phi) |
| return phi_test, gram_test.cpu() |
| else: |
| if self.store_robustness: |
| return self.kernel.compute_bag_from_robustness( |
| phis, self.robustness, self.self_kernels, self.robustness_lengths |
| ) |
| else: |
| return self.kernel.compute_bag_bag(phis, self.formulae_list) |
|
|
| def invert_regularized(self, alpha): |
| regularizer = abs(pow(10, alpha)) * torch.eye(self.dim) |
| return torch.inverse(self.gram + regularizer) |
|
|
| |
|
|
| def anchorGeneration(diff_init = False, |
| embed_dim: int = 30, |
| n_vars: int = 3, |
| leaf_prob: float = 0.4, |
| cosine_similarity_threshold: float = 0.8 |
| ) -> str: |
| |
| |
| sampler = StlGenerator(leaf_prob) |
| |
| |
| if diff_init: |
| |
| |
| diff_anchor_set = [sampler.sample(nvars=n_vars)] |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| mu = BaseMeasure(device=device) |
|
|
| |
| signals = mu.sample(samples=10000, varn=n_vars) |
|
|
| |
| anchor_rob_vectors = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in diff_anchor_set], 0) |
|
|
| while len(diff_anchor_set) < embed_dim: |
| |
| candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars) |
| |
| |
| candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0) |
| |
| |
| cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1) |
|
|
| |
| |
| similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])] |
| |
| |
| keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist]))) |
| |
| diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx] |
| |
| |
| keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device) |
| |
| |
| selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor) |
| |
| |
| anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0) |
|
|
| anchor_set = diff_anchor_set[:embed_dim] |
| |
| else: |
| anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars) |
|
|
| filename = f'anchor_set_no_diff_{embed_dim}_dim' |
| dump_pickle(filename, anchor_set) |
| return filename |
|
|
| |
|
|
| class STLTokenizer(PreTrainedTokenizer): |
| """ |
| A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process. |
| |
| This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs, |
| and handle padding and special tokens. |
| """ |
|
|
| def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad", |
| bos_token: str = "/s", eos_token: str = "s", model_max_length = 512, *args, **kwargs): |
| """ |
| Initializes the STLTokenizer with a given vocabulary and special tokens. |
| |
| Args: |
| vocab_path (str): The path to the JSON file containing the vocabulary. |
| unk_token (str, optional): The token used for unknown words. Defaults to "unk". |
| pad_token (str, optional): The token used for padding. Defaults to "pad". |
| bos_token (str, optional): The token used for the beginning of a sequence. Defaults to "/s". |
| eos_token (str, optional): The token used for the end of a sequence. Defaults to "s". |
| """ |
| self.vocab = load_json(vocab_path) |
| self.unk_token = unk_token |
| self.pad_token = pad_token |
| self.bos_token = bos_token |
| self.eos_token = eos_token |
| self.model_max_length = model_max_length |
| self.id_to_token = {v: k for k, v in self.vocab.items()} |
| super().__init__(unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token, |
| model_max_length=model_max_length, *args, **kwargs) |
|
|
| @property |
| def vocab_size(self) -> int: |
| """ |
| Returns the size of the vocabulary. |
| |
| Returns: |
| int: The number of tokens in the vocabulary. |
| """ |
| return len(self.vocab) |
|
|
| def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False): |
| """ |
| Replaces spaces in the input sequence with a specified token. |
| |
| Args: |
| sequence (str): The input sequence. |
| undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces. |
| |
| Returns: |
| str: The preprocessed sequence with spaces or padding tokens replaced. |
| """ |
| if undo: |
| return sequence.replace(new_space_token, space_token) |
| else: |
| return sequence.replace(space_token, new_space_token) |
|
|
| def add_bos_eos(self, sequence: str) -> str: |
| """ |
| Aggiunge i token BOS all'inizio e EOS alla fine della sequenza. |
| |
| Args: |
| sequence (str): La sequenza di input. |
| |
| Returns: |
| str: La sequenza con i token BOS ed EOS. |
| """ |
| return f'{self.bos_token} {sequence} {self.eos_token}' |
|
|
| def tokenize(self, text: str) -> List[str]: |
| """ |
| Tokenizes the input text into a list of tokens. |
| |
| The method preprocesses the input text by replacing spaces with padding tokens and then tries to |
| find the longest possible match for each substring in the vocabulary. |
| |
| Args: |
| text (str): The input text to be tokenized. |
| |
| Returns: |
| List[str]: A list of tokens representing the tokenized text. |
| """ |
| text = self.add_bos_eos(text) |
| text = self.prepad_sequence(text) |
| |
| tokens = [] |
| i = 0 |
| while i < len(text): |
| best_match = None |
| for j in range(len(text), i, -1): |
| subtoken = text[i:j] |
| if subtoken in self.vocab: |
| best_match = subtoken |
| break |
| if best_match: |
| tokens.append(best_match) |
| i += len(best_match) |
| else: |
| tokens.append(self.unk_token) |
| i += 1 |
| return tokens |
|
|
| def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: |
| """ |
| Converts a list of tokens into a list of token IDs. |
| |
| Args: |
| tokens (List[str]): A list of tokens to be converted into IDs. |
| |
| Returns: |
| List[int]: A list of corresponding token IDs. |
| """ |
| return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens] |
|
|
| def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: |
| """ |
| Converts a list of token IDs into a list of tokens. |
| |
| Args: |
| ids (List[int]): A list of token IDs to be converted into tokens. |
| |
| Returns: |
| List[str]: A list of corresponding tokens. |
| """ |
| return [self.id_to_token.get(i, self.unk_token) for i in ids] |
|
|
| def encode(self, sequence: str) -> List[int]: |
| """ |
| Encodes a string sequence into a list of token IDs. |
| |
| This method tokenizes the input sequence using the `tokenize` method, |
| and then converts the resulting tokens into their corresponding token IDs |
| using the `convert_tokens_to_ids` method. |
| |
| Args: |
| sequence (str): The input sequence (text) to be encoded. |
| |
| Returns: |
| List[int]: A list of token IDs corresponding to the input sequence. |
| """ |
| splitted_sequence = self.tokenize(sequence) |
| return self.convert_tokens_to_ids(splitted_sequence) |
|
|
| def postpad_sequence(self, sequence, pad_token_id): |
| """ |
| Fills the sequence up to max_length padding elements |
| """ |
| num_extra_elements = self.model_max_length - len(sequence) -1 |
| if num_extra_elements > 0: |
| sequence.extend([pad_token_id] * num_extra_elements) |
| return sequence |
|
|
| def decode(self, token_ids: List[int]) -> str: |
| """ |
| Decodes a list of token IDs into a string of text. |
| |
| The method converts the IDs to tokens and joins them to form a string. |
| It also restores the original spaces or padding tokens if `undo` is True. |
| |
| Args: |
| token_ids (List[int]): A list of token IDs to be decoded. |
| skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False. |
| |
| Returns: |
| str: The decoded string. |
| """ |
| tokens = self.convert_ids_to_tokens(token_ids) |
| decoded = "".join(tokens) |
| return self.prepad_sequence(decoded, undo=True) |
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| """ |
| Saves the tokenizer's vocabulary to a file. |
| Useful only when the vocabulary has to be retrieved and is not given |
| (thus this is not the case: here to further improvements with sentencepiece). |
| |
| This method saves the vocabulary to a JSON file in the specified directory. |
| |
| Args: |
| save_directory (str): The directory where the vocabulary file will be saved. |
| filename_prefix (Optional[str]): An optional prefix for the filename. |
| |
| Returns: |
| Tuple[str]: A tuple containing the path to the saved vocabulary file. |
| """ |
| vocab_file = f"{save_directory}/{filename_prefix + '-' if filename_prefix else ''}vocab.json" |
| with open(vocab_file, "w", encoding="utf-8") as f: |
| json.dump(self.vocab, f, indent=2, ensure_ascii=False) |
| return (vocab_file,) |
|
|
| def get_vocab(self) -> dict: |
| """ |
| Retrieves the vocabulary used by the tokenizer. |
| |
| Returns: |
| dict: The vocabulary as a dictionary. |
| """ |
| return self.vocab |
|
|
| class STLSinusoidalPositionalEmbedding(nn.Embedding): |
| """This module produces sinusoidal positional embeddings of any length.""" |
|
|
| def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: |
| super().__init__(num_positions, embedding_dim) |
| self.weight = self._init_weight(self.weight) |
|
|
| @staticmethod |
| def _init_weight(out: nn.Parameter) -> nn.Parameter: |
| """ |
| Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in |
| the 2nd half of the vector. [dim // 2:] |
| """ |
| n_pos, dim = out.shape |
| position_enc = np.array( |
| [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] |
| ) |
| out.requires_grad = False |
| sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 |
| out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) |
| out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) |
| out.detach_() |
| return out |
|
|
| @torch.no_grad() |
| def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: |
| """`input_ids_shape` is expected to be [bsz x seqlen].""" |
| bsz, seq_len = input_ids_shape[:2] |
| positions = torch.arange( |
| past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device |
| ) |
| return super().forward(positions) |
|
|
| class STLAttention(nn.Module): |
| """ Multi-Head Attention as depicted from 'Attention is all you need' """ |
|
|
| def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, |
| is_decoder: bool = False, bias: bool = False, is_causal: bool = False): |
| |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.dropout = dropout |
| self.head_dim = embed_dim // num_heads |
| assert (self.head_dim * num_heads) == self.embed_dim |
| self.scaling = self.head_dim ** -0.5 |
| self.is_decoder = is_decoder |
| self.is_causal = is_causal |
|
|
| |
| self.W_k = nn.Linear(embed_dim, embed_dim, bias = bias) |
| self.W_q = nn.Linear(embed_dim, embed_dim, bias = bias) |
| self.W_v = nn.Linear(embed_dim, embed_dim, bias = bias) |
|
|
| |
| self.W_o = nn.Linear(embed_dim, embed_dim, bias = bias) |
|
|
|
|
| def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): |
| return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
| |
| |
| def forward(self, |
| hidden_states: torch.Tensor, |
| key_value_states: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| output_attentions: bool = False |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
| is_cross_attention = key_value_states is not None |
|
|
| batch_size, tgt_len, embed_dim = hidden_states.size() |
|
|
| |
| query = self.W_q(hidden_states) * self.scaling |
|
|
| if (is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]): |
| key = past_key_value[0] |
| value = past_key_value[1] |
| elif is_cross_attention: |
| key = self._shape(self.W_k(key_value_states), -1, batch_size) |
| value = self._shape(self.W_v(key_value_states), -1, batch_size) |
| elif past_key_value is not None: |
| key = self._shape(self.W_k(hidden_states), -1, batch_size) |
| value = self._shape(self.W_v(hidden_states), -1, batch_size) |
| key = torch.cat([past_key_value[0], key], dim=2) |
| value = torch.cat([past_key_value[1], value], dim=2) |
| else: |
| key = self._shape(self.W_k(hidden_states), -1, batch_size) |
| value = self._shape(self.W_v(hidden_states), -1, batch_size) |
|
|
| if self.is_decoder: |
| past_key_value = (key, value) |
| |
| proj_shape = (batch_size * self.num_heads, -1, self.head_dim) |
|
|
| query = self._shape(query, tgt_len, batch_size).view(*proj_shape) |
| key = key.reshape(*proj_shape) |
| value = value.reshape(*proj_shape) |
|
|
| src_len = key.size(1) |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| attn_weights = torch.bmm(query, key.transpose(1, 2)) |
|
|
| if attention_mask is not None: |
| attn_weights = attn_weights.view(batch_size, self.num_heads, tgt_len, src_len) + attention_mask |
| attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len) |
| |
| |
| attn_weights = F.softmax(attn_weights, dim=-1) |
|
|
| |
| |
| |
|
|
| attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
| |
| attn_output = torch.bmm(attn_probs, value) |
|
|
| |
|
|
| attn_output = attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim) |
| attn_output = attn_output.transpose(1, 2) |
|
|
| attn_output = attn_output.reshape(batch_size, tgt_len, self.embed_dim) |
| attn_output = self.W_o(attn_output) |
|
|
| return attn_output, None, past_key_value |
|
|
| |
|
|
| class STLEncoder(): |
| def __init__(self, |
| embed_dim: int, |
| anchor_filename: Optional[str] = None, |
| n_vars: int = 3): |
| |
| self.n_vars = n_vars |
| self.embed_dim = embed_dim |
| self.anchorset_filename = anchor_filename |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| self.mu = BaseMeasure(device=self.device) |
| self.kernel = StlKernel(self.mu, varn=self.n_vars) |
|
|
| if anchor_filename is None: |
| anchor_filename = anchorGeneration(diff_init = True, embed_dim = self.embed_dim, n_vars = self.n_vars) |
| anchor_filename+='.pickle' |
|
|
| |
| anchor_set = load_pickle(anchor_filename) |
| if len(anchor_set) != self.embed_dim: |
| raise ValueError("The anchor set and the embedding dimension do not match!") |
|
|
| self.anchor_set = anchor_set |
|
|
| def compute_embeddings(self, formula: List[str]): |
| return self.kernel.compute_bag_bag(formula, self.anchor_set) |
|
|
| class STLModel(PreTrainedModel): |
| config_class = STLConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
|
|
| |
| def _init_weights(self, module: Union[nn.Linear, nn.Embedding, STLSinusoidalPositionalEmbedding]): |
| std = self.config.init_std |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, STLSinusoidalPositionalEmbedding): |
| pass |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| @property |
| def dummy_inputs(self): |
| pad_token = self.config.pad_token_id |
| input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) |
| dummy_inputs = { |
| "attention_mask": input_ids.ne(pad_token), |
| "input_ids": input_ids, |
| "decoder_input_ids": input_ids, |
| } |
| return dummy_inputs |
|
|
| class STLDecoderBlock(nn.Module): |
| |
| def __init__(self, embed_dim: int, |
| num_decoder_attention_heads: int, |
| num_decoder_ffn_dim: int, |
| dropout: float = 0.0, |
| attention_dropout: float = 0.0, |
| activation_dropout: float = 0.0, |
| ): |
| |
| super().__init__() |
| |
| self.embed_dim = embed_dim |
|
|
| |
| self.self_attn = STLAttention( |
| embed_dim=self.embed_dim, |
| num_heads=num_decoder_attention_heads, |
| dropout=dropout, |
| is_decoder=True, |
| is_causal=True, |
| ) |
| self.dropout = dropout |
| self.activation_fn = nn.functional.gelu |
| self.activation_dropout = activation_dropout |
| self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
| |
| self.encoder_attn = STLAttention( |
| self.embed_dim, |
| num_decoder_attention_heads, |
| dropout=attention_dropout, |
| is_decoder=True, |
| ) |
| self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
| |
| self.fc1 = nn.Linear(self.embed_dim, num_decoder_ffn_dim) |
| self.fc2 = nn.Linear(num_decoder_ffn_dim, self.embed_dim) |
| self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| cross_attn_layer_head_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = True, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`): attention mask of size |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
| encoder_hidden_states (`torch.FloatTensor`): |
| cross attention input to the layer of shape `(batch, seq_len, embed_dim)` |
| encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
| layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size |
| `(encoder_attention_heads,)`. |
| cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of |
| size `(decoder_attention_heads,)`. |
| past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| """ |
| |
| |
| |
| |
|
|
| |
| residual = hidden_states |
|
|
| |
| |
| self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
|
|
| |
| |
| hidden_states, self_attn_weights, present_key_value = self.self_attn.forward( |
| hidden_states=hidden_states, |
| past_key_value=self_attn_past_key_value, |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
| |
| hidden_states = residual + hidden_states |
|
|
| |
| hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
| |
|
|
| |
|
|
| |
| cross_attn_present_key_value = None |
| cross_attn_weights = None |
|
|
| |
| if encoder_hidden_states is not None: |
|
|
| |
| residual = hidden_states |
|
|
| |
| cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None |
|
|
| |
| hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn.forward( |
| hidden_states=hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| layer_head_mask=cross_attn_layer_head_mask, |
| past_key_value=cross_attn_past_key_value, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
| |
| hidden_states = residual + hidden_states |
|
|
| |
| hidden_states = self.encoder_attn_layer_norm(hidden_states) |
|
|
| |
| present_key_value = present_key_value + cross_attn_present_key_value |
|
|
| |
|
|
| |
|
|
| |
| residual = hidden_states |
|
|
| |
| hidden_states = self.activation_fn(self.fc1(hidden_states)) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
| hidden_states = self.fc2(hidden_states) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
| |
| hidden_states = residual + hidden_states |
|
|
| |
| hidden_states = self.final_layer_norm(hidden_states) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights, cross_attn_weights) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
| class STLDecoder(STLModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| |
| embed_dim = config.d_model |
| num_decoder_attention_heads = config.decoder_attention_heads |
| num_decoder_ffn_dim = config.decoder_ffn_dim |
| max_position_embeddings = config.max_position_embeddings |
| decoder_vocab_size = config.vocab_size |
| pad_token_id = config.pad_token_id |
| num_decoder_layers = config.decoder_layers |
| scale_embedding = config.scale_embedding |
| dropout = config.dropout |
| attention_dropout = config.attention_dropout |
| activation_dropout = config.activation_dropout |
| decoder_layerdrop = config.decoder_layerdrop |
| |
| self.dropout = dropout |
| self.layerdrop = decoder_layerdrop |
| self.padding_idx = pad_token_id |
| self.max_target_positions = max_position_embeddings |
| self.embed_scale = math.sqrt(embed_dim) if scale_embedding else 1.0 |
|
|
| |
| self.embed_tokens = nn.Embedding(decoder_vocab_size, embed_dim, self.padding_idx) |
| |
| |
| self.embed_positions = STLSinusoidalPositionalEmbedding( |
| max_position_embeddings, embed_dim, self.padding_idx |
| ) |
| |
| |
| self.layers = nn.ModuleList([STLDecoderBlock(embed_dim, num_decoder_attention_heads, |
| num_decoder_ffn_dim, dropout, |
| attention_dropout, activation_dropout) |
| for _ in range(num_decoder_layers)]) |
|
|
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| cross_attn_head_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: |
| |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| input_ids = input_ids.view(-1, input_shape[-1]) |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
| |
| past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale |
|
|
| attention_mask = _prepare_4d_causal_attention_mask( |
| attention_mask, input_shape, inputs_embeds, past_key_values_length |
| ) |
|
|
| |
| if encoder_hidden_states is not None and encoder_attention_mask is not None: |
| |
| encoder_attention_mask = _prepare_4d_attention_mask( |
| encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] |
| ) |
|
|
| |
| positions = self.embed_positions(input_shape, past_key_values_length) |
|
|
| hidden_states = inputs_embeds + positions |
|
|
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
| next_decoder_cache = () if use_cache else None |
|
|
| for idx, decoder_layer in enumerate(self.layers): |
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
| if self.training: |
| dropout_probability = torch.rand([]) |
| if dropout_probability < self.layerdrop: |
| continue |
|
|
| past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| encoder_attention_mask, |
| head_mask[idx] if head_mask is not None else None, |
| cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, |
| None, |
| output_attentions, |
| use_cache, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| cross_attn_layer_head_mask=( |
| cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None |
| ), |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| if encoder_hidden_states is not None: |
| all_cross_attentions += (layer_outputs[2],) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
| if not return_dict: |
| return tuple( |
| v |
| for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] |
| if v is not None |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| cross_attentions=all_cross_attentions, |
| ) |
|
|
| |
|
|
| class STLForCausalLM(STLModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| config = copy.deepcopy(config) |
| config.is_decoder = True |
| config.is_encoder_decoder = False |
| |
| super().__init__(config) |
| self.model = STLDecoder(config) |
|
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| cross_attn_head_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| head_mask=head_mask, |
| cross_attn_head_mask=cross_attn_head_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| logits = self.lm_head(outputs[0]) |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithCrossAttentions( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| cross_attentions=outputs.cross_attentions, |
| ) |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
| ) |
| return reordered_past |
|
|