saracandu commited on
Commit
cb9d925
·
verified ·
1 Parent(s): 64eddc2
Files changed (1) hide show
  1. modeling.py +704 -8
modeling.py CHANGED
@@ -33,18 +33,13 @@ from transformers.modeling_outputs import (
33
  )
34
 
35
  from configuration import STLConfig
36
- # from handcoded_tokenizer import STLTokenizer
37
  from nltk.translate.bleu_score import sentence_bleu
38
  from stl import *
39
  import networkx as nx
40
- # import phis_generator_depth
41
  from datasets import load_dataset
42
 
43
- from utils import from_string_to_formula, load_pickle, dump_pickle
44
- from phis_generator import StlGenerator
45
- from traj_measure import BaseMeasure
46
- from kernel import StlKernel
47
- from anchor_set_generation import anchorGeneration
48
 
49
  import re
50
  import json
@@ -54,6 +49,105 @@ from transformers.utils import logging
54
 
55
  logger = logging.get_logger(__name__)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def load_json(path: str) -> Union[Dict, List]:
59
  """
@@ -68,6 +162,607 @@ def load_json(path: str) -> Union[Dict, List]:
68
  with open(path, "r") as f:
69
  return json.load(f)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  class STLTokenizer(PreTrainedTokenizer):
73
  """
@@ -404,6 +1099,7 @@ class STLAttention(nn.Module):
404
 
405
  return attn_output, None, past_key_value
406
 
 
407
 
408
  class STLEncoder():
409
  def __init__(self,
@@ -808,7 +1504,7 @@ class STLDecoder(STLModel):
808
  cross_attentions=all_cross_attentions,
809
  )
810
 
811
-
812
 
813
  class STLForCausalLM(STLModel, GenerationMixin):
814
  _tied_weights_keys = ["lm_head.weight"]
 
33
  )
34
 
35
  from configuration import STLConfig
 
36
  from nltk.translate.bleu_score import sentence_bleu
37
  from stl import *
38
  import networkx as nx
 
39
  from datasets import load_dataset
40
 
41
+
42
+ # from anchor_set_generation import anchorGeneration
 
 
 
43
 
44
  import re
45
  import json
 
49
 
50
  logger = logging.get_logger(__name__)
51
 
52
+ #### utils ####
53
+
54
+ def load_pickle(path):
55
+ with open(path, 'rb') as f:
56
+ x = pickle.load(f)
57
+ return x
58
+
59
+ def dump_pickle(name, thing):
60
+ with open(name + '.pickle', 'wb') as f:
61
+ pickle.dump(thing, f)
62
+
63
+ def from_string_to_formula(st):
64
+ root_arity = 2 if st.startswith('(') else 1
65
+ st_split = st.split()
66
+ if root_arity <= 1:
67
+ root_op_str = copy.deepcopy(st_split[0])
68
+ if root_op_str.startswith('x'):
69
+ atom_sign = True if st_split[1] == '<=' else False
70
+ root_phi = Atom(var_index=int(st_split[0][2]), lte=atom_sign, threshold=float(st_split[2]))
71
+ return root_phi
72
+ else:
73
+ assert (root_op_str.startswith('not') or root_op_str.startswith('eventually')
74
+ or root_op_str.startswith('always'))
75
+ current_st = copy.deepcopy(st_split[2:-1])
76
+ if root_op_str == 'not':
77
+ root_phi = Not(child=from_string_to_formula(' '.join(current_st)))
78
+ elif root_op_str.startswith('eventually'):
79
+ unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
80
+ root_phi = Eventually(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
81
+ right_unbound=right_unbound, left_time_bound=left_time_bound,
82
+ right_time_bound=right_time_bound)
83
+ else:
84
+ unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
85
+ root_phi = Globally(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
86
+ right_unbound=right_unbound, left_time_bound=left_time_bound,
87
+ right_time_bound=right_time_bound)
88
+ else:
89
+ # 1 - delete everything which is contained in other sets of parenthesis (if any)
90
+ current_st = copy.deepcopy(st_split[1:-1])
91
+ if '(' in current_st:
92
+ par_queue = deque()
93
+ par_idx_list = []
94
+ for i, sub in enumerate(current_st):
95
+ if sub == '(':
96
+ par_queue.append(i)
97
+ elif sub == ')':
98
+ par_idx_list.append(tuple([par_queue.pop(), i]))
99
+ # open_par_idx, close_par_idx = [current_st.index(p) for p in ['(', ')']]
100
+ # union of parentheses range --> from these we may extract the substrings to be the children!!!
101
+ children_range = []
102
+ for begin, end in sorted(par_idx_list):
103
+ if children_range and children_range[-1][1] >= begin - 1:
104
+ children_range[-1][1] = max(children_range[-1][1], end)
105
+ else:
106
+ children_range.append([begin, end])
107
+ n_children = len(children_range)
108
+ assert (n_children in [1, 2])
109
+ if n_children == 1:
110
+ # one of the children is a variable --> need to individuate it
111
+ var_child_idx = 1 if children_range[0][0] <= 1 else 0 # 0 is left child, 1 is right child
112
+ if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
113
+ children_range[0][0] -= 1
114
+ left_child_str = current_st[:3] if var_child_idx == 0 else \
115
+ current_st[children_range[0][0]:children_range[0][1] + 1]
116
+ right_child_str = current_st[-3:] if var_child_idx == 1 else \
117
+ current_st[children_range[0][0]:children_range[0][1] + 1]
118
+ root_op_str = current_st[children_range[0][1] + 1] if var_child_idx == 1 else \
119
+ current_st[children_range[0][0] - 1]
120
+ assert (root_op_str[:2] in ['an', 'or', 'un'])
121
+ else:
122
+ if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
123
+ children_range[0][0] -= 1
124
+ if current_st[children_range[1][0] - 1][0:2] in ['no', 'ev', 'al']:
125
+ children_range[1][0] -= 1
126
+ # if there are two children, with parentheses, the element in the middle is the root
127
+ root_op_str = current_st[children_range[0][1] + 1]
128
+ assert (root_op_str[:2] in ['an', 'or', 'un'])
129
+ left_child_str = current_st[children_range[0][0]:children_range[0][1] + 1]
130
+ right_child_str = current_st[children_range[1][0]:children_range[1][1] + 1]
131
+ else:
132
+ # no parentheses means that both children are variables
133
+ left_child_str = current_st[:3]
134
+ right_child_str = current_st[-3:]
135
+ root_op_str = current_st[3]
136
+ left_child_str = ' '.join(left_child_str)
137
+ right_child_str = ' '.join(right_child_str)
138
+ if root_op_str == 'and':
139
+ root_phi = And(left_child=from_string_to_formula(left_child_str),
140
+ right_child=from_string_to_formula(right_child_str))
141
+ elif root_op_str == 'or':
142
+ root_phi = Or(left_child=from_string_to_formula(left_child_str),
143
+ right_child=from_string_to_formula(right_child_str))
144
+ else:
145
+ unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
146
+ root_phi = Until(left_child=from_string_to_formula(left_child_str),
147
+ right_child=from_string_to_formula(right_child_str),
148
+ unbound=unbound, right_unbound=right_unbound, left_time_bound=left_time_bound,
149
+ right_time_bound=right_time_bound)
150
+ return root_phi
151
 
152
  def load_json(path: str) -> Union[Dict, List]:
153
  """
 
162
  with open(path, "r") as f:
163
  return json.load(f)
164
 
165
+ #### phis_generator ####
166
+
167
+ class StlGenerator:
168
+ def __init__(
169
+ self,
170
+ leaf_prob: float = 0.3,
171
+ inner_node_prob: list = None,
172
+ threshold_mean: float = 0.0,
173
+ threshold_sd: float = 1.0,
174
+ unbound_prob: float = 0.1,
175
+ right_unbound_prob: float = 0.2,
176
+ time_bound_max_range: float = 20,
177
+ adaptive_unbound_temporal_ops: bool = True,
178
+ max_timespan: int = 100,
179
+ ):
180
+ """
181
+ leaf_prob
182
+ probability of generating a leaf (always zero for root)
183
+ node_types = ["not", "and", "or", "always", "eventually", "until"]
184
+ Inner node types
185
+ inner_node_prob
186
+ probability vector for the different types of internal nodes
187
+ threshold_mean
188
+ threshold_sd
189
+ mean and std for the normal distribution of the thresholds of atoms
190
+ unbound_prob
191
+ probability of a temporal operator to have a time bound o the type [0,infty]
192
+ time_bound_max_range
193
+ maximum value of time span of a temporal operator (i.e. max value of t in [0,t])
194
+ adaptive_unbound_temporal_ops
195
+ if true, unbounded temporal operators are computed from current point to the end of the signal, otherwise
196
+ they are evaluated only at time zero.
197
+ max_timespan
198
+ maximum time depth of a formula.
199
+ """
200
+
201
+ # Address the mutability of default arguments
202
+ if inner_node_prob is None:
203
+ inner_node_prob = [0.166, 0.166, 0.166, 0.17, 0.166, 0.166]
204
+
205
+ self.leaf_prob = leaf_prob
206
+ self.inner_node_prob = inner_node_prob
207
+ self.threshold_mean = threshold_mean
208
+ self.threshold_sd = threshold_sd
209
+ self.unbound_prob = unbound_prob
210
+ self.right_unbound_prob = right_unbound_prob
211
+ self.time_bound_max_range = time_bound_max_range
212
+ self.adaptive_unbound_temporal_ops = adaptive_unbound_temporal_ops
213
+ self.node_types = ["not", "and", "or", "always", "eventually", "until"]
214
+ self.max_timespan = max_timespan
215
+
216
+ def sample(self, nvars):
217
+ """
218
+ Samples a random formula with distribution defined in class instance parameters
219
+
220
+ Parameters
221
+ ----------
222
+ nvars : number of variables of input signals
223
+ how many variables the formula is expected to consider.
224
+
225
+ Returns
226
+ -------
227
+ TYPE
228
+ A random formula.
229
+
230
+ """
231
+ return self._sample_internal_node(nvars)
232
+
233
+ def bag_sample(self, bag_size, nvars):
234
+ """
235
+ Samples a bag of bag_size formulae
236
+
237
+ Parameters
238
+ ----------
239
+ bag_size : INT
240
+ number of formulae.
241
+ nvars : INT
242
+ number of vars in formulae.
243
+
244
+ Returns
245
+ -------
246
+ a list of formulae.
247
+
248
+ """
249
+ formulae = []
250
+ for _ in range(bag_size):
251
+ phi = self.sample(nvars)
252
+ formulae.append(phi)
253
+ return formulae
254
+
255
+ def _sample_internal_node(self, nvars):
256
+ # Declare & dummy-assign "idiom"
257
+ node: Union[None, Node]
258
+ node = None
259
+ # choose node type
260
+ nodetype = rnd.choice(self.node_types, p=self.inner_node_prob)
261
+ while True:
262
+ if nodetype == "not":
263
+ n = self._sample_node(nvars)
264
+ node = stl.Not(n)
265
+ elif nodetype == "and":
266
+ n1 = self._sample_node(nvars)
267
+ n2 = self._sample_node(nvars)
268
+ node = stl.And(n1, n2)
269
+ elif nodetype == "or":
270
+ n1 = self._sample_node(nvars)
271
+ n2 = self._sample_node(nvars)
272
+ node = stl.Or(n1, n2)
273
+ elif nodetype == "always":
274
+ n = self._sample_node(nvars)
275
+ unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
276
+ node = stl.Globally(
277
+ n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
278
+ )
279
+ elif nodetype == "eventually":
280
+ n = self._sample_node(nvars)
281
+ unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
282
+ node = stl.Eventually(
283
+ n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
284
+ )
285
+ elif nodetype == "until":
286
+ n1 = self._sample_node(nvars)
287
+ n2 = self._sample_node(nvars)
288
+ unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
289
+ node = stl.Until(
290
+ n1, n2, unbound, right_unbound, left_time_bound, right_time_bound
291
+ )
292
+
293
+ if (node is not None) and (node.time_depth() < self.max_timespan):
294
+ return node
295
+
296
+ def _sample_node(self, nvars):
297
+ if rnd.rand() < self.leaf_prob:
298
+ # sample a leaf
299
+ var, thr, lte = self._get_atom(nvars)
300
+ return stl.Atom(var, thr, lte)
301
+ else:
302
+ return self._sample_internal_node(nvars)
303
+
304
+ def _get_temporal_parameters(self):
305
+ if rnd.rand() < self.unbound_prob:
306
+ return True, False, 0, 0
307
+ elif rnd.rand() < self.right_unbound_prob:
308
+ return False, True, rnd.randint(self.time_bound_max_range), 1
309
+ else:
310
+ left_bound = rnd.randint(self.time_bound_max_range)
311
+ return False, False, left_bound, rnd.randint(left_bound, self.time_bound_max_range) + 1
312
+
313
+ def _get_atom(self, nvars):
314
+ variable = rnd.randint(nvars)
315
+ lte = rnd.rand() > 0.5
316
+ threshold = rnd.normal(self.threshold_mean, self.threshold_sd)
317
+ return variable, threshold, lte
318
+
319
+ #### traj_measure ####
320
+
321
+ class Measure:
322
+ def sample(self, samples=100000, varn=2, points=100):
323
+ # Must be overridden
324
+ pass
325
+
326
+ class BaseMeasure(Measure):
327
+ def __init__(
328
+ self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu"
329
+ ):
330
+ """
331
+
332
+ Parameters
333
+ ----------
334
+ mu0 : mean of normal distribution of initial state, optional
335
+ The default is 0.0.
336
+ sigma0 : standard deviation of normal distribution of initial state, optional
337
+ The default is 1.0.
338
+ mu1 : DOUBLE, optional
339
+ mean of normal distribution of total variation. The default is 0.0.
340
+ sigma1 : standard deviation of normal distribution of total variation, optional
341
+ The default is 1.0.
342
+ q : DOUBLE, optional
343
+ probability of change of sign in derivative. The default is 0.1.
344
+ q0 : DOUBLE, optional
345
+ probability of initial sign of derivative. The default is 0.5.
346
+ device : 'cpu' or 'cuda', optional
347
+ device on which to run the algorithm. The default is 'cpu'.
348
+
349
+ Returns
350
+ -------
351
+ None.
352
+
353
+ """
354
+ self.mu0 = mu0
355
+ self.sigma0 = sigma0
356
+ self.mu1 = mu1
357
+ self.sigma1 = sigma1
358
+ self.q = q
359
+ self.q0 = q0
360
+ self.device = device
361
+
362
+ def sample(self, samples=100000, varn=2, points=100):
363
+ """
364
+ Samples a set of trajectories from the basic measure space, with parameters
365
+ passed to the sampler
366
+
367
+ Parameters
368
+ ----------
369
+ points : INT, optional
370
+ number of points per trajectory, including initial one. The default is 1000.
371
+ samples : INT, optional
372
+ number of trajectories. The default is 100000.
373
+ varn : INT, optional
374
+ number of variables per trajectory. The default is 2.
375
+
376
+
377
+ Returns
378
+ -------
379
+ signal : samples x varn x points double pytorch tensor
380
+ The sampled signals.
381
+
382
+ """
383
+ if self.device == "cuda" and not torch.cuda.is_available():
384
+ raise RuntimeError("GPU card or CUDA library not available!")
385
+
386
+ # generate unif RN
387
+ signal = torch.rand(samples, varn, points, device=self.device)
388
+ # first point is special - set to zero for the moment, and set one point to 1
389
+ signal[:, :, 0] = 0.0
390
+ signal[:, :, -1] = 1.0
391
+ # sorting each trajectory
392
+ signal, _ = torch.sort(signal, 2)
393
+ # computing increments and storing them in points 1 to end
394
+ signal[:, :, 1:] = signal[:, :, 1:] - signal[:, :, :-1]
395
+ # generate initial state, according to a normal distribution
396
+ signal[:, :, 0] = self.mu0 + self.sigma0 * torch.randn(signal[:, :, 0].size())
397
+
398
+ # sampling change signs from bernoulli in -1, 1
399
+ derivs = (1 - self.q) * torch.ones(samples, varn, points, device=self.device)
400
+ derivs = 2 * torch.bernoulli(derivs) - 1
401
+ # sampling initial derivative
402
+ derivs[:, :, 0] = self.q0
403
+ derivs[:, :, 0] = 2 * torch.bernoulli(derivs[:, :, 0]) - 1
404
+ # taking the cumulative product along axis 2
405
+ derivs = torch.cumprod(derivs, 2)
406
+
407
+ # sampling total variation
408
+ totvar = torch.pow(
409
+ self.mu1 + self.sigma1 * torch.randn(samples, varn, 1, device=self.device),
410
+ 2,
411
+ )
412
+ # multiplying total variation and derivatives and making initial point non-invasive
413
+ derivs = derivs * totvar
414
+ derivs[:, :, 0] = 1.0
415
+
416
+ # computing trajectories by multiplying and then doing a cumulative sum
417
+ signal = signal * derivs
418
+ signal = torch.cumsum(signal, 2)
419
+ return signal
420
+
421
+ #### kernel ####
422
+
423
+ realnum = Union[float, int]
424
+
425
+ class StlKernel:
426
+ def __init__(
427
+ self,
428
+ measure,
429
+ normalize=True,
430
+ exp_kernel=True,
431
+ sigma2=0.2, # 0.5 meglio, inizialmente era a 0.2
432
+ integrate_time=False,
433
+ samples=100000,
434
+ varn=2,
435
+ points=100,
436
+ boolean=False,
437
+ signals=None,
438
+ ):
439
+ self.traj_measure = measure
440
+ self.exp_kernel = exp_kernel
441
+ self.normalize = normalize
442
+ self.sigma2 = sigma2
443
+ self.samples = samples
444
+ self.varn = varn
445
+ self.points = points
446
+ self.integrate_time = integrate_time
447
+ if signals is not None:
448
+ self.signals = signals
449
+ else:
450
+ self.signals = measure.sample(points=points, samples=samples, varn=varn)
451
+ self.boolean = boolean
452
+
453
+ def compute(self, phi1, phi2):
454
+ return self.compute_one_one(phi1, phi2)
455
+
456
+ def compute_one_one(self, phi1, phi2):
457
+ phis1: list = [phi1]
458
+ phis2: list = [phi2]
459
+ ker = self.compute_bag_bag(phis1, phis2)
460
+ return ker[0, 0]
461
+
462
+ def compute_bag(self, phis, return_robustness=True):
463
+ if self.integrate_time:
464
+ rhos, selfk, len0 = self._compute_robustness_time(phis)
465
+ kernel_matrix = self._compute_kernel_time(
466
+ rhos, rhos, selfk, selfk, len0, len0
467
+ )
468
+ else:
469
+ rhos, selfk = self._compute_robustness_no_time(phis)
470
+ kernel_matrix = self._compute_kernel_no_time(rhos, rhos, selfk, selfk)
471
+ len0 = None
472
+ if return_robustness:
473
+ return kernel_matrix.cpu(), rhos, selfk, len0
474
+ else:
475
+ return kernel_matrix.cpu()
476
+
477
+ def compute_one_bag(self, phi1, phis2, return_robustness=False):
478
+ phis1: list = [phi1]
479
+ return self.compute_bag_bag(phis1, phis2, return_robustness)
480
+
481
+ def compute_bag_bag(self, phis1, phis2, return_robustness=False):
482
+ if self.integrate_time:
483
+ rhos1, selfk1, len1 = self._compute_robustness_time(phis1)
484
+ rhos2, selfk2, len2 = self._compute_robustness_time(phis2)
485
+ kernel_matrix = self._compute_kernel_time(
486
+ rhos1, rhos2, selfk1, selfk2, len1, len2
487
+ )
488
+ else:
489
+ rhos1, selfk1 = self._compute_robustness_no_time(phis1)
490
+ rhos2, selfk2 = self._compute_robustness_no_time(phis2)
491
+ len1, len2 = [None, None]
492
+ kernel_matrix = self._compute_kernel_no_time(rhos1, rhos2, selfk1, selfk2)
493
+ if return_robustness:
494
+ return kernel_matrix.cpu(), rhos1, rhos2, selfk1, selfk2, len1, len2
495
+ else:
496
+ return kernel_matrix.cpu()
497
+
498
+ def compute_one_from_robustness(self, phi, rhos, rho_self, lengths=None, return_robustness=False):
499
+ phis: list = [phi]
500
+ return self.compute_bag_from_robustness(phis, rhos, rho_self, lengths, return_robustness)
501
+
502
+ def compute_bag_from_robustness(self, phis, rhos, rho_self, lengths=None, return_robustness=False):
503
+ if self.integrate_time:
504
+ rhos1, selfk1, len1 = self._compute_robustness_time(phis)
505
+ kernel_matrix = self._compute_kernel_time(
506
+ rhos1, rhos, selfk1, rho_self, len1, lengths
507
+ )
508
+ else:
509
+ rhos1, selfk1 = self._compute_robustness_no_time(phis)
510
+ len1 = None
511
+ kernel_matrix = self._compute_kernel_no_time(rhos1, rhos, selfk1, rho_self)
512
+ if return_robustness:
513
+ return kernel_matrix.cpu(), rhos1, selfk1, len1
514
+ else:
515
+ return kernel_matrix.cpu()
516
+
517
+ def _compute_robustness_time(self, phis):
518
+ n = self.samples
519
+ p = self.points
520
+ k = len(phis)
521
+ rhos = torch.zeros((k, n, p), device="cpu")
522
+ lengths = torch.zeros(k)
523
+ self_kernels = torch.zeros((k, 1))
524
+ for i, phi in enumerate(phis):
525
+ if self.boolean:
526
+ rho = phi.boolean(self.signals, evaluate_at_all_times=True).float()
527
+ rho[rho == 0.0] = -1.0
528
+ else:
529
+ rho = phi.quantitative(self.signals, evaluate_at_all_times=True)
530
+ actual_p = rho.size()[2]
531
+ rho = rho.reshape(n, actual_p).cpu()
532
+ rhos[i, :, :actual_p] = rho
533
+ lengths[i] = actual_p
534
+ self_kernels[i] = torch.tensordot(
535
+ rho.reshape(1, n, -1), rho.reshape(1, n, -1), dims=[[1, 2], [1, 2]]
536
+ ) / (actual_p * n)
537
+ return rhos, self_kernels, lengths
538
+
539
+ def _compute_robustness_no_time(self, phis):
540
+ n = self.samples
541
+ k = len(phis)
542
+ rhos = torch.zeros((k, n), device=self.traj_measure.device)
543
+ self_kernels = torch.zeros((k, 1), device=self.traj_measure.device)
544
+ for i, phi in enumerate(phis):
545
+ if self.boolean:
546
+ rho = phi.boolean(self.signals, evaluate_at_all_times=False).float()
547
+ rho[rho == 0.0] = -1.0
548
+ else:
549
+ rho = phi.quantitative(self.signals, evaluate_at_all_times=False)
550
+ self_kernels[i] = rho.dot(rho) / n
551
+ rhos[i, :] = rho
552
+ return rhos, self_kernels
553
+
554
+ def _compute_kernel_time(self, rhos1, rhos2, selfk1, selfk2, len1, len2):
555
+ kernel_matrix = torch.tensordot(rhos1, rhos2, [[1, 2], [1, 2]])
556
+ length_normalizer = self._compute_trajectory_length_normalizer(len1, len2)
557
+ kernel_matrix = kernel_matrix * length_normalizer / self.samples
558
+ if self.normalize:
559
+ kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
560
+ if self.exp_kernel:
561
+ kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
562
+ return kernel_matrix
563
+
564
+ def _compute_kernel_no_time(self, rhos1, rhos2, selfk1, selfk2):
565
+ kernel_matrix = torch.tensordot(rhos1, rhos2, [[1], [1]])
566
+ kernel_matrix = kernel_matrix / self.samples
567
+ if self.normalize:
568
+ kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
569
+ if self.exp_kernel:
570
+ kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
571
+ return kernel_matrix
572
+
573
+ @staticmethod
574
+ def _normalize(kernel_matrix, selfk1, selfk2):
575
+ normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
576
+ kernel_matrix = kernel_matrix / normalize
577
+ return kernel_matrix
578
+
579
+ def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None):
580
+ if sigma2 is None:
581
+ sigma2 = self.sigma2
582
+ if self.normalize:
583
+ # selfk is (1.0^2 + 1.0^2)
584
+ selfk = 2.0
585
+ else:
586
+ k1 = selfk1.size()[0]
587
+ k2 = selfk2.size()[0]
588
+ selfk = (selfk1 * selfk1).repeat(1, k2) + torch.transpose(
589
+ selfk2 * selfk2, 0, 1
590
+ ).repeat(k1, 1)
591
+ return torch.exp(-(selfk - 2 * kernel_matrix) / (2 * sigma2))
592
+
593
+ @staticmethod
594
+ def _compute_trajectory_length_normalizer(len1, len2):
595
+ k1 = len1.size()[0]
596
+ k2 = len2.size()[0]
597
+ y1 = len1.reshape(-1, 1)
598
+ y1 = y1.repeat(1, k2)
599
+ y2 = len2.repeat(k1, 1)
600
+ return 1.0 / torch.min(y1, y2)
601
+
602
+ class GramMatrix:
603
+ def __init__(self, kernel, formulae, store_robustness=True, sample=False, sampler=None, bag_size=None):
604
+ self.kernel = kernel
605
+ self.formulae_list = formulae
606
+ # if kernel is computed from robustness at time zero only,
607
+ # we store the robustness for each formula and each sample
608
+ # to speed up computation later
609
+ self.store_robustness = store_robustness
610
+ self.dim = len(self.formulae_list) if not bag_size else int(bag_size)
611
+ self.sample = sample # whether to generate formulae in a controlled manner
612
+ if self.sample:
613
+ self.t = 0.99 if self.kernel.boolean else 0.85
614
+ self.sampler = sampler # stl formulae generator
615
+ self._compute_gram_matrix()
616
+
617
+ def _compute_gram_matrix(self):
618
+ if self.sample:
619
+ gram = torch.zeros(self.dim, self.dim)
620
+ rhos = torch.zeros((self.dim, self.kernel.samples), device=self.kernel.traj_measure.device) if \
621
+ not self.kernel.integrate_time else torch.zeros((self.dim, self.kernel.samples, self.kernel.points),
622
+ device=self.kernel.traj_measure.device)
623
+ lengths = torch.zeros(self.dim) if self.kernel.integrate_time else np.zeros(self.dim)
624
+ kernels = torch.zeros((self.dim, 1), device=self.kernel.traj_measure.device)
625
+ phis = [self.sampler.sample(nvars=self.kernel.varn)]
626
+ gram[0, :1], rhos[0], kernels[0, :], lengths[0] = self.kernel.compute_bag(phis, return_robustness=True)
627
+ while len(phis) < self.dim:
628
+ i = len(phis)
629
+ phi = self.sampler.sample(nvars=self.kernel.varn)
630
+ gram[i, :i], rhos[i], kernels[i, :], lengths[i] = self.kernel.compute_one_from_robustness(
631
+ phi, rhos[:i, :], kernels[:i, :], lengths[:i], return_robustness=True)
632
+ if torch.sum(gram[i, :i + 1] >= self.t) < 3:
633
+ phis.append(phi)
634
+ gram[:i, i] = gram[i, :i]
635
+ gram[i, i] = kernels[i, :]
636
+
637
+ self.formulae_list = phis
638
+ self.gram = gram.cpu()
639
+ self.robustness = rhos if self.store_robustness else None
640
+ self.self_kernels = kernels if self.store_robustness else None
641
+ self.robustness_lengths = lengths if self.store_robustness else None
642
+ else:
643
+ if self.store_robustness:
644
+ k_matrix, rhos, selfk, len0 = self.kernel.compute_bag(
645
+ self.formulae_list, return_robustness=True
646
+ )
647
+ self.gram = k_matrix
648
+ self.robustness = rhos
649
+ self.self_kernels = selfk
650
+ self.robustness_lengths = len0
651
+ else:
652
+ self.gram = self.kernel.compute_bag(
653
+ self.formulae_list, return_robustness=False
654
+ )
655
+ self.robustness = None
656
+ self.self_kernels = None
657
+ self.robustness_lengths = None
658
+
659
+ def compute_kernel_vector(self, phi):
660
+ if self.store_robustness:
661
+ return self.kernel.compute_one_from_robustness(
662
+ phi, self.robustness, self.self_kernels, self.robustness_lengths
663
+ )
664
+ else:
665
+ return self.kernel.compute_one_bag(phi, self.formulae_list)
666
+
667
+ def compute_bag_kernel_vector(self, phis, generate_phis=False, bag_size=None):
668
+ if generate_phis:
669
+ gram_test = torch.zeros(bag_size, self.dim) # self.dim, bag_size
670
+ rhos_test = torch.zeros((bag_size, self.kernel.samples), device=self.kernel.traj_measure.device) if \
671
+ not self.kernel.integrate_time else torch.zeros((bag_size, self.kernel.samples, self.kernel.points),
672
+ device=self.kernel.traj_measure.device)
673
+ lengths_test = torch.zeros(bag_size) if self.kernel.integrate_time else np.zeros(bag_size)
674
+ kernels_test = torch.zeros((bag_size, 1), device=self.kernel.traj_measure.device)
675
+ phi_test = []
676
+ while len(phi_test) < bag_size:
677
+ i = len(phi_test)
678
+ phi = self.sampler.sample(nvars=self.kernel.varn)
679
+ if self.store_robustness:
680
+ gram_test[i, :], rhos_test[i], kernels_test[i, :], lengths_test[i] = \
681
+ self.kernel.compute_one_from_robustness(phi, self.robustness, self.self_kernels,
682
+ self.robustness_lengths, return_robustness=True)
683
+ else:
684
+ gram_test[i, :], rhos_test[i], _, kernels_test[i, :], _, lengths_test[i], _ = \
685
+ self.kernel.compute_one_bag(phi, self.formulae_list, return_robustness=True)
686
+ if not ((rhos_test[i] > 0).all() or (rhos_test[i] < 0).all()):
687
+ phi_test.append(phi)
688
+ return phi_test, gram_test.cpu()
689
+ else:
690
+ if self.store_robustness:
691
+ return self.kernel.compute_bag_from_robustness(
692
+ phis, self.robustness, self.self_kernels, self.robustness_lengths
693
+ )
694
+ else:
695
+ return self.kernel.compute_bag_bag(phis, self.formulae_list)
696
+
697
+ def invert_regularized(self, alpha):
698
+ regularizer = abs(pow(10, alpha)) * torch.eye(self.dim)
699
+ return torch.inverse(self.gram + regularizer)
700
+
701
+ #### anchor_generation ####
702
+
703
+ def anchorGeneration(diff_init = False, # to control whether we want formulae to be semantically different by construction
704
+ embed_dim: int = 30, # embedding dimension, aka number of generated formulae in the anchor set
705
+ n_vars: int = 3, # dimension of the input signal (3D in this case)
706
+ leaf_prob: float = 0.4, # complexity of the generated formula
707
+ cosine_similarity_threshold: float = 0.8 # if two formulae cosine similarity exceeds 0.9, then discard one of the two
708
+ ) -> str:
709
+
710
+ # initialize STL formula generator
711
+ sampler = StlGenerator(leaf_prob)
712
+
713
+ # effective anchor set generation
714
+ if diff_init:
715
+
716
+ # initialize the anchor set with a randomly sampled formula
717
+ diff_anchor_set = [sampler.sample(nvars=n_vars)]
718
+
719
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
720
+ mu = BaseMeasure(device=device)
721
+
722
+ # generates a set of random signals working as a tester for the formulae testing
723
+ signals = mu.sample(samples=10000, varn=n_vars)
724
+
725
+ # computes robustness value for the initial set of formulae in the anchor set
726
+ anchor_rob_vectors = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in diff_anchor_set], 0)
727
+
728
+ while len(diff_anchor_set) < embed_dim:
729
+ # sample the 'remaining' formulae to reach the desired number of `embed_dim` formulae:
730
+ candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars)
731
+
732
+ # compute robustness of candidate anchor formulae on the same signals as previous anchor set
733
+ candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0)
734
+
735
+ # compute cosine similarity between current anchor set and candidate new formulae
736
+ cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1)
737
+
738
+ # check which formulae are similar (i.e. greater cosine similarity then threshold) w.r.t. current anchors
739
+ # NOTA: chiedere a gaia se cosine similarities negative vanno ammazzate con un valore assoluto o meno!
740
+ similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])]
741
+
742
+ # keep only those who are semantically distant
743
+ keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist])))
744
+
745
+ diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx]
746
+
747
+ # Convert keep_idx to a tensor on the same device as candidate_robs
748
+ keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device)
749
+
750
+ # Use index_select to pick the relevant rows
751
+ selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor)
752
+
753
+ # Concatenate on the same device
754
+ anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0)
755
+
756
+ anchor_set = diff_anchor_set[:embed_dim]
757
+
758
+ else:
759
+ anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars)
760
+
761
+ filename = f'anchor_set_no_diff_{embed_dim}_dim'
762
+ dump_pickle(filename, anchor_set)
763
+ return filename
764
+
765
+ ####
766
 
767
  class STLTokenizer(PreTrainedTokenizer):
768
  """
 
1099
 
1100
  return attn_output, None, past_key_value
1101
 
1102
+ ####
1103
 
1104
  class STLEncoder():
1105
  def __init__(self,
 
1504
  cross_attentions=all_cross_attentions,
1505
  )
1506
 
1507
+ ####
1508
 
1509
  class STLForCausalLM(STLModel, GenerationMixin):
1510
  _tied_weights_keys = ["lm_head.weight"]