saracandu commited on
Commit
21c7f66
·
verified ·
1 Parent(s): 8d3eaaf

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +675 -124
modeling.py CHANGED
@@ -32,12 +32,597 @@ from transformers.modeling_outputs import (
32
  Seq2SeqModelOutput,
33
  )
34
 
35
- from configuration import STLConfig
36
  from nltk.translate.bleu_score import sentence_bleu
37
- from stl import *
38
  import networkx as nx
39
  from datasets import load_dataset
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # from anchor_set_generation import anchorGeneration
43
 
@@ -152,10 +737,9 @@ def from_string_to_formula(st):
152
  def load_json(path: str) -> Union[Dict, List]:
153
  """
154
  Load a JSON file from the given path.
155
-
156
  Args:
157
  path (str): The path to the JSON file to be loaded.
158
-
159
  Returns:
160
  Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list.
161
  """
@@ -216,35 +800,28 @@ class StlGenerator:
216
  def sample(self, nvars):
217
  """
218
  Samples a random formula with distribution defined in class instance parameters
219
-
220
  Parameters
221
  ----------
222
  nvars : number of variables of input signals
223
  how many variables the formula is expected to consider.
224
-
225
  Returns
226
  -------
227
  TYPE
228
  A random formula.
229
-
230
  """
231
  return self._sample_internal_node(nvars)
232
-
233
  def bag_sample(self, bag_size, nvars):
234
  """
235
  Samples a bag of bag_size formulae
236
-
237
  Parameters
238
  ----------
239
  bag_size : INT
240
  number of formulae.
241
  nvars : INT
242
  number of vars in formulae.
243
-
244
  Returns
245
  -------
246
  a list of formulae.
247
-
248
  """
249
  formulae = []
250
  for _ in range(bag_size):
@@ -261,32 +838,32 @@ class StlGenerator:
261
  while True:
262
  if nodetype == "not":
263
  n = self._sample_node(nvars)
264
- node = stl.Not(n)
265
  elif nodetype == "and":
266
  n1 = self._sample_node(nvars)
267
  n2 = self._sample_node(nvars)
268
- node = stl.And(n1, n2)
269
  elif nodetype == "or":
270
  n1 = self._sample_node(nvars)
271
  n2 = self._sample_node(nvars)
272
- node = stl.Or(n1, n2)
273
  elif nodetype == "always":
274
  n = self._sample_node(nvars)
275
  unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
276
- node = stl.Globally(
277
  n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
278
  )
279
  elif nodetype == "eventually":
280
  n = self._sample_node(nvars)
281
  unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
282
- node = stl.Eventually(
283
  n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
284
  )
285
  elif nodetype == "until":
286
  n1 = self._sample_node(nvars)
287
  n2 = self._sample_node(nvars)
288
  unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
289
- node = stl.Until(
290
  n1, n2, unbound, right_unbound, left_time_bound, right_time_bound
291
  )
292
 
@@ -297,7 +874,7 @@ class StlGenerator:
297
  if rnd.rand() < self.leaf_prob:
298
  # sample a leaf
299
  var, thr, lte = self._get_atom(nvars)
300
- return stl.Atom(var, thr, lte)
301
  else:
302
  return self._sample_internal_node(nvars)
303
 
@@ -328,7 +905,6 @@ class BaseMeasure(Measure):
328
  self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu"
329
  ):
330
  """
331
-
332
  Parameters
333
  ----------
334
  mu0 : mean of normal distribution of initial state, optional
@@ -345,11 +921,9 @@ class BaseMeasure(Measure):
345
  probability of initial sign of derivative. The default is 0.5.
346
  device : 'cpu' or 'cuda', optional
347
  device on which to run the algorithm. The default is 'cpu'.
348
-
349
  Returns
350
  -------
351
  None.
352
-
353
  """
354
  self.mu0 = mu0
355
  self.sigma0 = sigma0
@@ -363,7 +937,6 @@ class BaseMeasure(Measure):
363
  """
364
  Samples a set of trajectories from the basic measure space, with parameters
365
  passed to the sampler
366
-
367
  Parameters
368
  ----------
369
  points : INT, optional
@@ -372,13 +945,10 @@ class BaseMeasure(Measure):
372
  number of trajectories. The default is 100000.
373
  varn : INT, optional
374
  number of variables per trajectory. The default is 2.
375
-
376
-
377
  Returns
378
  -------
379
  signal : samples x varn x points double pytorch tensor
380
  The sampled signals.
381
-
382
  """
383
  if self.device == "cuda" and not torch.cuda.is_available():
384
  raise RuntimeError("GPU card or CUDA library not available!")
@@ -513,8 +1083,6 @@ class StlKernel:
513
  return kernel_matrix.cpu(), rhos1, selfk1, len1
514
  else:
515
  return kernel_matrix.cpu()
516
-
517
- def _compute_robustness_time(self, phis):
518
  n = self.samples
519
  p = self.points
520
  k = len(phis)
@@ -576,6 +1144,12 @@ class StlKernel:
576
  kernel_matrix = kernel_matrix / normalize
577
  return kernel_matrix
578
 
 
 
 
 
 
 
579
  def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None):
580
  if sigma2 is None:
581
  sigma2 = self.sigma2
@@ -706,13 +1280,13 @@ def anchorGeneration(diff_init = False, # to control whether we want formulae to
706
  leaf_prob: float = 0.4, # complexity of the generated formula
707
  cosine_similarity_threshold: float = 0.8 # if two formulae cosine similarity exceeds 0.9, then discard one of the two
708
  ) -> str:
709
-
710
  # initialize STL formula generator
711
  sampler = StlGenerator(leaf_prob)
712
-
713
  # effective anchor set generation
714
  if diff_init:
715
-
716
  # initialize the anchor set with a randomly sampled formula
717
  diff_anchor_set = [sampler.sample(nvars=n_vars)]
718
 
@@ -728,35 +1302,35 @@ def anchorGeneration(diff_init = False, # to control whether we want formulae to
728
  while len(diff_anchor_set) < embed_dim:
729
  # sample the 'remaining' formulae to reach the desired number of `embed_dim` formulae:
730
  candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars)
731
-
732
  # compute robustness of candidate anchor formulae on the same signals as previous anchor set
733
  candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0)
734
-
735
  # compute cosine similarity between current anchor set and candidate new formulae
736
  cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1)
737
 
738
  # check which formulae are similar (i.e. greater cosine similarity then threshold) w.r.t. current anchors
739
  # NOTA: chiedere a gaia se cosine similarities negative vanno ammazzate con un valore assoluto o meno!
740
  similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])]
741
-
742
  # keep only those who are semantically distant
743
  keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist])))
744
-
745
  diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx]
746
-
747
  # Convert keep_idx to a tensor on the same device as candidate_robs
748
  keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device)
749
-
750
  # Use index_select to pick the relevant rows
751
  selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor)
752
-
753
  # Concatenate on the same device
754
  anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0)
755
 
756
  anchor_set = diff_anchor_set[:embed_dim]
757
-
758
  else:
759
- anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars)
760
 
761
  filename = f'anchor_set_no_diff_{embed_dim}_dim'
762
  dump_pickle(filename, anchor_set)
@@ -764,19 +1338,16 @@ def anchorGeneration(diff_init = False, # to control whether we want formulae to
764
 
765
  ####
766
 
767
- class STLTokenizer(PreTrainedTokenizer):
768
  """
769
  A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process.
770
-
771
- This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs,
772
  and handle padding and special tokens.
773
  """
774
 
775
- def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad",
776
  bos_token: str = "/s", eos_token: str = "s", model_max_length = 512, *args, **kwargs):
777
  """
778
  Initializes the STLTokenizer with a given vocabulary and special tokens.
779
-
780
  Args:
781
  vocab_path (str): The path to the JSON file containing the vocabulary.
782
  unk_token (str, optional): The token used for unknown words. Defaults to "unk".
@@ -791,14 +1362,13 @@ class STLTokenizer(PreTrainedTokenizer):
791
  self.eos_token = eos_token
792
  self.model_max_length = model_max_length
793
  self.id_to_token = {v: k for k, v in self.vocab.items()} # Reverse mapping
794
- super().__init__(unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token,
795
  model_max_length=model_max_length, *args, **kwargs)
796
 
797
  @property
798
  def vocab_size(self) -> int:
799
  """
800
  Returns the size of the vocabulary.
801
-
802
  Returns:
803
  int: The number of tokens in the vocabulary.
804
  """
@@ -807,11 +1377,9 @@ class STLTokenizer(PreTrainedTokenizer):
807
  def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False):
808
  """
809
  Replaces spaces in the input sequence with a specified token.
810
-
811
  Args:
812
  sequence (str): The input sequence.
813
  undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces.
814
-
815
  Returns:
816
  str: The preprocessed sequence with spaces or padding tokens replaced.
817
  """
@@ -823,10 +1391,8 @@ class STLTokenizer(PreTrainedTokenizer):
823
  def add_bos_eos(self, sequence: str) -> str:
824
  """
825
  Aggiunge i token BOS all'inizio e EOS alla fine della sequenza.
826
-
827
  Args:
828
  sequence (str): La sequenza di input.
829
-
830
  Returns:
831
  str: La sequenza con i token BOS ed EOS.
832
  """
@@ -835,19 +1401,15 @@ class STLTokenizer(PreTrainedTokenizer):
835
  def tokenize(self, text: str) -> List[str]:
836
  """
837
  Tokenizes the input text into a list of tokens.
838
-
839
- The method preprocesses the input text by replacing spaces with padding tokens and then tries to
840
  find the longest possible match for each substring in the vocabulary.
841
-
842
  Args:
843
  text (str): The input text to be tokenized.
844
-
845
  Returns:
846
  List[str]: A list of tokens representing the tokenized text.
847
  """
848
  text = self.add_bos_eos(text)
849
  text = self.prepad_sequence(text)
850
-
851
  tokens = []
852
  i = 0
853
  while i < len(text):
@@ -868,10 +1430,8 @@ class STLTokenizer(PreTrainedTokenizer):
868
  def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
869
  """
870
  Converts a list of tokens into a list of token IDs.
871
-
872
  Args:
873
  tokens (List[str]): A list of tokens to be converted into IDs.
874
-
875
  Returns:
876
  List[int]: A list of corresponding token IDs.
877
  """
@@ -880,10 +1440,8 @@ class STLTokenizer(PreTrainedTokenizer):
880
  def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
881
  """
882
  Converts a list of token IDs into a list of tokens.
883
-
884
  Args:
885
  ids (List[int]): A list of token IDs to be converted into tokens.
886
-
887
  Returns:
888
  List[str]: A list of corresponding tokens.
889
  """
@@ -892,14 +1450,14 @@ class STLTokenizer(PreTrainedTokenizer):
892
  def encode(self, sequence: str) -> List[int]:
893
  """
894
  Encodes a string sequence into a list of token IDs.
895
-
896
- This method tokenizes the input sequence using the `tokenize` method,
897
- and then converts the resulting tokens into their corresponding token IDs
898
  using the `convert_tokens_to_ids` method.
899
-
900
  Args:
901
  sequence (str): The input sequence (text) to be encoded.
902
-
903
  Returns:
904
  List[int]: A list of token IDs corresponding to the input sequence.
905
  """
@@ -908,8 +1466,8 @@ class STLTokenizer(PreTrainedTokenizer):
908
 
909
  def postpad_sequence(self, sequence, pad_token_id):
910
  """
911
- Fills the sequence up to max_length padding elements
912
- """
913
  num_extra_elements = self.model_max_length - len(sequence) -1
914
  if num_extra_elements > 0:
915
  sequence.extend([pad_token_id] * num_extra_elements)
@@ -918,14 +1476,11 @@ class STLTokenizer(PreTrainedTokenizer):
918
  def decode(self, token_ids: List[int]) -> str:
919
  """
920
  Decodes a list of token IDs into a string of text.
921
-
922
- The method converts the IDs to tokens and joins them to form a string.
923
  It also restores the original spaces or padding tokens if `undo` is True.
924
-
925
  Args:
926
  token_ids (List[int]): A list of token IDs to be decoded.
927
  skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False.
928
-
929
  Returns:
930
  str: The decoded string.
931
  """
@@ -935,16 +1490,13 @@ class STLTokenizer(PreTrainedTokenizer):
935
 
936
  def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
937
  """
938
- Saves the tokenizer's vocabulary to a file.
939
- Useful only when the vocabulary has to be retrieved and is not given
940
  (thus this is not the case: here to further improvements with sentencepiece).
941
-
942
- This method saves the vocabulary to a JSON file in the specified directory.
943
-
944
  Args:
945
  save_directory (str): The directory where the vocabulary file will be saved.
946
  filename_prefix (Optional[str]): An optional prefix for the filename.
947
-
948
  Returns:
949
  Tuple[str]: A tuple containing the path to the saved vocabulary file.
950
  """
@@ -956,7 +1508,6 @@ class STLTokenizer(PreTrainedTokenizer):
956
  def get_vocab(self) -> dict:
957
  """
958
  Retrieves the vocabulary used by the tokenizer.
959
-
960
  Returns:
961
  dict: The vocabulary as a dictionary.
962
  """
@@ -985,7 +1536,6 @@ class STLSinusoidalPositionalEmbedding(nn.Embedding):
985
  out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
986
  out.detach_()
987
  return out
988
-
989
  @torch.no_grad()
990
  def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
991
  """`input_ids_shape` is expected to be [bsz x seqlen]."""
@@ -998,40 +1548,39 @@ class STLSinusoidalPositionalEmbedding(nn.Embedding):
998
  class STLAttention(nn.Module):
999
  """ Multi-Head Attention as depicted from 'Attention is all you need' """
1000
 
1001
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0,
1002
  is_decoder: bool = False, bias: bool = False, is_causal: bool = False):
1003
-
1004
  super().__init__()
1005
  self.embed_dim = embed_dim # overall embedding dimension -> to be divided between multiple heads
1006
  self.num_heads = num_heads
1007
  self.dropout = dropout
1008
  self.head_dim = embed_dim // num_heads
1009
- assert (self.head_dim * num_heads) == self.embed_dim
1010
  self.scaling = self.head_dim ** -0.5 # used to normalize values when projected using `W_` matrices
1011
  self.is_decoder = is_decoder
1012
  self.is_causal = is_causal
1013
 
1014
- # 'roleplaying' matrices
1015
- self.W_k = nn.Linear(embed_dim, embed_dim, bias = bias)
1016
  self.W_q = nn.Linear(embed_dim, embed_dim, bias = bias)
1017
  self.W_v = nn.Linear(embed_dim, embed_dim, bias = bias)
1018
 
1019
  # to project the heads' outputs into a single vector
1020
- self.W_o = nn.Linear(embed_dim, embed_dim, bias = bias)
1021
 
1022
 
1023
  def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
1024
  return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
1025
-
1026
-
1027
- def forward(self,
1028
  hidden_states: torch.Tensor, # previous values, passed to the multi-head attn layer
1029
  key_value_states: Optional[torch.Tensor] = None, # different key, value items (used in cross-attn)
1030
- past_key_value: Optional[Tuple[torch.Tensor]] = None, # stores the key and values of previous steps
1031
  attention_mask: Optional[torch.Tensor] = None, # masks non-allowed items (padded or future ones)
1032
  layer_head_mask: Optional[torch.Tensor] = None, # used to de-activate specific attn heads
1033
  output_attentions: bool = False # flag to control the output of the attn values,
1034
- **kwargs,
1035
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1036
 
1037
  is_cross_attention = key_value_states is not None # cross-attn if key_value_states is not None
@@ -1055,19 +1604,18 @@ class STLAttention(nn.Module):
1055
  else:
1056
  key = self._shape(self.W_k(hidden_states), -1, batch_size)
1057
  value = self._shape(self.W_v(hidden_states), -1, batch_size)
1058
-
1059
  if self.is_decoder:
1060
  past_key_value = (key, value)
1061
-
1062
  proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
1063
 
1064
- query = self._shape(query, tgt_len, batch_size).view(*proj_shape)
1065
  key = key.reshape(*proj_shape)
1066
  value = value.reshape(*proj_shape)
1067
 
1068
  src_len = key.size(1)
1069
 
1070
-
1071
  ######################################################################################################
1072
 
1073
  # 'traditional' attention computation
@@ -1079,7 +1627,7 @@ class STLAttention(nn.Module):
1079
  if attention_mask is not None:
1080
  attn_weights = attn_weights.view(batch_size, self.num_heads, tgt_len, src_len) + attention_mask
1081
  attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
1082
-
1083
  # Normalize values on the `key` axis (dim=-1)
1084
  attn_weights = F.softmax(attn_weights, dim=-1)
1085
 
@@ -1098,18 +1646,18 @@ class STLAttention(nn.Module):
1098
  attn_output = attn_output.transpose(1, 2)
1099
 
1100
  attn_output = attn_output.reshape(batch_size, tgt_len, self.embed_dim)
1101
- attn_output = self.W_o(attn_output)
1102
 
1103
  return attn_output, None, past_key_value
1104
 
1105
  ####
1106
 
1107
  class STLEncoder():
1108
- def __init__(self,
1109
  embed_dim: int,
1110
  anchor_filename: Optional[str] = None,
1111
  n_vars: int = 3):
1112
-
1113
  self.n_vars = n_vars # passaglielo in input
1114
  self.embed_dim = embed_dim
1115
  self.anchorset_filename = anchor_filename
@@ -1117,8 +1665,8 @@ class STLEncoder():
1117
  self.mu = BaseMeasure(device=self.device)
1118
  self.kernel = StlKernel(self.mu, varn=self.n_vars)
1119
 
1120
- if anchor_filename is None:
1121
- anchor_filename = anchorGeneration(diff_init = True, embed_dim = self.embed_dim, n_vars = self.n_vars)
1122
  anchor_filename+='.pickle'
1123
 
1124
  # TO DO: check on the dimensions of the anchor set and the `embed_dim` and `n_vars` values
@@ -1132,8 +1680,8 @@ class STLEncoder():
1132
  return self.kernel.compute_bag_bag(formula, self.anchor_set)
1133
 
1134
  class STLModel(PreTrainedModel):
1135
- config_class = STLConfig
1136
- base_model_prefix = "model"
1137
  supports_gradient_checkpointing = True
1138
 
1139
  # initializes the weights of `nn.Linear`, `nn.Embedding` and `STLSinusoidalPositionalEmbedding`
@@ -1162,22 +1710,22 @@ class STLModel(PreTrainedModel):
1162
  return dummy_inputs
1163
 
1164
  class STLDecoderBlock(nn.Module):
1165
-
1166
- def __init__(self, embed_dim: int,
1167
  num_decoder_attention_heads: int,
1168
  num_decoder_ffn_dim: int,
1169
  dropout: float = 0.0,
1170
  attention_dropout: float = 0.0,
1171
  activation_dropout: float = 0.0,
1172
  ):
1173
-
1174
  super().__init__()
1175
-
1176
  self.embed_dim = embed_dim
1177
 
1178
- # first block
1179
  self.self_attn = STLAttention(
1180
- embed_dim=self.embed_dim,
1181
  num_heads=num_decoder_attention_heads,
1182
  dropout=dropout,
1183
  is_decoder=True, # not used, debugging purposes
@@ -1234,26 +1782,26 @@ class STLDecoderBlock(nn.Module):
1234
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1235
  returned tensors for more detail.
1236
  """
1237
-
1238
  ###################################################################
1239
-
1240
- # BLOCK 1: processing what has been previously generated
1241
 
1242
  # previous state is stored into an auxiliary variable `residual`
1243
  residual = hidden_states
1244
 
1245
- # tries to exploit previous K, V values if there are any
1246
  # (practically picks up to the first 2 values stored in `past_key_value` vector)
1247
  self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
1248
 
1249
  # masked MHSA on the already generated sequence
1250
- # invokes `forward` method to transform the original vector accordingly
1251
  hidden_states, self_attn_weights, present_key_value = self.self_attn.forward(
1252
  hidden_states=hidden_states, # Q
1253
  past_key_value=self_attn_past_key_value, # K, V
1254
  attention_mask=attention_mask, # passed as input of the decoder layer
1255
- layer_head_mask=layer_head_mask, # to deactivate certain attn layers
1256
- output_attentions=output_attentions,
1257
  )
1258
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1259
 
@@ -1268,7 +1816,7 @@ class STLDecoderBlock(nn.Module):
1268
  # BLOCK 2: cross-attn between already generated input and previous information (from the encoder)
1269
 
1270
  # initialize K, Q, attn_weights for this new attn operation
1271
- cross_attn_present_key_value = None
1272
  cross_attn_weights = None
1273
 
1274
  # the important condition is that the encoder carries some information
@@ -1346,7 +1894,7 @@ class STLDecoder(STLModel):
1346
  attention_dropout = config.attention_dropout
1347
  activation_dropout = config.activation_dropout
1348
  decoder_layerdrop = config.decoder_layerdrop
1349
-
1350
  self.dropout = dropout
1351
  self.layerdrop = decoder_layerdrop
1352
  self.padding_idx = pad_token_id
@@ -1355,16 +1903,16 @@ class STLDecoder(STLModel):
1355
 
1356
  # Initialize the input embedding (if not passed already)
1357
  self.embed_tokens = nn.Embedding(decoder_vocab_size, embed_dim, self.padding_idx)
1358
-
1359
  # Initialize positional embedding also
1360
  self.embed_positions = STLSinusoidalPositionalEmbedding(
1361
  max_position_embeddings, embed_dim, self.padding_idx
1362
  )
1363
-
1364
  # Initialize decoder layers (of a prespecified number)
1365
- self.layers = nn.ModuleList([STLDecoderBlock(embed_dim, num_decoder_attention_heads,
1366
- num_decoder_ffn_dim, dropout,
1367
- attention_dropout, activation_dropout)
1368
  for _ in range(num_decoder_layers)])
1369
 
1370
  self.gradient_checkpointing = False
@@ -1386,7 +1934,7 @@ class STLDecoder(STLModel):
1386
  return_dict: Optional[bool] = None,
1387
  **kwargs,
1388
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
1389
-
1390
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1391
  output_hidden_states = (
1392
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1509,7 +2057,7 @@ class STLDecoder(STLModel):
1509
  cross_attentions=all_cross_attentions,
1510
  )
1511
 
1512
- ####
1513
 
1514
  class STLForCausalLM(STLModel, GenerationMixin):
1515
  _tied_weights_keys = ["lm_head.weight"]
@@ -1518,7 +2066,7 @@ class STLForCausalLM(STLModel, GenerationMixin):
1518
  config = copy.deepcopy(config)
1519
  config.is_decoder = True
1520
  config.is_encoder_decoder = False
1521
-
1522
  super().__init__(config)
1523
  self.model = STLDecoder(config)
1524
 
@@ -1615,3 +2163,6 @@ class STLForCausalLM(STLModel, GenerationMixin):
1615
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1616
  )
1617
  return reordered_past
 
 
 
 
32
  Seq2SeqModelOutput,
33
  )
34
 
35
+ from .configuration import STLConfig
36
  from nltk.translate.bleu_score import sentence_bleu
37
+ # from stl import *
38
  import networkx as nx
39
  from datasets import load_dataset
40
 
41
+ ### from custom_typing.py
42
+
43
+ realnum = Union[float, int]
44
+
45
+
46
+ ### from stl.py
47
+
48
+ # For tensor functions
49
+ import torch
50
+ from torch import Tensor
51
+ import torch.nn.functional as F
52
+
53
+
54
+ def eventually(x: Tensor, time_span: int) -> Tensor:
55
+ """
56
+ STL operator 'eventually' in 1D.
57
+
58
+ Parameters
59
+ ----------
60
+ x: torch.Tensor
61
+ Signal
62
+ time_span: any numeric type
63
+ Timespan duration
64
+
65
+ Returns
66
+ -------
67
+ torch.Tensor
68
+ A tensor containing the result of the operation.
69
+ """
70
+ return F.max_pool1d(x, kernel_size=time_span, stride=1)
71
+
72
+ class Node:
73
+ """Abstract node class for STL semantics tree."""
74
+
75
+ def __init__(self) -> None:
76
+ # Must be overloaded.
77
+ pass
78
+
79
+ def __str__(self) -> str:
80
+ # Must be overloaded.
81
+ pass
82
+
83
+ def boolean(self, x: Tensor, evaluate_at_all_times: bool = False) -> Tensor:
84
+ """
85
+ Evaluates the boolean semantics at the node.
86
+
87
+ Parameters
88
+ ----------
89
+ x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
90
+ The input signals, stored as a batch tensor with trhee dimensions.
91
+ evaluate_at_all_times: bool
92
+ Whether to evaluate the semantics at all times (True) or
93
+ just at t=0 (False).
94
+
95
+ Returns
96
+ -------
97
+ torch.Tensor
98
+ A tensor with the boolean semantics for the node.
99
+ """
100
+ z: Tensor = self._boolean(x)
101
+ if evaluate_at_all_times:
102
+ return z
103
+ else:
104
+ return self._extract_semantics_at_time_zero(z)
105
+
106
+ def quantitative(
107
+ self,
108
+ x: Tensor,
109
+ normalize: bool = False,
110
+ evaluate_at_all_times: bool = False,
111
+ ) -> Tensor:
112
+ """
113
+ Evaluates the quantitative semantics at the node.
114
+
115
+ Parameters
116
+ ----------
117
+ x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
118
+ The input signals, stored as a batch tensor with three dimensions.
119
+ normalize: bool
120
+ Whether the measure of robustness if normalized (True) or
121
+ not (False). Currently not in use.
122
+ evaluate_at_all_times: bool
123
+ Whether to evaluate the semantics at all times (True) or
124
+ just at t=0 (False).
125
+
126
+ Returns
127
+ -------
128
+ torch.Tensor
129
+ A tensor with the quantitative semantics for the node.
130
+ """
131
+ z: Tensor = self._quantitative(x, normalize)
132
+ if evaluate_at_all_times:
133
+ return z
134
+ else:
135
+ return self._extract_semantics_at_time_zero(z)
136
+
137
+ def set_normalizing_flag(self, value: bool = True) -> None:
138
+ """
139
+ Setter for the 'normalization of robustness of the formula' flag.
140
+ Currently not in use.
141
+ """
142
+
143
+ def time_depth(self) -> int:
144
+ """Returns time depth of bounded temporal operators only."""
145
+ # Must be overloaded.
146
+
147
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
148
+ """Private method equivalent to public one for inner call."""
149
+ # Must be overloaded.
150
+
151
+ def _boolean(self, x: Tensor) -> Tensor:
152
+ """Private method equivalent to public one for inner call."""
153
+ # Must be overloaded.
154
+
155
+ @staticmethod
156
+ def _extract_semantics_at_time_zero(x: Tensor) -> Tensor:
157
+ """Extrapolates the vector of truth values at time zero"""
158
+ return torch.reshape(x[:, 0, 0], (-1,))
159
+
160
+
161
+ class Atom(Node):
162
+ """Atomic formula node; for now of the form X<=t or X>=t"""
163
+
164
+ def __init__(self, var_index: int, threshold: realnum, lte: bool = False) -> None:
165
+ super().__init__()
166
+ self.var_index: int = var_index
167
+ self.threshold: realnum = threshold
168
+ self.lte: bool = lte
169
+
170
+ def __str__(self) -> str:
171
+ s: str = (
172
+ "x_"
173
+ + str(self.var_index)
174
+ + (" <= " if self.lte else " >= ")
175
+ + str(round(self.threshold, 4))
176
+ )
177
+ return s
178
+
179
+ def time_depth(self) -> int:
180
+ return 0
181
+
182
+ def _boolean(self, x: Tensor) -> Tensor:
183
+ # extract tensor of the same dimension as data, but with only one variable
184
+ xj: Tensor = x[:, self.var_index, :]
185
+ xj: Tensor = xj.view(xj.size()[0], 1, -1)
186
+ if self.lte:
187
+ z: Tensor = torch.le(xj, self.threshold)
188
+ else:
189
+ z: Tensor = torch.ge(xj, self.threshold)
190
+ return z
191
+
192
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
193
+ # extract tensor of the same dimension as data, but with only one variable
194
+ xj: Tensor = x[:, self.var_index, :]
195
+ xj: Tensor = xj.view(xj.size()[0], 1, -1)
196
+ if self.lte:
197
+ z: Tensor = -xj + self.threshold
198
+ else:
199
+ z: Tensor = xj - self.threshold
200
+ if normalize:
201
+ z: Tensor = torch.tanh(z)
202
+ return z
203
+
204
+ class Not(Node):
205
+ """Negation node."""
206
+
207
+ def __init__(self, child: Node) -> None:
208
+ super().__init__()
209
+ self.child: Node = child
210
+
211
+ def __str__(self) -> str:
212
+ s: str = "not ( " + self.child.__str__() + " )"
213
+ return s
214
+
215
+ def time_depth(self) -> int:
216
+ return self.child.time_depth()
217
+
218
+ def _boolean(self, x: Tensor) -> Tensor:
219
+ z: Tensor = ~self.child._boolean(x)
220
+ return z
221
+
222
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
223
+ z: Tensor = -self.child._quantitative(x, normalize)
224
+ return z
225
+
226
+
227
+ class And(Node):
228
+ """Conjunction node."""
229
+
230
+ def __init__(self, left_child: Node, right_child: Node) -> None:
231
+ super().__init__()
232
+ self.left_child: Node = left_child
233
+ self.right_child: Node = right_child
234
+
235
+ def __str__(self) -> str:
236
+ s: str = (
237
+ "( "
238
+ + self.left_child.__str__()
239
+ + " and "
240
+ + self.right_child.__str__()
241
+ + " )"
242
+ )
243
+ return s
244
+
245
+ def time_depth(self) -> int:
246
+ return max(self.left_child.time_depth(), self.right_child.time_depth())
247
+
248
+ def _boolean(self, x: Tensor) -> Tensor:
249
+ z1: Tensor = self.left_child._boolean(x)
250
+ z2: Tensor = self.right_child._boolean(x)
251
+ size: int = min(z1.size()[2], z2.size()[2])
252
+ z1: Tensor = z1[:, :, :size]
253
+ z2: Tensor = z2[:, :, :size]
254
+ z: Tensor = torch.logical_and(z1, z2)
255
+ return z
256
+
257
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
258
+ z1: Tensor = self.left_child._quantitative(x, normalize)
259
+ z2: Tensor = self.right_child._quantitative(x, normalize)
260
+ size: int = min(z1.size()[2], z2.size()[2])
261
+ z1: Tensor = z1[:, :, :size]
262
+ z2: Tensor = z2[:, :, :size]
263
+ z: Tensor = torch.min(z1, z2)
264
+ return z
265
+
266
+ class Not(Node):
267
+ """Negation node."""
268
+
269
+ def __init__(self, child: Node) -> None:
270
+ super().__init__()
271
+ self.child: Node = child
272
+
273
+ def __str__(self) -> str:
274
+ s: str = "not ( " + self.child.__str__() + " )"
275
+ return s
276
+
277
+ def time_depth(self) -> int:
278
+ return self.child.time_depth()
279
+
280
+ def _boolean(self, x: Tensor) -> Tensor:
281
+ z: Tensor = ~self.child._boolean(x)
282
+ return z
283
+
284
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
285
+ z: Tensor = -self.child._quantitative(x, normalize)
286
+ return z
287
+
288
+
289
+ class And(Node):
290
+ """Conjunction node."""
291
+
292
+ def __init__(self, left_child: Node, right_child: Node) -> None:
293
+ super().__init__()
294
+ self.left_child: Node = left_child
295
+ self.right_child: Node = right_child
296
+
297
+ def __str__(self) -> str:
298
+ s: str = (
299
+ "( "
300
+ + self.left_child.__str__()
301
+ + " and "
302
+ + self.right_child.__str__()
303
+ + " )"
304
+ )
305
+ return s
306
+
307
+ def time_depth(self) -> int:
308
+ return max(self.left_child.time_depth(), self.right_child.time_depth())
309
+
310
+ def _boolean(self, x: Tensor) -> Tensor:
311
+ z1: Tensor = self.left_child._boolean(x)
312
+ z2: Tensor = self.right_child._boolean(x)
313
+ size: int = min(z1.size()[2], z2.size()[2])
314
+ z1: Tensor = z1[:, :, :size]
315
+ z2: Tensor = z2[:, :, :size]
316
+ z: Tensor = torch.logical_and(z1, z2)
317
+ return z
318
+
319
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
320
+ z1: Tensor = self.left_child._quantitative(x, normalize)
321
+ z2: Tensor = self.right_child._quantitative(x, normalize)
322
+ size: int = min(z1.size()[2], z2.size()[2])
323
+ z1: Tensor = z1[:, :, :size]
324
+ z2: Tensor = z2[:, :, :size]
325
+ z: Tensor = torch.min(z1, z2)
326
+ return z
327
+
328
+ class Or(Node):
329
+ """Disjunction node."""
330
+
331
+ def __init__(self, left_child: Node, right_child: Node) -> None:
332
+ super().__init__()
333
+ self.left_child: Node = left_child
334
+ self.right_child: Node = right_child
335
+
336
+ def __str__(self) -> str:
337
+ s: str = (
338
+ "( "
339
+ + self.left_child.__str__()
340
+ + " or "
341
+ + self.right_child.__str__()
342
+ + " )"
343
+ )
344
+ return s
345
+
346
+ def time_depth(self) -> int:
347
+ return max(self.left_child.time_depth(), self.right_child.time_depth())
348
+
349
+ def _boolean(self, x: Tensor) -> Tensor:
350
+ z1: Tensor = self.left_child._boolean(x)
351
+ z2: Tensor = self.right_child._boolean(x)
352
+ size: int = min(z1.size()[2], z2.size()[2])
353
+ z1: Tensor = z1[:, :, :size]
354
+ z2: Tensor = z2[:, :, :size]
355
+ z: Tensor = torch.logical_or(z1, z2)
356
+ return z
357
+
358
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
359
+ z1: Tensor = self.left_child._quantitative(x, normalize)
360
+ z2: Tensor = self.right_child._quantitative(x, normalize)
361
+ size: int = min(z1.size()[2], z2.size()[2])
362
+ z1: Tensor = z1[:, :, :size]
363
+ z2: Tensor = z2[:, :, :size]
364
+ z: Tensor = torch.max(z1, z2)
365
+ return z
366
+
367
+
368
+ class Globally(Node):
369
+ """Globally node."""
370
+ def __init__(
371
+ self,
372
+ child: Node,
373
+ unbound: bool = False,
374
+ right_unbound: bool = False,
375
+ left_time_bound: int = 0,
376
+ right_time_bound: int = 1,
377
+ adapt_unbound: bool = True,
378
+ ) -> None:
379
+ super().__init__()
380
+ self.child: Node = child
381
+ self.unbound: bool = unbound
382
+ self.right_unbound: bool = right_unbound
383
+ self.left_time_bound: int = left_time_bound
384
+ self.right_time_bound: int = right_time_bound + 1
385
+ self.adapt_unbound: bool = adapt_unbound
386
+
387
+ def __str__(self) -> str:
388
+ s_left = "[" + str(self.left_time_bound) + ","
389
+ s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
390
+ s0: str = s_left + s_right + "]" if not self.unbound else ""
391
+ s: str = "always" + s0 + " ( " + self.child.__str__() + " )"
392
+ return s
393
+
394
+ def time_depth(self) -> int:
395
+ if self.unbound:
396
+ return self.child.time_depth()
397
+ elif self.right_unbound:
398
+ return self.child.time_depth() + self.left_time_bound
399
+ else:
400
+ # diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
401
+ return self.child.time_depth() + self.right_time_bound - 1
402
+ # (self.right_time_bound - self.left_time_bound + 1) - diff
403
+
404
+ def _boolean(self, x: Tensor) -> Tensor:
405
+ z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:]) # nested temporal parameters
406
+ # z1 = z1[:, :, self.left_time_bound:]
407
+ if self.unbound or self.right_unbound:
408
+ if self.adapt_unbound:
409
+ z: Tensor
410
+ _: Tensor
411
+ z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
412
+ z: Tensor = torch.flip(z, [2])
413
+ else:
414
+ z: Tensor
415
+ _: Tensor
416
+ z, _ = torch.min(z1, 2, keepdim=True)
417
+ else:
418
+ z: Tensor = torch.ge(1.0 - eventually((~z1).double(), self.right_time_bound - self.left_time_bound), 0.5)
419
+ return z
420
+
421
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
422
+ z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
423
+ # z1 = z1[:, :, self.left_time_bound:]
424
+ if self.unbound or self.right_unbound:
425
+ if self.adapt_unbound:
426
+ z: Tensor
427
+ _: Tensor
428
+ z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
429
+ z: Tensor = torch.flip(z, [2])
430
+ else:
431
+ z: Tensor
432
+ _: Tensor
433
+ z, _ = torch.min(z1, 2, keepdim=True)
434
+ else:
435
+ z: Tensor = -eventually(-z1, self.right_time_bound - self.left_time_bound)
436
+ return z
437
+
438
+
439
+
440
+ class Eventually(Node):
441
+ """Eventually node."""
442
+
443
+ def __init__(
444
+ self,
445
+ child: Node,
446
+ unbound: bool = False,
447
+ right_unbound: bool = False,
448
+ left_time_bound: int = 0,
449
+ right_time_bound: int = 1,
450
+ adapt_unbound: bool = True,
451
+ ) -> None:
452
+ super().__init__()
453
+ self.child: Node = child
454
+ self.unbound: bool = unbound
455
+ self.right_unbound: bool = right_unbound
456
+ self.left_time_bound: int = left_time_bound
457
+ self.right_time_bound: int = right_time_bound + 1
458
+ self.adapt_unbound: bool = adapt_unbound
459
+
460
+ if (self.unbound is False) and (self.right_unbound is False) and \
461
+ (self.right_time_bound <= self.left_time_bound):
462
+ raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
463
+
464
+ def __str__(self) -> str:
465
+ s_left = "[" + str(self.left_time_bound) + ","
466
+ s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
467
+ s0: str = s_left + s_right + "]" if not self.unbound else ""
468
+ s: str = "eventually" + s0 + " ( " + self.child.__str__() + " )"
469
+ return s
470
+
471
+ def time_depth(self) -> int:
472
+ if self.unbound:
473
+ return self.child.time_depth()
474
+ elif self.right_unbound:
475
+ return self.child.time_depth() + self.left_time_bound
476
+ else:
477
+ # diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
478
+ return self.child.time_depth() + self.right_time_bound - 1
479
+ # (self.right_time_bound - self.left_time_bound + 1) - diff
480
+
481
+ def _boolean(self, x: Tensor) -> Tensor:
482
+ z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:])
483
+ if self.unbound or self.right_unbound:
484
+ if self.adapt_unbound:
485
+ z: Tensor
486
+ _: Tensor
487
+ z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
488
+ z: Tensor = torch.flip(z, [2])
489
+ else:
490
+ z: Tensor
491
+ _: Tensor
492
+ z, _ = torch.max(z1, 2, keepdim=True)
493
+ else:
494
+ z: Tensor = torch.ge(eventually(z1.double(), self.right_time_bound - self.left_time_bound), 0.5)
495
+ return z
496
+
497
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
498
+ z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
499
+ if self.unbound or self.right_unbound:
500
+ if self.adapt_unbound:
501
+ z: Tensor
502
+ _: Tensor
503
+ z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
504
+ z: Tensor = torch.flip(z, [2])
505
+ else:
506
+ z: Tensor
507
+ _: Tensor
508
+ z, _ = torch.max(z1, 2, keepdim=True)
509
+ else:
510
+ z: Tensor = eventually(z1, self.right_time_bound - self.left_time_bound)
511
+ return z
512
+
513
+ class Until(Node):
514
+ """Until node."""
515
+
516
+ def __init__(
517
+ self,
518
+ left_child: Node,
519
+ right_child: Node,
520
+ unbound: bool = False,
521
+ right_unbound: bool = False,
522
+ left_time_bound: int = 0,
523
+ right_time_bound: int = 1,
524
+ ) -> None:
525
+ super().__init__()
526
+ self.left_child: Node = left_child
527
+ self.right_child: Node = right_child
528
+ self.unbound: bool = unbound
529
+ self.right_unbound: bool = right_unbound
530
+ self.left_time_bound: int = left_time_bound
531
+ self.right_time_bound: int = right_time_bound + 1
532
+
533
+ if (self.unbound is False) and (self.right_unbound is False) and \
534
+ (self.right_time_bound <= self.left_time_bound):
535
+ raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
536
+
537
+ def __str__(self) -> str:
538
+ s_left = "[" + str(self.left_time_bound) + ","
539
+ s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
540
+ s0: str = s_left + s_right + "]" if not self.unbound else ""
541
+ s: str = "( " + self.left_child.__str__() + " until" + s0 + " " + self.right_child.__str__() + " )"
542
+ return s
543
+
544
+ def time_depth(self) -> int:
545
+ sum_children_depth: int = self.left_child.time_depth() + self.right_child.time_depth()
546
+ if self.unbound:
547
+ return sum_children_depth
548
+ elif self.right_unbound:
549
+ return sum_children_depth + self.left_time_bound
550
+ else:
551
+ # diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
552
+ return sum_children_depth + self.right_time_bound - 1
553
+ # (self.right_time_bound - self.left_time_bound + 1) - diff
554
+
555
+ def _boolean(self, x: Tensor) -> Tensor:
556
+ if self.unbound:
557
+ z1: Tensor = self.left_child._boolean(x)
558
+ z2: Tensor = self.right_child._boolean(x)
559
+ size: int = min(z1.size()[2], z2.size()[2])
560
+ z1: Tensor = z1[:, :, :size]
561
+ z2: Tensor = z2[:, :, :size]
562
+ z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
563
+ z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
564
+ z1_triu = torch.triu(z1_rep)
565
+ z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
566
+
567
+ z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
568
+ z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
569
+ z2_triu = torch.triu(z2_rep)
570
+ z2_def = z2_tril + z2_triu
571
+ z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
572
+ dim=-1)[0]
573
+ elif self.right_unbound:
574
+ timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
575
+ And(Eventually(self.right_child, right_unbound=True,
576
+ left_time_bound=self.left_time_bound),
577
+ Eventually(Until(self.left_child, self.right_child, unbound=True),
578
+ left_time_bound=self.left_time_bound, right_unbound=True)))
579
+ z: Tensor = timed_until._boolean(x)
580
+ else:
581
+ timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
582
+ And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
583
+ right_time_bound=self.right_time_bound - 1),
584
+ Eventually(Until(self.left_child, self.right_child, unbound=True),
585
+ left_time_bound=self.left_time_bound, right_unbound=True)))
586
+ z: Tensor = timed_until._boolean(x)
587
+ return z
588
+
589
+ def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
590
+ if self.unbound:
591
+ z1: Tensor = self.left_child._quantitative(x, normalize)
592
+ z2: Tensor = self.right_child._quantitative(x, normalize)
593
+ size: int = min(z1.size()[2], z2.size()[2])
594
+ z1: Tensor = z1[:, :, :size]
595
+ z2: Tensor = z2[:, :, :size]
596
+
597
+ # z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
598
+ # z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
599
+ # z1_triu = torch.triu(z1_rep)
600
+ # z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
601
+
602
+ # z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
603
+ # z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
604
+ # z2_triu = torch.triu(z2_rep)
605
+ # z2_def = z2_tril + z2_triu
606
+ # z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
607
+ # dim=-1)[0]
608
+ z: Tensor = torch.cat([torch.max(torch.min(
609
+ torch.cat([torch.cummin(z1[:, :, t:].unsqueeze(-1), dim=2)[0], z2[:, :, t:].unsqueeze(-1)], dim=-1),
610
+ dim=-1)[0], dim=2, keepdim=True)[0] for t in range(size)], dim=2)
611
+ elif self.right_unbound:
612
+ timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
613
+ And(Eventually(self.right_child, right_unbound=True,
614
+ left_time_bound=self.left_time_bound),
615
+ Eventually(Until(self.left_child, self.right_child, unbound=True),
616
+ left_time_bound=self.left_time_bound, right_unbound=True)))
617
+ z: Tensor = timed_until._quantitative(x, normalize=normalize)
618
+ else:
619
+ timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
620
+ And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
621
+ right_time_bound=self.right_time_bound-1),
622
+ Eventually(Until(self.left_child, self.right_child, unbound=True),
623
+ left_time_bound=self.left_time_bound, right_unbound=True)))
624
+ z: Tensor = timed_until._quantitative(x, normalize=normalize)
625
+ return z
626
 
627
  # from anchor_set_generation import anchorGeneration
628
 
 
737
  def load_json(path: str) -> Union[Dict, List]:
738
  """
739
  Load a JSON file from the given path.
 
740
  Args:
741
  path (str): The path to the JSON file to be loaded.
742
+
743
  Returns:
744
  Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list.
745
  """
 
800
  def sample(self, nvars):
801
  """
802
  Samples a random formula with distribution defined in class instance parameters
 
803
  Parameters
804
  ----------
805
  nvars : number of variables of input signals
806
  how many variables the formula is expected to consider.
 
807
  Returns
808
  -------
809
  TYPE
810
  A random formula.
 
811
  """
812
  return self._sample_internal_node(nvars)
 
813
  def bag_sample(self, bag_size, nvars):
814
  """
815
  Samples a bag of bag_size formulae
 
816
  Parameters
817
  ----------
818
  bag_size : INT
819
  number of formulae.
820
  nvars : INT
821
  number of vars in formulae.
 
822
  Returns
823
  -------
824
  a list of formulae.
 
825
  """
826
  formulae = []
827
  for _ in range(bag_size):
 
838
  while True:
839
  if nodetype == "not":
840
  n = self._sample_node(nvars)
841
+ node = Not(n)
842
  elif nodetype == "and":
843
  n1 = self._sample_node(nvars)
844
  n2 = self._sample_node(nvars)
845
+ node = And(n1, n2)
846
  elif nodetype == "or":
847
  n1 = self._sample_node(nvars)
848
  n2 = self._sample_node(nvars)
849
+ node = Or(n1, n2)
850
  elif nodetype == "always":
851
  n = self._sample_node(nvars)
852
  unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
853
+ node = Globally(
854
  n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
855
  )
856
  elif nodetype == "eventually":
857
  n = self._sample_node(nvars)
858
  unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
859
+ node = Eventually(
860
  n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
861
  )
862
  elif nodetype == "until":
863
  n1 = self._sample_node(nvars)
864
  n2 = self._sample_node(nvars)
865
  unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
866
+ node = Until(
867
  n1, n2, unbound, right_unbound, left_time_bound, right_time_bound
868
  )
869
 
 
874
  if rnd.rand() < self.leaf_prob:
875
  # sample a leaf
876
  var, thr, lte = self._get_atom(nvars)
877
+ return Atom(var, thr, lte)
878
  else:
879
  return self._sample_internal_node(nvars)
880
 
 
905
  self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu"
906
  ):
907
  """
 
908
  Parameters
909
  ----------
910
  mu0 : mean of normal distribution of initial state, optional
 
921
  probability of initial sign of derivative. The default is 0.5.
922
  device : 'cpu' or 'cuda', optional
923
  device on which to run the algorithm. The default is 'cpu'.
 
924
  Returns
925
  -------
926
  None.
 
927
  """
928
  self.mu0 = mu0
929
  self.sigma0 = sigma0
 
937
  """
938
  Samples a set of trajectories from the basic measure space, with parameters
939
  passed to the sampler
 
940
  Parameters
941
  ----------
942
  points : INT, optional
 
945
  number of trajectories. The default is 100000.
946
  varn : INT, optional
947
  number of variables per trajectory. The default is 2.
 
 
948
  Returns
949
  -------
950
  signal : samples x varn x points double pytorch tensor
951
  The sampled signals.
 
952
  """
953
  if self.device == "cuda" and not torch.cuda.is_available():
954
  raise RuntimeError("GPU card or CUDA library not available!")
 
1083
  return kernel_matrix.cpu(), rhos1, selfk1, len1
1084
  else:
1085
  return kernel_matrix.cpu()
 
 
1086
  n = self.samples
1087
  p = self.points
1088
  k = len(phis)
 
1144
  kernel_matrix = kernel_matrix / normalize
1145
  return kernel_matrix
1146
 
1147
+ @staticmethod
1148
+ def _normalize(kernel_matrix, selfk1, selfk2):
1149
+ normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
1150
+ kernel_matrix = kernel_matrix / normalize
1151
+ return kernel_matrix
1152
+
1153
  def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None):
1154
  if sigma2 is None:
1155
  sigma2 = self.sigma2
 
1280
  leaf_prob: float = 0.4, # complexity of the generated formula
1281
  cosine_similarity_threshold: float = 0.8 # if two formulae cosine similarity exceeds 0.9, then discard one of the two
1282
  ) -> str:
1283
+
1284
  # initialize STL formula generator
1285
  sampler = StlGenerator(leaf_prob)
1286
+
1287
  # effective anchor set generation
1288
  if diff_init:
1289
+
1290
  # initialize the anchor set with a randomly sampled formula
1291
  diff_anchor_set = [sampler.sample(nvars=n_vars)]
1292
 
 
1302
  while len(diff_anchor_set) < embed_dim:
1303
  # sample the 'remaining' formulae to reach the desired number of `embed_dim` formulae:
1304
  candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars)
1305
+
1306
  # compute robustness of candidate anchor formulae on the same signals as previous anchor set
1307
  candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0)
1308
+
1309
  # compute cosine similarity between current anchor set and candidate new formulae
1310
  cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1)
1311
 
1312
  # check which formulae are similar (i.e. greater cosine similarity then threshold) w.r.t. current anchors
1313
  # NOTA: chiedere a gaia se cosine similarities negative vanno ammazzate con un valore assoluto o meno!
1314
  similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])]
1315
+
1316
  # keep only those who are semantically distant
1317
  keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist])))
1318
+
1319
  diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx]
1320
+
1321
  # Convert keep_idx to a tensor on the same device as candidate_robs
1322
  keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device)
1323
+
1324
  # Use index_select to pick the relevant rows
1325
  selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor)
1326
+
1327
  # Concatenate on the same device
1328
  anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0)
1329
 
1330
  anchor_set = diff_anchor_set[:embed_dim]
1331
+
1332
  else:
1333
+ anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars)
1334
 
1335
  filename = f'anchor_set_no_diff_{embed_dim}_dim'
1336
  dump_pickle(filename, anchor_set)
 
1338
 
1339
  ####
1340
 
 
1341
  """
1342
  A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process.
1343
+ This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs,
 
1344
  and handle padding and special tokens.
1345
  """
1346
 
1347
+ def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad",
1348
  bos_token: str = "/s", eos_token: str = "s", model_max_length = 512, *args, **kwargs):
1349
  """
1350
  Initializes the STLTokenizer with a given vocabulary and special tokens.
 
1351
  Args:
1352
  vocab_path (str): The path to the JSON file containing the vocabulary.
1353
  unk_token (str, optional): The token used for unknown words. Defaults to "unk".
 
1362
  self.eos_token = eos_token
1363
  self.model_max_length = model_max_length
1364
  self.id_to_token = {v: k for k, v in self.vocab.items()} # Reverse mapping
1365
+ super().__init__(unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token,
1366
  model_max_length=model_max_length, *args, **kwargs)
1367
 
1368
  @property
1369
  def vocab_size(self) -> int:
1370
  """
1371
  Returns the size of the vocabulary.
 
1372
  Returns:
1373
  int: The number of tokens in the vocabulary.
1374
  """
 
1377
  def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False):
1378
  """
1379
  Replaces spaces in the input sequence with a specified token.
 
1380
  Args:
1381
  sequence (str): The input sequence.
1382
  undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces.
 
1383
  Returns:
1384
  str: The preprocessed sequence with spaces or padding tokens replaced.
1385
  """
 
1391
  def add_bos_eos(self, sequence: str) -> str:
1392
  """
1393
  Aggiunge i token BOS all'inizio e EOS alla fine della sequenza.
 
1394
  Args:
1395
  sequence (str): La sequenza di input.
 
1396
  Returns:
1397
  str: La sequenza con i token BOS ed EOS.
1398
  """
 
1401
  def tokenize(self, text: str) -> List[str]:
1402
  """
1403
  Tokenizes the input text into a list of tokens.
1404
+ The method preprocesses the input text by replacing spaces with padding tokens and then tries to
 
1405
  find the longest possible match for each substring in the vocabulary.
 
1406
  Args:
1407
  text (str): The input text to be tokenized.
 
1408
  Returns:
1409
  List[str]: A list of tokens representing the tokenized text.
1410
  """
1411
  text = self.add_bos_eos(text)
1412
  text = self.prepad_sequence(text)
 
1413
  tokens = []
1414
  i = 0
1415
  while i < len(text):
 
1430
  def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
1431
  """
1432
  Converts a list of tokens into a list of token IDs.
 
1433
  Args:
1434
  tokens (List[str]): A list of tokens to be converted into IDs.
 
1435
  Returns:
1436
  List[int]: A list of corresponding token IDs.
1437
  """
 
1440
  def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
1441
  """
1442
  Converts a list of token IDs into a list of tokens.
 
1443
  Args:
1444
  ids (List[int]): A list of token IDs to be converted into tokens.
 
1445
  Returns:
1446
  List[str]: A list of corresponding tokens.
1447
  """
 
1450
  def encode(self, sequence: str) -> List[int]:
1451
  """
1452
  Encodes a string sequence into a list of token IDs.
1453
+
1454
+ This method tokenizes the input sequence using the `tokenize` method,
1455
+ and then converts the resulting tokens into their corresponding token IDs
1456
  using the `convert_tokens_to_ids` method.
1457
+
1458
  Args:
1459
  sequence (str): The input sequence (text) to be encoded.
1460
+
1461
  Returns:
1462
  List[int]: A list of token IDs corresponding to the input sequence.
1463
  """
 
1466
 
1467
  def postpad_sequence(self, sequence, pad_token_id):
1468
  """
1469
+ Fills the sequence up to max_length padding elements
1470
+ """
1471
  num_extra_elements = self.model_max_length - len(sequence) -1
1472
  if num_extra_elements > 0:
1473
  sequence.extend([pad_token_id] * num_extra_elements)
 
1476
  def decode(self, token_ids: List[int]) -> str:
1477
  """
1478
  Decodes a list of token IDs into a string of text.
1479
+ The method converts the IDs to tokens and joins them to form a string.
 
1480
  It also restores the original spaces or padding tokens if `undo` is True.
 
1481
  Args:
1482
  token_ids (List[int]): A list of token IDs to be decoded.
1483
  skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False.
 
1484
  Returns:
1485
  str: The decoded string.
1486
  """
 
1490
 
1491
  def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
1492
  """
1493
+ Saves the tokenizer's vocabulary to a file.
1494
+ Useful only when the vocabulary has to be retrieved and is not given
1495
  (thus this is not the case: here to further improvements with sentencepiece).
1496
+ This method saves the vocabulary to a JSON file in the specified directory.
 
 
1497
  Args:
1498
  save_directory (str): The directory where the vocabulary file will be saved.
1499
  filename_prefix (Optional[str]): An optional prefix for the filename.
 
1500
  Returns:
1501
  Tuple[str]: A tuple containing the path to the saved vocabulary file.
1502
  """
 
1508
  def get_vocab(self) -> dict:
1509
  """
1510
  Retrieves the vocabulary used by the tokenizer.
 
1511
  Returns:
1512
  dict: The vocabulary as a dictionary.
1513
  """
 
1536
  out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
1537
  out.detach_()
1538
  return out
 
1539
  @torch.no_grad()
1540
  def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
1541
  """`input_ids_shape` is expected to be [bsz x seqlen]."""
 
1548
  class STLAttention(nn.Module):
1549
  """ Multi-Head Attention as depicted from 'Attention is all you need' """
1550
 
1551
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0,
1552
  is_decoder: bool = False, bias: bool = False, is_causal: bool = False):
1553
+
1554
  super().__init__()
1555
  self.embed_dim = embed_dim # overall embedding dimension -> to be divided between multiple heads
1556
  self.num_heads = num_heads
1557
  self.dropout = dropout
1558
  self.head_dim = embed_dim // num_heads
1559
+ assert (self.head_dim * num_heads) == self.embed_dim
1560
  self.scaling = self.head_dim ** -0.5 # used to normalize values when projected using `W_` matrices
1561
  self.is_decoder = is_decoder
1562
  self.is_causal = is_causal
1563
 
1564
+ # 'roleplaying' matrices
1565
+ self.W_k = nn.Linear(embed_dim, embed_dim, bias = bias)
1566
  self.W_q = nn.Linear(embed_dim, embed_dim, bias = bias)
1567
  self.W_v = nn.Linear(embed_dim, embed_dim, bias = bias)
1568
 
1569
  # to project the heads' outputs into a single vector
1570
+ self.W_o = nn.Linear(embed_dim, embed_dim, bias = bias)
1571
 
1572
 
1573
  def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
1574
  return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
1575
+
1576
+
1577
+ def forward(self,
1578
  hidden_states: torch.Tensor, # previous values, passed to the multi-head attn layer
1579
  key_value_states: Optional[torch.Tensor] = None, # different key, value items (used in cross-attn)
1580
+ past_key_value: Optional[Tuple[torch.Tensor]] = None, # stores the key and values of previous steps
1581
  attention_mask: Optional[torch.Tensor] = None, # masks non-allowed items (padded or future ones)
1582
  layer_head_mask: Optional[torch.Tensor] = None, # used to de-activate specific attn heads
1583
  output_attentions: bool = False # flag to control the output of the attn values,
 
1584
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1585
 
1586
  is_cross_attention = key_value_states is not None # cross-attn if key_value_states is not None
 
1604
  else:
1605
  key = self._shape(self.W_k(hidden_states), -1, batch_size)
1606
  value = self._shape(self.W_v(hidden_states), -1, batch_size)
 
1607
  if self.is_decoder:
1608
  past_key_value = (key, value)
1609
+
1610
  proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
1611
 
1612
+ query = self._shape(query, tgt_len, batch_size).view(*proj_shape)
1613
  key = key.reshape(*proj_shape)
1614
  value = value.reshape(*proj_shape)
1615
 
1616
  src_len = key.size(1)
1617
 
1618
+
1619
  ######################################################################################################
1620
 
1621
  # 'traditional' attention computation
 
1627
  if attention_mask is not None:
1628
  attn_weights = attn_weights.view(batch_size, self.num_heads, tgt_len, src_len) + attention_mask
1629
  attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
1630
+
1631
  # Normalize values on the `key` axis (dim=-1)
1632
  attn_weights = F.softmax(attn_weights, dim=-1)
1633
 
 
1646
  attn_output = attn_output.transpose(1, 2)
1647
 
1648
  attn_output = attn_output.reshape(batch_size, tgt_len, self.embed_dim)
1649
+ attn_output = self.W_o(attn_output)
1650
 
1651
  return attn_output, None, past_key_value
1652
 
1653
  ####
1654
 
1655
  class STLEncoder():
1656
+ def __init__(self,
1657
  embed_dim: int,
1658
  anchor_filename: Optional[str] = None,
1659
  n_vars: int = 3):
1660
+
1661
  self.n_vars = n_vars # passaglielo in input
1662
  self.embed_dim = embed_dim
1663
  self.anchorset_filename = anchor_filename
 
1665
  self.mu = BaseMeasure(device=self.device)
1666
  self.kernel = StlKernel(self.mu, varn=self.n_vars)
1667
 
1668
+ if anchor_filename is None:
1669
+ anchor_filename = anchorGeneration(diff_init = True, embed_dim = self.embed_dim, n_vars = self.n_vars)
1670
  anchor_filename+='.pickle'
1671
 
1672
  # TO DO: check on the dimensions of the anchor set and the `embed_dim` and `n_vars` values
 
1680
  return self.kernel.compute_bag_bag(formula, self.anchor_set)
1681
 
1682
  class STLModel(PreTrainedModel):
1683
+ config_class = STLConfig
1684
+ base_model_prefix = "model"
1685
  supports_gradient_checkpointing = True
1686
 
1687
  # initializes the weights of `nn.Linear`, `nn.Embedding` and `STLSinusoidalPositionalEmbedding`
 
1710
  return dummy_inputs
1711
 
1712
  class STLDecoderBlock(nn.Module):
1713
+
1714
+ def __init__(self, embed_dim: int,
1715
  num_decoder_attention_heads: int,
1716
  num_decoder_ffn_dim: int,
1717
  dropout: float = 0.0,
1718
  attention_dropout: float = 0.0,
1719
  activation_dropout: float = 0.0,
1720
  ):
1721
+
1722
  super().__init__()
1723
+
1724
  self.embed_dim = embed_dim
1725
 
1726
+ # first block
1727
  self.self_attn = STLAttention(
1728
+ embed_dim=self.embed_dim,
1729
  num_heads=num_decoder_attention_heads,
1730
  dropout=dropout,
1731
  is_decoder=True, # not used, debugging purposes
 
1782
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1783
  returned tensors for more detail.
1784
  """
1785
+
1786
  ###################################################################
1787
+
1788
+ # BLOCK 1: processing what has been previously generated
1789
 
1790
  # previous state is stored into an auxiliary variable `residual`
1791
  residual = hidden_states
1792
 
1793
+ # tries to exploit previous K, V values if there are any
1794
  # (practically picks up to the first 2 values stored in `past_key_value` vector)
1795
  self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
1796
 
1797
  # masked MHSA on the already generated sequence
1798
+ # invokes `forward` method to transform the original vector accordingly
1799
  hidden_states, self_attn_weights, present_key_value = self.self_attn.forward(
1800
  hidden_states=hidden_states, # Q
1801
  past_key_value=self_attn_past_key_value, # K, V
1802
  attention_mask=attention_mask, # passed as input of the decoder layer
1803
+ layer_head_mask=layer_head_mask, # to deactivate certain attn layers
1804
+ output_attentions=output_attentions,
1805
  )
1806
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1807
 
 
1816
  # BLOCK 2: cross-attn between already generated input and previous information (from the encoder)
1817
 
1818
  # initialize K, Q, attn_weights for this new attn operation
1819
+ cross_attn_present_key_value = None
1820
  cross_attn_weights = None
1821
 
1822
  # the important condition is that the encoder carries some information
 
1894
  attention_dropout = config.attention_dropout
1895
  activation_dropout = config.activation_dropout
1896
  decoder_layerdrop = config.decoder_layerdrop
1897
+
1898
  self.dropout = dropout
1899
  self.layerdrop = decoder_layerdrop
1900
  self.padding_idx = pad_token_id
 
1903
 
1904
  # Initialize the input embedding (if not passed already)
1905
  self.embed_tokens = nn.Embedding(decoder_vocab_size, embed_dim, self.padding_idx)
1906
+
1907
  # Initialize positional embedding also
1908
  self.embed_positions = STLSinusoidalPositionalEmbedding(
1909
  max_position_embeddings, embed_dim, self.padding_idx
1910
  )
1911
+
1912
  # Initialize decoder layers (of a prespecified number)
1913
+ self.layers = nn.ModuleList([STLDecoderBlock(embed_dim, num_decoder_attention_heads,
1914
+ num_decoder_ffn_dim, dropout,
1915
+ attention_dropout, activation_dropout)
1916
  for _ in range(num_decoder_layers)])
1917
 
1918
  self.gradient_checkpointing = False
 
1934
  return_dict: Optional[bool] = None,
1935
  **kwargs,
1936
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
1937
+
1938
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1939
  output_hidden_states = (
1940
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
2057
  cross_attentions=all_cross_attentions,
2058
  )
2059
 
2060
+ ####
2061
 
2062
  class STLForCausalLM(STLModel, GenerationMixin):
2063
  _tied_weights_keys = ["lm_head.weight"]
 
2066
  config = copy.deepcopy(config)
2067
  config.is_decoder = True
2068
  config.is_encoder_decoder = False
2069
+
2070
  super().__init__(config)
2071
  self.model = STLDecoder(config)
2072
 
 
2163
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
2164
  )
2165
  return reordered_past
2166
+
2167
+
2168
+