saracandu commited on
Commit
4eacfc4
·
verified ·
1 Parent(s): 21c7f66

Delete modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +0 -2168
modeling.py DELETED
@@ -1,2168 +0,0 @@
1
- import ast
2
- import copy
3
- import math
4
- import pickle
5
- import os
6
- from collections import deque
7
- from typing import List, Optional, Tuple, Union
8
-
9
- import numpy as np
10
- import pandas as pd
11
- import torch
12
- import torch.utils.checkpoint
13
- from torch import nn
14
- import torch.nn.functional as F
15
- from torch.utils.data import Dataset
16
-
17
- from transformers.modeling_utils import PreTrainedModel
18
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
19
- from transformers.generation import GenerationMixin
20
- from transformers.utils import (
21
- add_end_docstrings,
22
- add_start_docstrings,
23
- add_start_docstrings_to_model_forward,
24
- logging,
25
- replace_return_docstrings,
26
- )
27
- from transformers.modeling_outputs import (
28
- BaseModelOutput,
29
- BaseModelOutputWithPastAndCrossAttentions,
30
- CausalLMOutputWithCrossAttentions,
31
- Seq2SeqLMOutput,
32
- Seq2SeqModelOutput,
33
- )
34
-
35
- from .configuration import STLConfig
36
- from nltk.translate.bleu_score import sentence_bleu
37
- # from stl import *
38
- import networkx as nx
39
- from datasets import load_dataset
40
-
41
- ### from custom_typing.py
42
-
43
- realnum = Union[float, int]
44
-
45
-
46
- ### from stl.py
47
-
48
- # For tensor functions
49
- import torch
50
- from torch import Tensor
51
- import torch.nn.functional as F
52
-
53
-
54
- def eventually(x: Tensor, time_span: int) -> Tensor:
55
- """
56
- STL operator 'eventually' in 1D.
57
-
58
- Parameters
59
- ----------
60
- x: torch.Tensor
61
- Signal
62
- time_span: any numeric type
63
- Timespan duration
64
-
65
- Returns
66
- -------
67
- torch.Tensor
68
- A tensor containing the result of the operation.
69
- """
70
- return F.max_pool1d(x, kernel_size=time_span, stride=1)
71
-
72
- class Node:
73
- """Abstract node class for STL semantics tree."""
74
-
75
- def __init__(self) -> None:
76
- # Must be overloaded.
77
- pass
78
-
79
- def __str__(self) -> str:
80
- # Must be overloaded.
81
- pass
82
-
83
- def boolean(self, x: Tensor, evaluate_at_all_times: bool = False) -> Tensor:
84
- """
85
- Evaluates the boolean semantics at the node.
86
-
87
- Parameters
88
- ----------
89
- x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
90
- The input signals, stored as a batch tensor with trhee dimensions.
91
- evaluate_at_all_times: bool
92
- Whether to evaluate the semantics at all times (True) or
93
- just at t=0 (False).
94
-
95
- Returns
96
- -------
97
- torch.Tensor
98
- A tensor with the boolean semantics for the node.
99
- """
100
- z: Tensor = self._boolean(x)
101
- if evaluate_at_all_times:
102
- return z
103
- else:
104
- return self._extract_semantics_at_time_zero(z)
105
-
106
- def quantitative(
107
- self,
108
- x: Tensor,
109
- normalize: bool = False,
110
- evaluate_at_all_times: bool = False,
111
- ) -> Tensor:
112
- """
113
- Evaluates the quantitative semantics at the node.
114
-
115
- Parameters
116
- ----------
117
- x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
118
- The input signals, stored as a batch tensor with three dimensions.
119
- normalize: bool
120
- Whether the measure of robustness if normalized (True) or
121
- not (False). Currently not in use.
122
- evaluate_at_all_times: bool
123
- Whether to evaluate the semantics at all times (True) or
124
- just at t=0 (False).
125
-
126
- Returns
127
- -------
128
- torch.Tensor
129
- A tensor with the quantitative semantics for the node.
130
- """
131
- z: Tensor = self._quantitative(x, normalize)
132
- if evaluate_at_all_times:
133
- return z
134
- else:
135
- return self._extract_semantics_at_time_zero(z)
136
-
137
- def set_normalizing_flag(self, value: bool = True) -> None:
138
- """
139
- Setter for the 'normalization of robustness of the formula' flag.
140
- Currently not in use.
141
- """
142
-
143
- def time_depth(self) -> int:
144
- """Returns time depth of bounded temporal operators only."""
145
- # Must be overloaded.
146
-
147
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
148
- """Private method equivalent to public one for inner call."""
149
- # Must be overloaded.
150
-
151
- def _boolean(self, x: Tensor) -> Tensor:
152
- """Private method equivalent to public one for inner call."""
153
- # Must be overloaded.
154
-
155
- @staticmethod
156
- def _extract_semantics_at_time_zero(x: Tensor) -> Tensor:
157
- """Extrapolates the vector of truth values at time zero"""
158
- return torch.reshape(x[:, 0, 0], (-1,))
159
-
160
-
161
- class Atom(Node):
162
- """Atomic formula node; for now of the form X<=t or X>=t"""
163
-
164
- def __init__(self, var_index: int, threshold: realnum, lte: bool = False) -> None:
165
- super().__init__()
166
- self.var_index: int = var_index
167
- self.threshold: realnum = threshold
168
- self.lte: bool = lte
169
-
170
- def __str__(self) -> str:
171
- s: str = (
172
- "x_"
173
- + str(self.var_index)
174
- + (" <= " if self.lte else " >= ")
175
- + str(round(self.threshold, 4))
176
- )
177
- return s
178
-
179
- def time_depth(self) -> int:
180
- return 0
181
-
182
- def _boolean(self, x: Tensor) -> Tensor:
183
- # extract tensor of the same dimension as data, but with only one variable
184
- xj: Tensor = x[:, self.var_index, :]
185
- xj: Tensor = xj.view(xj.size()[0], 1, -1)
186
- if self.lte:
187
- z: Tensor = torch.le(xj, self.threshold)
188
- else:
189
- z: Tensor = torch.ge(xj, self.threshold)
190
- return z
191
-
192
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
193
- # extract tensor of the same dimension as data, but with only one variable
194
- xj: Tensor = x[:, self.var_index, :]
195
- xj: Tensor = xj.view(xj.size()[0], 1, -1)
196
- if self.lte:
197
- z: Tensor = -xj + self.threshold
198
- else:
199
- z: Tensor = xj - self.threshold
200
- if normalize:
201
- z: Tensor = torch.tanh(z)
202
- return z
203
-
204
- class Not(Node):
205
- """Negation node."""
206
-
207
- def __init__(self, child: Node) -> None:
208
- super().__init__()
209
- self.child: Node = child
210
-
211
- def __str__(self) -> str:
212
- s: str = "not ( " + self.child.__str__() + " )"
213
- return s
214
-
215
- def time_depth(self) -> int:
216
- return self.child.time_depth()
217
-
218
- def _boolean(self, x: Tensor) -> Tensor:
219
- z: Tensor = ~self.child._boolean(x)
220
- return z
221
-
222
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
223
- z: Tensor = -self.child._quantitative(x, normalize)
224
- return z
225
-
226
-
227
- class And(Node):
228
- """Conjunction node."""
229
-
230
- def __init__(self, left_child: Node, right_child: Node) -> None:
231
- super().__init__()
232
- self.left_child: Node = left_child
233
- self.right_child: Node = right_child
234
-
235
- def __str__(self) -> str:
236
- s: str = (
237
- "( "
238
- + self.left_child.__str__()
239
- + " and "
240
- + self.right_child.__str__()
241
- + " )"
242
- )
243
- return s
244
-
245
- def time_depth(self) -> int:
246
- return max(self.left_child.time_depth(), self.right_child.time_depth())
247
-
248
- def _boolean(self, x: Tensor) -> Tensor:
249
- z1: Tensor = self.left_child._boolean(x)
250
- z2: Tensor = self.right_child._boolean(x)
251
- size: int = min(z1.size()[2], z2.size()[2])
252
- z1: Tensor = z1[:, :, :size]
253
- z2: Tensor = z2[:, :, :size]
254
- z: Tensor = torch.logical_and(z1, z2)
255
- return z
256
-
257
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
258
- z1: Tensor = self.left_child._quantitative(x, normalize)
259
- z2: Tensor = self.right_child._quantitative(x, normalize)
260
- size: int = min(z1.size()[2], z2.size()[2])
261
- z1: Tensor = z1[:, :, :size]
262
- z2: Tensor = z2[:, :, :size]
263
- z: Tensor = torch.min(z1, z2)
264
- return z
265
-
266
- class Not(Node):
267
- """Negation node."""
268
-
269
- def __init__(self, child: Node) -> None:
270
- super().__init__()
271
- self.child: Node = child
272
-
273
- def __str__(self) -> str:
274
- s: str = "not ( " + self.child.__str__() + " )"
275
- return s
276
-
277
- def time_depth(self) -> int:
278
- return self.child.time_depth()
279
-
280
- def _boolean(self, x: Tensor) -> Tensor:
281
- z: Tensor = ~self.child._boolean(x)
282
- return z
283
-
284
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
285
- z: Tensor = -self.child._quantitative(x, normalize)
286
- return z
287
-
288
-
289
- class And(Node):
290
- """Conjunction node."""
291
-
292
- def __init__(self, left_child: Node, right_child: Node) -> None:
293
- super().__init__()
294
- self.left_child: Node = left_child
295
- self.right_child: Node = right_child
296
-
297
- def __str__(self) -> str:
298
- s: str = (
299
- "( "
300
- + self.left_child.__str__()
301
- + " and "
302
- + self.right_child.__str__()
303
- + " )"
304
- )
305
- return s
306
-
307
- def time_depth(self) -> int:
308
- return max(self.left_child.time_depth(), self.right_child.time_depth())
309
-
310
- def _boolean(self, x: Tensor) -> Tensor:
311
- z1: Tensor = self.left_child._boolean(x)
312
- z2: Tensor = self.right_child._boolean(x)
313
- size: int = min(z1.size()[2], z2.size()[2])
314
- z1: Tensor = z1[:, :, :size]
315
- z2: Tensor = z2[:, :, :size]
316
- z: Tensor = torch.logical_and(z1, z2)
317
- return z
318
-
319
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
320
- z1: Tensor = self.left_child._quantitative(x, normalize)
321
- z2: Tensor = self.right_child._quantitative(x, normalize)
322
- size: int = min(z1.size()[2], z2.size()[2])
323
- z1: Tensor = z1[:, :, :size]
324
- z2: Tensor = z2[:, :, :size]
325
- z: Tensor = torch.min(z1, z2)
326
- return z
327
-
328
- class Or(Node):
329
- """Disjunction node."""
330
-
331
- def __init__(self, left_child: Node, right_child: Node) -> None:
332
- super().__init__()
333
- self.left_child: Node = left_child
334
- self.right_child: Node = right_child
335
-
336
- def __str__(self) -> str:
337
- s: str = (
338
- "( "
339
- + self.left_child.__str__()
340
- + " or "
341
- + self.right_child.__str__()
342
- + " )"
343
- )
344
- return s
345
-
346
- def time_depth(self) -> int:
347
- return max(self.left_child.time_depth(), self.right_child.time_depth())
348
-
349
- def _boolean(self, x: Tensor) -> Tensor:
350
- z1: Tensor = self.left_child._boolean(x)
351
- z2: Tensor = self.right_child._boolean(x)
352
- size: int = min(z1.size()[2], z2.size()[2])
353
- z1: Tensor = z1[:, :, :size]
354
- z2: Tensor = z2[:, :, :size]
355
- z: Tensor = torch.logical_or(z1, z2)
356
- return z
357
-
358
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
359
- z1: Tensor = self.left_child._quantitative(x, normalize)
360
- z2: Tensor = self.right_child._quantitative(x, normalize)
361
- size: int = min(z1.size()[2], z2.size()[2])
362
- z1: Tensor = z1[:, :, :size]
363
- z2: Tensor = z2[:, :, :size]
364
- z: Tensor = torch.max(z1, z2)
365
- return z
366
-
367
-
368
- class Globally(Node):
369
- """Globally node."""
370
- def __init__(
371
- self,
372
- child: Node,
373
- unbound: bool = False,
374
- right_unbound: bool = False,
375
- left_time_bound: int = 0,
376
- right_time_bound: int = 1,
377
- adapt_unbound: bool = True,
378
- ) -> None:
379
- super().__init__()
380
- self.child: Node = child
381
- self.unbound: bool = unbound
382
- self.right_unbound: bool = right_unbound
383
- self.left_time_bound: int = left_time_bound
384
- self.right_time_bound: int = right_time_bound + 1
385
- self.adapt_unbound: bool = adapt_unbound
386
-
387
- def __str__(self) -> str:
388
- s_left = "[" + str(self.left_time_bound) + ","
389
- s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
390
- s0: str = s_left + s_right + "]" if not self.unbound else ""
391
- s: str = "always" + s0 + " ( " + self.child.__str__() + " )"
392
- return s
393
-
394
- def time_depth(self) -> int:
395
- if self.unbound:
396
- return self.child.time_depth()
397
- elif self.right_unbound:
398
- return self.child.time_depth() + self.left_time_bound
399
- else:
400
- # diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
401
- return self.child.time_depth() + self.right_time_bound - 1
402
- # (self.right_time_bound - self.left_time_bound + 1) - diff
403
-
404
- def _boolean(self, x: Tensor) -> Tensor:
405
- z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:]) # nested temporal parameters
406
- # z1 = z1[:, :, self.left_time_bound:]
407
- if self.unbound or self.right_unbound:
408
- if self.adapt_unbound:
409
- z: Tensor
410
- _: Tensor
411
- z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
412
- z: Tensor = torch.flip(z, [2])
413
- else:
414
- z: Tensor
415
- _: Tensor
416
- z, _ = torch.min(z1, 2, keepdim=True)
417
- else:
418
- z: Tensor = torch.ge(1.0 - eventually((~z1).double(), self.right_time_bound - self.left_time_bound), 0.5)
419
- return z
420
-
421
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
422
- z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
423
- # z1 = z1[:, :, self.left_time_bound:]
424
- if self.unbound or self.right_unbound:
425
- if self.adapt_unbound:
426
- z: Tensor
427
- _: Tensor
428
- z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
429
- z: Tensor = torch.flip(z, [2])
430
- else:
431
- z: Tensor
432
- _: Tensor
433
- z, _ = torch.min(z1, 2, keepdim=True)
434
- else:
435
- z: Tensor = -eventually(-z1, self.right_time_bound - self.left_time_bound)
436
- return z
437
-
438
-
439
-
440
- class Eventually(Node):
441
- """Eventually node."""
442
-
443
- def __init__(
444
- self,
445
- child: Node,
446
- unbound: bool = False,
447
- right_unbound: bool = False,
448
- left_time_bound: int = 0,
449
- right_time_bound: int = 1,
450
- adapt_unbound: bool = True,
451
- ) -> None:
452
- super().__init__()
453
- self.child: Node = child
454
- self.unbound: bool = unbound
455
- self.right_unbound: bool = right_unbound
456
- self.left_time_bound: int = left_time_bound
457
- self.right_time_bound: int = right_time_bound + 1
458
- self.adapt_unbound: bool = adapt_unbound
459
-
460
- if (self.unbound is False) and (self.right_unbound is False) and \
461
- (self.right_time_bound <= self.left_time_bound):
462
- raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
463
-
464
- def __str__(self) -> str:
465
- s_left = "[" + str(self.left_time_bound) + ","
466
- s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
467
- s0: str = s_left + s_right + "]" if not self.unbound else ""
468
- s: str = "eventually" + s0 + " ( " + self.child.__str__() + " )"
469
- return s
470
-
471
- def time_depth(self) -> int:
472
- if self.unbound:
473
- return self.child.time_depth()
474
- elif self.right_unbound:
475
- return self.child.time_depth() + self.left_time_bound
476
- else:
477
- # diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
478
- return self.child.time_depth() + self.right_time_bound - 1
479
- # (self.right_time_bound - self.left_time_bound + 1) - diff
480
-
481
- def _boolean(self, x: Tensor) -> Tensor:
482
- z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:])
483
- if self.unbound or self.right_unbound:
484
- if self.adapt_unbound:
485
- z: Tensor
486
- _: Tensor
487
- z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
488
- z: Tensor = torch.flip(z, [2])
489
- else:
490
- z: Tensor
491
- _: Tensor
492
- z, _ = torch.max(z1, 2, keepdim=True)
493
- else:
494
- z: Tensor = torch.ge(eventually(z1.double(), self.right_time_bound - self.left_time_bound), 0.5)
495
- return z
496
-
497
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
498
- z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
499
- if self.unbound or self.right_unbound:
500
- if self.adapt_unbound:
501
- z: Tensor
502
- _: Tensor
503
- z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
504
- z: Tensor = torch.flip(z, [2])
505
- else:
506
- z: Tensor
507
- _: Tensor
508
- z, _ = torch.max(z1, 2, keepdim=True)
509
- else:
510
- z: Tensor = eventually(z1, self.right_time_bound - self.left_time_bound)
511
- return z
512
-
513
- class Until(Node):
514
- """Until node."""
515
-
516
- def __init__(
517
- self,
518
- left_child: Node,
519
- right_child: Node,
520
- unbound: bool = False,
521
- right_unbound: bool = False,
522
- left_time_bound: int = 0,
523
- right_time_bound: int = 1,
524
- ) -> None:
525
- super().__init__()
526
- self.left_child: Node = left_child
527
- self.right_child: Node = right_child
528
- self.unbound: bool = unbound
529
- self.right_unbound: bool = right_unbound
530
- self.left_time_bound: int = left_time_bound
531
- self.right_time_bound: int = right_time_bound + 1
532
-
533
- if (self.unbound is False) and (self.right_unbound is False) and \
534
- (self.right_time_bound <= self.left_time_bound):
535
- raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
536
-
537
- def __str__(self) -> str:
538
- s_left = "[" + str(self.left_time_bound) + ","
539
- s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
540
- s0: str = s_left + s_right + "]" if not self.unbound else ""
541
- s: str = "( " + self.left_child.__str__() + " until" + s0 + " " + self.right_child.__str__() + " )"
542
- return s
543
-
544
- def time_depth(self) -> int:
545
- sum_children_depth: int = self.left_child.time_depth() + self.right_child.time_depth()
546
- if self.unbound:
547
- return sum_children_depth
548
- elif self.right_unbound:
549
- return sum_children_depth + self.left_time_bound
550
- else:
551
- # diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
552
- return sum_children_depth + self.right_time_bound - 1
553
- # (self.right_time_bound - self.left_time_bound + 1) - diff
554
-
555
- def _boolean(self, x: Tensor) -> Tensor:
556
- if self.unbound:
557
- z1: Tensor = self.left_child._boolean(x)
558
- z2: Tensor = self.right_child._boolean(x)
559
- size: int = min(z1.size()[2], z2.size()[2])
560
- z1: Tensor = z1[:, :, :size]
561
- z2: Tensor = z2[:, :, :size]
562
- z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
563
- z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
564
- z1_triu = torch.triu(z1_rep)
565
- z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
566
-
567
- z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
568
- z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
569
- z2_triu = torch.triu(z2_rep)
570
- z2_def = z2_tril + z2_triu
571
- z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
572
- dim=-1)[0]
573
- elif self.right_unbound:
574
- timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
575
- And(Eventually(self.right_child, right_unbound=True,
576
- left_time_bound=self.left_time_bound),
577
- Eventually(Until(self.left_child, self.right_child, unbound=True),
578
- left_time_bound=self.left_time_bound, right_unbound=True)))
579
- z: Tensor = timed_until._boolean(x)
580
- else:
581
- timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
582
- And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
583
- right_time_bound=self.right_time_bound - 1),
584
- Eventually(Until(self.left_child, self.right_child, unbound=True),
585
- left_time_bound=self.left_time_bound, right_unbound=True)))
586
- z: Tensor = timed_until._boolean(x)
587
- return z
588
-
589
- def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
590
- if self.unbound:
591
- z1: Tensor = self.left_child._quantitative(x, normalize)
592
- z2: Tensor = self.right_child._quantitative(x, normalize)
593
- size: int = min(z1.size()[2], z2.size()[2])
594
- z1: Tensor = z1[:, :, :size]
595
- z2: Tensor = z2[:, :, :size]
596
-
597
- # z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
598
- # z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
599
- # z1_triu = torch.triu(z1_rep)
600
- # z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
601
-
602
- # z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
603
- # z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
604
- # z2_triu = torch.triu(z2_rep)
605
- # z2_def = z2_tril + z2_triu
606
- # z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
607
- # dim=-1)[0]
608
- z: Tensor = torch.cat([torch.max(torch.min(
609
- torch.cat([torch.cummin(z1[:, :, t:].unsqueeze(-1), dim=2)[0], z2[:, :, t:].unsqueeze(-1)], dim=-1),
610
- dim=-1)[0], dim=2, keepdim=True)[0] for t in range(size)], dim=2)
611
- elif self.right_unbound:
612
- timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
613
- And(Eventually(self.right_child, right_unbound=True,
614
- left_time_bound=self.left_time_bound),
615
- Eventually(Until(self.left_child, self.right_child, unbound=True),
616
- left_time_bound=self.left_time_bound, right_unbound=True)))
617
- z: Tensor = timed_until._quantitative(x, normalize=normalize)
618
- else:
619
- timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
620
- And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
621
- right_time_bound=self.right_time_bound-1),
622
- Eventually(Until(self.left_child, self.right_child, unbound=True),
623
- left_time_bound=self.left_time_bound, right_unbound=True)))
624
- z: Tensor = timed_until._quantitative(x, normalize=normalize)
625
- return z
626
-
627
- # from anchor_set_generation import anchorGeneration
628
-
629
- import re
630
- import json
631
- from typing import Any, Dict, List, Optional, Tuple, Union
632
- from transformers import PreTrainedTokenizer
633
- from transformers.utils import logging
634
-
635
- logger = logging.get_logger(__name__)
636
-
637
- #### utils ####
638
-
639
- def load_pickle(path):
640
- with open(path, 'rb') as f:
641
- x = pickle.load(f)
642
- return x
643
-
644
- def dump_pickle(name, thing):
645
- with open(name + '.pickle', 'wb') as f:
646
- pickle.dump(thing, f)
647
-
648
- def from_string_to_formula(st):
649
- root_arity = 2 if st.startswith('(') else 1
650
- st_split = st.split()
651
- if root_arity <= 1:
652
- root_op_str = copy.deepcopy(st_split[0])
653
- if root_op_str.startswith('x'):
654
- atom_sign = True if st_split[1] == '<=' else False
655
- root_phi = Atom(var_index=int(st_split[0][2]), lte=atom_sign, threshold=float(st_split[2]))
656
- return root_phi
657
- else:
658
- assert (root_op_str.startswith('not') or root_op_str.startswith('eventually')
659
- or root_op_str.startswith('always'))
660
- current_st = copy.deepcopy(st_split[2:-1])
661
- if root_op_str == 'not':
662
- root_phi = Not(child=from_string_to_formula(' '.join(current_st)))
663
- elif root_op_str.startswith('eventually'):
664
- unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
665
- root_phi = Eventually(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
666
- right_unbound=right_unbound, left_time_bound=left_time_bound,
667
- right_time_bound=right_time_bound)
668
- else:
669
- unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
670
- root_phi = Globally(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
671
- right_unbound=right_unbound, left_time_bound=left_time_bound,
672
- right_time_bound=right_time_bound)
673
- else:
674
- # 1 - delete everything which is contained in other sets of parenthesis (if any)
675
- current_st = copy.deepcopy(st_split[1:-1])
676
- if '(' in current_st:
677
- par_queue = deque()
678
- par_idx_list = []
679
- for i, sub in enumerate(current_st):
680
- if sub == '(':
681
- par_queue.append(i)
682
- elif sub == ')':
683
- par_idx_list.append(tuple([par_queue.pop(), i]))
684
- # open_par_idx, close_par_idx = [current_st.index(p) for p in ['(', ')']]
685
- # union of parentheses range --> from these we may extract the substrings to be the children!!!
686
- children_range = []
687
- for begin, end in sorted(par_idx_list):
688
- if children_range and children_range[-1][1] >= begin - 1:
689
- children_range[-1][1] = max(children_range[-1][1], end)
690
- else:
691
- children_range.append([begin, end])
692
- n_children = len(children_range)
693
- assert (n_children in [1, 2])
694
- if n_children == 1:
695
- # one of the children is a variable --> need to individuate it
696
- var_child_idx = 1 if children_range[0][0] <= 1 else 0 # 0 is left child, 1 is right child
697
- if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
698
- children_range[0][0] -= 1
699
- left_child_str = current_st[:3] if var_child_idx == 0 else \
700
- current_st[children_range[0][0]:children_range[0][1] + 1]
701
- right_child_str = current_st[-3:] if var_child_idx == 1 else \
702
- current_st[children_range[0][0]:children_range[0][1] + 1]
703
- root_op_str = current_st[children_range[0][1] + 1] if var_child_idx == 1 else \
704
- current_st[children_range[0][0] - 1]
705
- assert (root_op_str[:2] in ['an', 'or', 'un'])
706
- else:
707
- if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
708
- children_range[0][0] -= 1
709
- if current_st[children_range[1][0] - 1][0:2] in ['no', 'ev', 'al']:
710
- children_range[1][0] -= 1
711
- # if there are two children, with parentheses, the element in the middle is the root
712
- root_op_str = current_st[children_range[0][1] + 1]
713
- assert (root_op_str[:2] in ['an', 'or', 'un'])
714
- left_child_str = current_st[children_range[0][0]:children_range[0][1] + 1]
715
- right_child_str = current_st[children_range[1][0]:children_range[1][1] + 1]
716
- else:
717
- # no parentheses means that both children are variables
718
- left_child_str = current_st[:3]
719
- right_child_str = current_st[-3:]
720
- root_op_str = current_st[3]
721
- left_child_str = ' '.join(left_child_str)
722
- right_child_str = ' '.join(right_child_str)
723
- if root_op_str == 'and':
724
- root_phi = And(left_child=from_string_to_formula(left_child_str),
725
- right_child=from_string_to_formula(right_child_str))
726
- elif root_op_str == 'or':
727
- root_phi = Or(left_child=from_string_to_formula(left_child_str),
728
- right_child=from_string_to_formula(right_child_str))
729
- else:
730
- unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
731
- root_phi = Until(left_child=from_string_to_formula(left_child_str),
732
- right_child=from_string_to_formula(right_child_str),
733
- unbound=unbound, right_unbound=right_unbound, left_time_bound=left_time_bound,
734
- right_time_bound=right_time_bound)
735
- return root_phi
736
-
737
- def load_json(path: str) -> Union[Dict, List]:
738
- """
739
- Load a JSON file from the given path.
740
- Args:
741
- path (str): The path to the JSON file to be loaded.
742
-
743
- Returns:
744
- Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list.
745
- """
746
- with open(path, "r") as f:
747
- return json.load(f)
748
-
749
- #### phis_generator ####
750
-
751
- class StlGenerator:
752
- def __init__(
753
- self,
754
- leaf_prob: float = 0.3,
755
- inner_node_prob: list = None,
756
- threshold_mean: float = 0.0,
757
- threshold_sd: float = 1.0,
758
- unbound_prob: float = 0.1,
759
- right_unbound_prob: float = 0.2,
760
- time_bound_max_range: float = 20,
761
- adaptive_unbound_temporal_ops: bool = True,
762
- max_timespan: int = 100,
763
- ):
764
- """
765
- leaf_prob
766
- probability of generating a leaf (always zero for root)
767
- node_types = ["not", "and", "or", "always", "eventually", "until"]
768
- Inner node types
769
- inner_node_prob
770
- probability vector for the different types of internal nodes
771
- threshold_mean
772
- threshold_sd
773
- mean and std for the normal distribution of the thresholds of atoms
774
- unbound_prob
775
- probability of a temporal operator to have a time bound o the type [0,infty]
776
- time_bound_max_range
777
- maximum value of time span of a temporal operator (i.e. max value of t in [0,t])
778
- adaptive_unbound_temporal_ops
779
- if true, unbounded temporal operators are computed from current point to the end of the signal, otherwise
780
- they are evaluated only at time zero.
781
- max_timespan
782
- maximum time depth of a formula.
783
- """
784
-
785
- # Address the mutability of default arguments
786
- if inner_node_prob is None:
787
- inner_node_prob = [0.166, 0.166, 0.166, 0.17, 0.166, 0.166]
788
-
789
- self.leaf_prob = leaf_prob
790
- self.inner_node_prob = inner_node_prob
791
- self.threshold_mean = threshold_mean
792
- self.threshold_sd = threshold_sd
793
- self.unbound_prob = unbound_prob
794
- self.right_unbound_prob = right_unbound_prob
795
- self.time_bound_max_range = time_bound_max_range
796
- self.adaptive_unbound_temporal_ops = adaptive_unbound_temporal_ops
797
- self.node_types = ["not", "and", "or", "always", "eventually", "until"]
798
- self.max_timespan = max_timespan
799
-
800
- def sample(self, nvars):
801
- """
802
- Samples a random formula with distribution defined in class instance parameters
803
- Parameters
804
- ----------
805
- nvars : number of variables of input signals
806
- how many variables the formula is expected to consider.
807
- Returns
808
- -------
809
- TYPE
810
- A random formula.
811
- """
812
- return self._sample_internal_node(nvars)
813
- def bag_sample(self, bag_size, nvars):
814
- """
815
- Samples a bag of bag_size formulae
816
- Parameters
817
- ----------
818
- bag_size : INT
819
- number of formulae.
820
- nvars : INT
821
- number of vars in formulae.
822
- Returns
823
- -------
824
- a list of formulae.
825
- """
826
- formulae = []
827
- for _ in range(bag_size):
828
- phi = self.sample(nvars)
829
- formulae.append(phi)
830
- return formulae
831
-
832
- def _sample_internal_node(self, nvars):
833
- # Declare & dummy-assign "idiom"
834
- node: Union[None, Node]
835
- node = None
836
- # choose node type
837
- nodetype = rnd.choice(self.node_types, p=self.inner_node_prob)
838
- while True:
839
- if nodetype == "not":
840
- n = self._sample_node(nvars)
841
- node = Not(n)
842
- elif nodetype == "and":
843
- n1 = self._sample_node(nvars)
844
- n2 = self._sample_node(nvars)
845
- node = And(n1, n2)
846
- elif nodetype == "or":
847
- n1 = self._sample_node(nvars)
848
- n2 = self._sample_node(nvars)
849
- node = Or(n1, n2)
850
- elif nodetype == "always":
851
- n = self._sample_node(nvars)
852
- unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
853
- node = Globally(
854
- n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
855
- )
856
- elif nodetype == "eventually":
857
- n = self._sample_node(nvars)
858
- unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
859
- node = Eventually(
860
- n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
861
- )
862
- elif nodetype == "until":
863
- n1 = self._sample_node(nvars)
864
- n2 = self._sample_node(nvars)
865
- unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
866
- node = Until(
867
- n1, n2, unbound, right_unbound, left_time_bound, right_time_bound
868
- )
869
-
870
- if (node is not None) and (node.time_depth() < self.max_timespan):
871
- return node
872
-
873
- def _sample_node(self, nvars):
874
- if rnd.rand() < self.leaf_prob:
875
- # sample a leaf
876
- var, thr, lte = self._get_atom(nvars)
877
- return Atom(var, thr, lte)
878
- else:
879
- return self._sample_internal_node(nvars)
880
-
881
- def _get_temporal_parameters(self):
882
- if rnd.rand() < self.unbound_prob:
883
- return True, False, 0, 0
884
- elif rnd.rand() < self.right_unbound_prob:
885
- return False, True, rnd.randint(self.time_bound_max_range), 1
886
- else:
887
- left_bound = rnd.randint(self.time_bound_max_range)
888
- return False, False, left_bound, rnd.randint(left_bound, self.time_bound_max_range) + 1
889
-
890
- def _get_atom(self, nvars):
891
- variable = rnd.randint(nvars)
892
- lte = rnd.rand() > 0.5
893
- threshold = rnd.normal(self.threshold_mean, self.threshold_sd)
894
- return variable, threshold, lte
895
-
896
- #### traj_measure ####
897
-
898
- class Measure:
899
- def sample(self, samples=100000, varn=2, points=100):
900
- # Must be overridden
901
- pass
902
-
903
- class BaseMeasure(Measure):
904
- def __init__(
905
- self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu"
906
- ):
907
- """
908
- Parameters
909
- ----------
910
- mu0 : mean of normal distribution of initial state, optional
911
- The default is 0.0.
912
- sigma0 : standard deviation of normal distribution of initial state, optional
913
- The default is 1.0.
914
- mu1 : DOUBLE, optional
915
- mean of normal distribution of total variation. The default is 0.0.
916
- sigma1 : standard deviation of normal distribution of total variation, optional
917
- The default is 1.0.
918
- q : DOUBLE, optional
919
- probability of change of sign in derivative. The default is 0.1.
920
- q0 : DOUBLE, optional
921
- probability of initial sign of derivative. The default is 0.5.
922
- device : 'cpu' or 'cuda', optional
923
- device on which to run the algorithm. The default is 'cpu'.
924
- Returns
925
- -------
926
- None.
927
- """
928
- self.mu0 = mu0
929
- self.sigma0 = sigma0
930
- self.mu1 = mu1
931
- self.sigma1 = sigma1
932
- self.q = q
933
- self.q0 = q0
934
- self.device = device
935
-
936
- def sample(self, samples=100000, varn=2, points=100):
937
- """
938
- Samples a set of trajectories from the basic measure space, with parameters
939
- passed to the sampler
940
- Parameters
941
- ----------
942
- points : INT, optional
943
- number of points per trajectory, including initial one. The default is 1000.
944
- samples : INT, optional
945
- number of trajectories. The default is 100000.
946
- varn : INT, optional
947
- number of variables per trajectory. The default is 2.
948
- Returns
949
- -------
950
- signal : samples x varn x points double pytorch tensor
951
- The sampled signals.
952
- """
953
- if self.device == "cuda" and not torch.cuda.is_available():
954
- raise RuntimeError("GPU card or CUDA library not available!")
955
-
956
- # generate unif RN
957
- signal = torch.rand(samples, varn, points, device=self.device)
958
- # first point is special - set to zero for the moment, and set one point to 1
959
- signal[:, :, 0] = 0.0
960
- signal[:, :, -1] = 1.0
961
- # sorting each trajectory
962
- signal, _ = torch.sort(signal, 2)
963
- # computing increments and storing them in points 1 to end
964
- signal[:, :, 1:] = signal[:, :, 1:] - signal[:, :, :-1]
965
- # generate initial state, according to a normal distribution
966
- signal[:, :, 0] = self.mu0 + self.sigma0 * torch.randn(signal[:, :, 0].size())
967
-
968
- # sampling change signs from bernoulli in -1, 1
969
- derivs = (1 - self.q) * torch.ones(samples, varn, points, device=self.device)
970
- derivs = 2 * torch.bernoulli(derivs) - 1
971
- # sampling initial derivative
972
- derivs[:, :, 0] = self.q0
973
- derivs[:, :, 0] = 2 * torch.bernoulli(derivs[:, :, 0]) - 1
974
- # taking the cumulative product along axis 2
975
- derivs = torch.cumprod(derivs, 2)
976
-
977
- # sampling total variation
978
- totvar = torch.pow(
979
- self.mu1 + self.sigma1 * torch.randn(samples, varn, 1, device=self.device),
980
- 2,
981
- )
982
- # multiplying total variation and derivatives and making initial point non-invasive
983
- derivs = derivs * totvar
984
- derivs[:, :, 0] = 1.0
985
-
986
- # computing trajectories by multiplying and then doing a cumulative sum
987
- signal = signal * derivs
988
- signal = torch.cumsum(signal, 2)
989
- return signal
990
-
991
- #### kernel ####
992
-
993
- realnum = Union[float, int]
994
-
995
- class StlKernel:
996
- def __init__(
997
- self,
998
- measure,
999
- normalize=True,
1000
- exp_kernel=True,
1001
- sigma2=0.2, # 0.5 meglio, inizialmente era a 0.2
1002
- integrate_time=False,
1003
- samples=100000,
1004
- varn=2,
1005
- points=100,
1006
- boolean=False,
1007
- signals=None,
1008
- ):
1009
- self.traj_measure = measure
1010
- self.exp_kernel = exp_kernel
1011
- self.normalize = normalize
1012
- self.sigma2 = sigma2
1013
- self.samples = samples
1014
- self.varn = varn
1015
- self.points = points
1016
- self.integrate_time = integrate_time
1017
- if signals is not None:
1018
- self.signals = signals
1019
- else:
1020
- self.signals = measure.sample(points=points, samples=samples, varn=varn)
1021
- self.boolean = boolean
1022
-
1023
- def compute(self, phi1, phi2):
1024
- return self.compute_one_one(phi1, phi2)
1025
-
1026
- def compute_one_one(self, phi1, phi2):
1027
- phis1: list = [phi1]
1028
- phis2: list = [phi2]
1029
- ker = self.compute_bag_bag(phis1, phis2)
1030
- return ker[0, 0]
1031
-
1032
- def compute_bag(self, phis, return_robustness=True):
1033
- if self.integrate_time:
1034
- rhos, selfk, len0 = self._compute_robustness_time(phis)
1035
- kernel_matrix = self._compute_kernel_time(
1036
- rhos, rhos, selfk, selfk, len0, len0
1037
- )
1038
- else:
1039
- rhos, selfk = self._compute_robustness_no_time(phis)
1040
- kernel_matrix = self._compute_kernel_no_time(rhos, rhos, selfk, selfk)
1041
- len0 = None
1042
- if return_robustness:
1043
- return kernel_matrix.cpu(), rhos, selfk, len0
1044
- else:
1045
- return kernel_matrix.cpu()
1046
-
1047
- def compute_one_bag(self, phi1, phis2, return_robustness=False):
1048
- phis1: list = [phi1]
1049
- return self.compute_bag_bag(phis1, phis2, return_robustness)
1050
-
1051
- def compute_bag_bag(self, phis1, phis2, return_robustness=False):
1052
- if self.integrate_time:
1053
- rhos1, selfk1, len1 = self._compute_robustness_time(phis1)
1054
- rhos2, selfk2, len2 = self._compute_robustness_time(phis2)
1055
- kernel_matrix = self._compute_kernel_time(
1056
- rhos1, rhos2, selfk1, selfk2, len1, len2
1057
- )
1058
- else:
1059
- rhos1, selfk1 = self._compute_robustness_no_time(phis1)
1060
- rhos2, selfk2 = self._compute_robustness_no_time(phis2)
1061
- len1, len2 = [None, None]
1062
- kernel_matrix = self._compute_kernel_no_time(rhos1, rhos2, selfk1, selfk2)
1063
- if return_robustness:
1064
- return kernel_matrix.cpu(), rhos1, rhos2, selfk1, selfk2, len1, len2
1065
- else:
1066
- return kernel_matrix.cpu()
1067
-
1068
- def compute_one_from_robustness(self, phi, rhos, rho_self, lengths=None, return_robustness=False):
1069
- phis: list = [phi]
1070
- return self.compute_bag_from_robustness(phis, rhos, rho_self, lengths, return_robustness)
1071
-
1072
- def compute_bag_from_robustness(self, phis, rhos, rho_self, lengths=None, return_robustness=False):
1073
- if self.integrate_time:
1074
- rhos1, selfk1, len1 = self._compute_robustness_time(phis)
1075
- kernel_matrix = self._compute_kernel_time(
1076
- rhos1, rhos, selfk1, rho_self, len1, lengths
1077
- )
1078
- else:
1079
- rhos1, selfk1 = self._compute_robustness_no_time(phis)
1080
- len1 = None
1081
- kernel_matrix = self._compute_kernel_no_time(rhos1, rhos, selfk1, rho_self)
1082
- if return_robustness:
1083
- return kernel_matrix.cpu(), rhos1, selfk1, len1
1084
- else:
1085
- return kernel_matrix.cpu()
1086
- n = self.samples
1087
- p = self.points
1088
- k = len(phis)
1089
- rhos = torch.zeros((k, n, p), device="cpu")
1090
- lengths = torch.zeros(k)
1091
- self_kernels = torch.zeros((k, 1))
1092
- for i, phi in enumerate(phis):
1093
- if self.boolean:
1094
- rho = phi.boolean(self.signals, evaluate_at_all_times=True).float()
1095
- rho[rho == 0.0] = -1.0
1096
- else:
1097
- rho = phi.quantitative(self.signals, evaluate_at_all_times=True)
1098
- actual_p = rho.size()[2]
1099
- rho = rho.reshape(n, actual_p).cpu()
1100
- rhos[i, :, :actual_p] = rho
1101
- lengths[i] = actual_p
1102
- self_kernels[i] = torch.tensordot(
1103
- rho.reshape(1, n, -1), rho.reshape(1, n, -1), dims=[[1, 2], [1, 2]]
1104
- ) / (actual_p * n)
1105
- return rhos, self_kernels, lengths
1106
-
1107
- def _compute_robustness_no_time(self, phis):
1108
- n = self.samples
1109
- k = len(phis)
1110
- rhos = torch.zeros((k, n), device=self.traj_measure.device)
1111
- self_kernels = torch.zeros((k, 1), device=self.traj_measure.device)
1112
- for i, phi in enumerate(phis):
1113
- if self.boolean:
1114
- rho = phi.boolean(self.signals, evaluate_at_all_times=False).float()
1115
- rho[rho == 0.0] = -1.0
1116
- else:
1117
- rho = phi.quantitative(self.signals, evaluate_at_all_times=False)
1118
- self_kernels[i] = rho.dot(rho) / n
1119
- rhos[i, :] = rho
1120
- return rhos, self_kernels
1121
-
1122
- def _compute_kernel_time(self, rhos1, rhos2, selfk1, selfk2, len1, len2):
1123
- kernel_matrix = torch.tensordot(rhos1, rhos2, [[1, 2], [1, 2]])
1124
- length_normalizer = self._compute_trajectory_length_normalizer(len1, len2)
1125
- kernel_matrix = kernel_matrix * length_normalizer / self.samples
1126
- if self.normalize:
1127
- kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
1128
- if self.exp_kernel:
1129
- kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
1130
- return kernel_matrix
1131
-
1132
- def _compute_kernel_no_time(self, rhos1, rhos2, selfk1, selfk2):
1133
- kernel_matrix = torch.tensordot(rhos1, rhos2, [[1], [1]])
1134
- kernel_matrix = kernel_matrix / self.samples
1135
- if self.normalize:
1136
- kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
1137
- if self.exp_kernel:
1138
- kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
1139
- return kernel_matrix
1140
-
1141
- @staticmethod
1142
- def _normalize(kernel_matrix, selfk1, selfk2):
1143
- normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
1144
- kernel_matrix = kernel_matrix / normalize
1145
- return kernel_matrix
1146
-
1147
- @staticmethod
1148
- def _normalize(kernel_matrix, selfk1, selfk2):
1149
- normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
1150
- kernel_matrix = kernel_matrix / normalize
1151
- return kernel_matrix
1152
-
1153
- def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None):
1154
- if sigma2 is None:
1155
- sigma2 = self.sigma2
1156
- if self.normalize:
1157
- # selfk is (1.0^2 + 1.0^2)
1158
- selfk = 2.0
1159
- else:
1160
- k1 = selfk1.size()[0]
1161
- k2 = selfk2.size()[0]
1162
- selfk = (selfk1 * selfk1).repeat(1, k2) + torch.transpose(
1163
- selfk2 * selfk2, 0, 1
1164
- ).repeat(k1, 1)
1165
- return torch.exp(-(selfk - 2 * kernel_matrix) / (2 * sigma2))
1166
-
1167
- @staticmethod
1168
- def _compute_trajectory_length_normalizer(len1, len2):
1169
- k1 = len1.size()[0]
1170
- k2 = len2.size()[0]
1171
- y1 = len1.reshape(-1, 1)
1172
- y1 = y1.repeat(1, k2)
1173
- y2 = len2.repeat(k1, 1)
1174
- return 1.0 / torch.min(y1, y2)
1175
-
1176
- class GramMatrix:
1177
- def __init__(self, kernel, formulae, store_robustness=True, sample=False, sampler=None, bag_size=None):
1178
- self.kernel = kernel
1179
- self.formulae_list = formulae
1180
- # if kernel is computed from robustness at time zero only,
1181
- # we store the robustness for each formula and each sample
1182
- # to speed up computation later
1183
- self.store_robustness = store_robustness
1184
- self.dim = len(self.formulae_list) if not bag_size else int(bag_size)
1185
- self.sample = sample # whether to generate formulae in a controlled manner
1186
- if self.sample:
1187
- self.t = 0.99 if self.kernel.boolean else 0.85
1188
- self.sampler = sampler # stl formulae generator
1189
- self._compute_gram_matrix()
1190
-
1191
- def _compute_gram_matrix(self):
1192
- if self.sample:
1193
- gram = torch.zeros(self.dim, self.dim)
1194
- rhos = torch.zeros((self.dim, self.kernel.samples), device=self.kernel.traj_measure.device) if \
1195
- not self.kernel.integrate_time else torch.zeros((self.dim, self.kernel.samples, self.kernel.points),
1196
- device=self.kernel.traj_measure.device)
1197
- lengths = torch.zeros(self.dim) if self.kernel.integrate_time else np.zeros(self.dim)
1198
- kernels = torch.zeros((self.dim, 1), device=self.kernel.traj_measure.device)
1199
- phis = [self.sampler.sample(nvars=self.kernel.varn)]
1200
- gram[0, :1], rhos[0], kernels[0, :], lengths[0] = self.kernel.compute_bag(phis, return_robustness=True)
1201
- while len(phis) < self.dim:
1202
- i = len(phis)
1203
- phi = self.sampler.sample(nvars=self.kernel.varn)
1204
- gram[i, :i], rhos[i], kernels[i, :], lengths[i] = self.kernel.compute_one_from_robustness(
1205
- phi, rhos[:i, :], kernels[:i, :], lengths[:i], return_robustness=True)
1206
- if torch.sum(gram[i, :i + 1] >= self.t) < 3:
1207
- phis.append(phi)
1208
- gram[:i, i] = gram[i, :i]
1209
- gram[i, i] = kernels[i, :]
1210
-
1211
- self.formulae_list = phis
1212
- self.gram = gram.cpu()
1213
- self.robustness = rhos if self.store_robustness else None
1214
- self.self_kernels = kernels if self.store_robustness else None
1215
- self.robustness_lengths = lengths if self.store_robustness else None
1216
- else:
1217
- if self.store_robustness:
1218
- k_matrix, rhos, selfk, len0 = self.kernel.compute_bag(
1219
- self.formulae_list, return_robustness=True
1220
- )
1221
- self.gram = k_matrix
1222
- self.robustness = rhos
1223
- self.self_kernels = selfk
1224
- self.robustness_lengths = len0
1225
- else:
1226
- self.gram = self.kernel.compute_bag(
1227
- self.formulae_list, return_robustness=False
1228
- )
1229
- self.robustness = None
1230
- self.self_kernels = None
1231
- self.robustness_lengths = None
1232
-
1233
- def compute_kernel_vector(self, phi):
1234
- if self.store_robustness:
1235
- return self.kernel.compute_one_from_robustness(
1236
- phi, self.robustness, self.self_kernels, self.robustness_lengths
1237
- )
1238
- else:
1239
- return self.kernel.compute_one_bag(phi, self.formulae_list)
1240
-
1241
- def compute_bag_kernel_vector(self, phis, generate_phis=False, bag_size=None):
1242
- if generate_phis:
1243
- gram_test = torch.zeros(bag_size, self.dim) # self.dim, bag_size
1244
- rhos_test = torch.zeros((bag_size, self.kernel.samples), device=self.kernel.traj_measure.device) if \
1245
- not self.kernel.integrate_time else torch.zeros((bag_size, self.kernel.samples, self.kernel.points),
1246
- device=self.kernel.traj_measure.device)
1247
- lengths_test = torch.zeros(bag_size) if self.kernel.integrate_time else np.zeros(bag_size)
1248
- kernels_test = torch.zeros((bag_size, 1), device=self.kernel.traj_measure.device)
1249
- phi_test = []
1250
- while len(phi_test) < bag_size:
1251
- i = len(phi_test)
1252
- phi = self.sampler.sample(nvars=self.kernel.varn)
1253
- if self.store_robustness:
1254
- gram_test[i, :], rhos_test[i], kernels_test[i, :], lengths_test[i] = \
1255
- self.kernel.compute_one_from_robustness(phi, self.robustness, self.self_kernels,
1256
- self.robustness_lengths, return_robustness=True)
1257
- else:
1258
- gram_test[i, :], rhos_test[i], _, kernels_test[i, :], _, lengths_test[i], _ = \
1259
- self.kernel.compute_one_bag(phi, self.formulae_list, return_robustness=True)
1260
- if not ((rhos_test[i] > 0).all() or (rhos_test[i] < 0).all()):
1261
- phi_test.append(phi)
1262
- return phi_test, gram_test.cpu()
1263
- else:
1264
- if self.store_robustness:
1265
- return self.kernel.compute_bag_from_robustness(
1266
- phis, self.robustness, self.self_kernels, self.robustness_lengths
1267
- )
1268
- else:
1269
- return self.kernel.compute_bag_bag(phis, self.formulae_list)
1270
-
1271
- def invert_regularized(self, alpha):
1272
- regularizer = abs(pow(10, alpha)) * torch.eye(self.dim)
1273
- return torch.inverse(self.gram + regularizer)
1274
-
1275
- #### anchor_generation ####
1276
-
1277
- def anchorGeneration(diff_init = False, # to control whether we want formulae to be semantically different by construction
1278
- embed_dim: int = 30, # embedding dimension, aka number of generated formulae in the anchor set
1279
- n_vars: int = 3, # dimension of the input signal (3D in this case)
1280
- leaf_prob: float = 0.4, # complexity of the generated formula
1281
- cosine_similarity_threshold: float = 0.8 # if two formulae cosine similarity exceeds 0.9, then discard one of the two
1282
- ) -> str:
1283
-
1284
- # initialize STL formula generator
1285
- sampler = StlGenerator(leaf_prob)
1286
-
1287
- # effective anchor set generation
1288
- if diff_init:
1289
-
1290
- # initialize the anchor set with a randomly sampled formula
1291
- diff_anchor_set = [sampler.sample(nvars=n_vars)]
1292
-
1293
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1294
- mu = BaseMeasure(device=device)
1295
-
1296
- # generates a set of random signals working as a tester for the formulae testing
1297
- signals = mu.sample(samples=10000, varn=n_vars)
1298
-
1299
- # computes robustness value for the initial set of formulae in the anchor set
1300
- anchor_rob_vectors = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in diff_anchor_set], 0)
1301
-
1302
- while len(diff_anchor_set) < embed_dim:
1303
- # sample the 'remaining' formulae to reach the desired number of `embed_dim` formulae:
1304
- candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars)
1305
-
1306
- # compute robustness of candidate anchor formulae on the same signals as previous anchor set
1307
- candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0)
1308
-
1309
- # compute cosine similarity between current anchor set and candidate new formulae
1310
- cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1)
1311
-
1312
- # check which formulae are similar (i.e. greater cosine similarity then threshold) w.r.t. current anchors
1313
- # NOTA: chiedere a gaia se cosine similarities negative vanno ammazzate con un valore assoluto o meno!
1314
- similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])]
1315
-
1316
- # keep only those who are semantically distant
1317
- keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist])))
1318
-
1319
- diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx]
1320
-
1321
- # Convert keep_idx to a tensor on the same device as candidate_robs
1322
- keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device)
1323
-
1324
- # Use index_select to pick the relevant rows
1325
- selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor)
1326
-
1327
- # Concatenate on the same device
1328
- anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0)
1329
-
1330
- anchor_set = diff_anchor_set[:embed_dim]
1331
-
1332
- else:
1333
- anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars)
1334
-
1335
- filename = f'anchor_set_no_diff_{embed_dim}_dim'
1336
- dump_pickle(filename, anchor_set)
1337
- return filename
1338
-
1339
- ####
1340
-
1341
- """
1342
- A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process.
1343
- This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs,
1344
- and handle padding and special tokens.
1345
- """
1346
-
1347
- def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad",
1348
- bos_token: str = "/s", eos_token: str = "s", model_max_length = 512, *args, **kwargs):
1349
- """
1350
- Initializes the STLTokenizer with a given vocabulary and special tokens.
1351
- Args:
1352
- vocab_path (str): The path to the JSON file containing the vocabulary.
1353
- unk_token (str, optional): The token used for unknown words. Defaults to "unk".
1354
- pad_token (str, optional): The token used for padding. Defaults to "pad".
1355
- bos_token (str, optional): The token used for the beginning of a sequence. Defaults to "/s".
1356
- eos_token (str, optional): The token used for the end of a sequence. Defaults to "s".
1357
- """
1358
- self.vocab = load_json(vocab_path)
1359
- self.unk_token = unk_token
1360
- self.pad_token = pad_token
1361
- self.bos_token = bos_token
1362
- self.eos_token = eos_token
1363
- self.model_max_length = model_max_length
1364
- self.id_to_token = {v: k for k, v in self.vocab.items()} # Reverse mapping
1365
- super().__init__(unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token,
1366
- model_max_length=model_max_length, *args, **kwargs)
1367
-
1368
- @property
1369
- def vocab_size(self) -> int:
1370
- """
1371
- Returns the size of the vocabulary.
1372
- Returns:
1373
- int: The number of tokens in the vocabulary.
1374
- """
1375
- return len(self.vocab)
1376
-
1377
- def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False):
1378
- """
1379
- Replaces spaces in the input sequence with a specified token.
1380
- Args:
1381
- sequence (str): The input sequence.
1382
- undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces.
1383
- Returns:
1384
- str: The preprocessed sequence with spaces or padding tokens replaced.
1385
- """
1386
- if undo:
1387
- return sequence.replace(new_space_token, space_token)
1388
- else:
1389
- return sequence.replace(space_token, new_space_token)
1390
-
1391
- def add_bos_eos(self, sequence: str) -> str:
1392
- """
1393
- Aggiunge i token BOS all'inizio e EOS alla fine della sequenza.
1394
- Args:
1395
- sequence (str): La sequenza di input.
1396
- Returns:
1397
- str: La sequenza con i token BOS ed EOS.
1398
- """
1399
- return f'{self.bos_token} {sequence} {self.eos_token}'
1400
-
1401
- def tokenize(self, text: str) -> List[str]:
1402
- """
1403
- Tokenizes the input text into a list of tokens.
1404
- The method preprocesses the input text by replacing spaces with padding tokens and then tries to
1405
- find the longest possible match for each substring in the vocabulary.
1406
- Args:
1407
- text (str): The input text to be tokenized.
1408
- Returns:
1409
- List[str]: A list of tokens representing the tokenized text.
1410
- """
1411
- text = self.add_bos_eos(text)
1412
- text = self.prepad_sequence(text)
1413
- tokens = []
1414
- i = 0
1415
- while i < len(text):
1416
- best_match = None
1417
- for j in range(len(text), i, -1): # Try matching substrings of decreasing length
1418
- subtoken = text[i:j]
1419
- if subtoken in self.vocab:
1420
- best_match = subtoken
1421
- break
1422
- if best_match:
1423
- tokens.append(best_match)
1424
- i += len(best_match)
1425
- else:
1426
- tokens.append(self.unk_token)
1427
- i += 1
1428
- return tokens
1429
-
1430
- def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
1431
- """
1432
- Converts a list of tokens into a list of token IDs.
1433
- Args:
1434
- tokens (List[str]): A list of tokens to be converted into IDs.
1435
- Returns:
1436
- List[int]: A list of corresponding token IDs.
1437
- """
1438
- return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
1439
-
1440
- def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
1441
- """
1442
- Converts a list of token IDs into a list of tokens.
1443
- Args:
1444
- ids (List[int]): A list of token IDs to be converted into tokens.
1445
- Returns:
1446
- List[str]: A list of corresponding tokens.
1447
- """
1448
- return [self.id_to_token.get(i, self.unk_token) for i in ids]
1449
-
1450
- def encode(self, sequence: str) -> List[int]:
1451
- """
1452
- Encodes a string sequence into a list of token IDs.
1453
-
1454
- This method tokenizes the input sequence using the `tokenize` method,
1455
- and then converts the resulting tokens into their corresponding token IDs
1456
- using the `convert_tokens_to_ids` method.
1457
-
1458
- Args:
1459
- sequence (str): The input sequence (text) to be encoded.
1460
-
1461
- Returns:
1462
- List[int]: A list of token IDs corresponding to the input sequence.
1463
- """
1464
- splitted_sequence = self.tokenize(sequence)
1465
- return self.convert_tokens_to_ids(splitted_sequence)
1466
-
1467
- def postpad_sequence(self, sequence, pad_token_id):
1468
- """
1469
- Fills the sequence up to max_length padding elements
1470
- """
1471
- num_extra_elements = self.model_max_length - len(sequence) -1
1472
- if num_extra_elements > 0:
1473
- sequence.extend([pad_token_id] * num_extra_elements)
1474
- return sequence
1475
-
1476
- def decode(self, token_ids: List[int]) -> str:
1477
- """
1478
- Decodes a list of token IDs into a string of text.
1479
- The method converts the IDs to tokens and joins them to form a string.
1480
- It also restores the original spaces or padding tokens if `undo` is True.
1481
- Args:
1482
- token_ids (List[int]): A list of token IDs to be decoded.
1483
- skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False.
1484
- Returns:
1485
- str: The decoded string.
1486
- """
1487
- tokens = self.convert_ids_to_tokens(token_ids)
1488
- decoded = "".join(tokens)
1489
- return self.prepad_sequence(decoded, undo=True)
1490
-
1491
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
1492
- """
1493
- Saves the tokenizer's vocabulary to a file.
1494
- Useful only when the vocabulary has to be retrieved and is not given
1495
- (thus this is not the case: here to further improvements with sentencepiece).
1496
- This method saves the vocabulary to a JSON file in the specified directory.
1497
- Args:
1498
- save_directory (str): The directory where the vocabulary file will be saved.
1499
- filename_prefix (Optional[str]): An optional prefix for the filename.
1500
- Returns:
1501
- Tuple[str]: A tuple containing the path to the saved vocabulary file.
1502
- """
1503
- vocab_file = f"{save_directory}/{filename_prefix + '-' if filename_prefix else ''}vocab.json"
1504
- with open(vocab_file, "w", encoding="utf-8") as f:
1505
- json.dump(self.vocab, f, indent=2, ensure_ascii=False)
1506
- return (vocab_file,)
1507
-
1508
- def get_vocab(self) -> dict:
1509
- """
1510
- Retrieves the vocabulary used by the tokenizer.
1511
- Returns:
1512
- dict: The vocabulary as a dictionary.
1513
- """
1514
- return self.vocab
1515
-
1516
- class STLSinusoidalPositionalEmbedding(nn.Embedding):
1517
- """This module produces sinusoidal positional embeddings of any length."""
1518
-
1519
- def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
1520
- super().__init__(num_positions, embedding_dim)
1521
- self.weight = self._init_weight(self.weight)
1522
-
1523
- @staticmethod
1524
- def _init_weight(out: nn.Parameter) -> nn.Parameter:
1525
- """
1526
- Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
1527
- the 2nd half of the vector. [dim // 2:]
1528
- """
1529
- n_pos, dim = out.shape
1530
- position_enc = np.array(
1531
- [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
1532
- )
1533
- out.requires_grad = False # set early to avoid an error in pytorch-1.8+
1534
- sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
1535
- out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
1536
- out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
1537
- out.detach_()
1538
- return out
1539
- @torch.no_grad()
1540
- def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
1541
- """`input_ids_shape` is expected to be [bsz x seqlen]."""
1542
- bsz, seq_len = input_ids_shape[:2]
1543
- positions = torch.arange(
1544
- past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
1545
- )
1546
- return super().forward(positions)
1547
-
1548
- class STLAttention(nn.Module):
1549
- """ Multi-Head Attention as depicted from 'Attention is all you need' """
1550
-
1551
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0,
1552
- is_decoder: bool = False, bias: bool = False, is_causal: bool = False):
1553
-
1554
- super().__init__()
1555
- self.embed_dim = embed_dim # overall embedding dimension -> to be divided between multiple heads
1556
- self.num_heads = num_heads
1557
- self.dropout = dropout
1558
- self.head_dim = embed_dim // num_heads
1559
- assert (self.head_dim * num_heads) == self.embed_dim
1560
- self.scaling = self.head_dim ** -0.5 # used to normalize values when projected using `W_` matrices
1561
- self.is_decoder = is_decoder
1562
- self.is_causal = is_causal
1563
-
1564
- # 'roleplaying' matrices
1565
- self.W_k = nn.Linear(embed_dim, embed_dim, bias = bias)
1566
- self.W_q = nn.Linear(embed_dim, embed_dim, bias = bias)
1567
- self.W_v = nn.Linear(embed_dim, embed_dim, bias = bias)
1568
-
1569
- # to project the heads' outputs into a single vector
1570
- self.W_o = nn.Linear(embed_dim, embed_dim, bias = bias)
1571
-
1572
-
1573
- def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
1574
- return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
1575
-
1576
-
1577
- def forward(self,
1578
- hidden_states: torch.Tensor, # previous values, passed to the multi-head attn layer
1579
- key_value_states: Optional[torch.Tensor] = None, # different key, value items (used in cross-attn)
1580
- past_key_value: Optional[Tuple[torch.Tensor]] = None, # stores the key and values of previous steps
1581
- attention_mask: Optional[torch.Tensor] = None, # masks non-allowed items (padded or future ones)
1582
- layer_head_mask: Optional[torch.Tensor] = None, # used to de-activate specific attn heads
1583
- output_attentions: bool = False # flag to control the output of the attn values,
1584
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1585
-
1586
- is_cross_attention = key_value_states is not None # cross-attn if key_value_states is not None
1587
-
1588
- batch_size, tgt_len, embed_dim = hidden_states.size()
1589
-
1590
- # Project the current input in the `query` role:
1591
- query = self.W_q(hidden_states) * self.scaling
1592
-
1593
- if (is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]):
1594
- key = past_key_value[0]
1595
- value = past_key_value[1]
1596
- elif is_cross_attention:
1597
- key = self._shape(self.W_k(key_value_states), -1, batch_size)
1598
- value = self._shape(self.W_v(key_value_states), -1, batch_size)
1599
- elif past_key_value is not None:
1600
- key = self._shape(self.W_k(hidden_states), -1, batch_size)
1601
- value = self._shape(self.W_v(hidden_states), -1, batch_size)
1602
- key = torch.cat([past_key_value[0], key], dim=2)
1603
- value = torch.cat([past_key_value[1], value], dim=2)
1604
- else:
1605
- key = self._shape(self.W_k(hidden_states), -1, batch_size)
1606
- value = self._shape(self.W_v(hidden_states), -1, batch_size)
1607
- if self.is_decoder:
1608
- past_key_value = (key, value)
1609
-
1610
- proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
1611
-
1612
- query = self._shape(query, tgt_len, batch_size).view(*proj_shape)
1613
- key = key.reshape(*proj_shape)
1614
- value = value.reshape(*proj_shape)
1615
-
1616
- src_len = key.size(1)
1617
-
1618
-
1619
- ######################################################################################################
1620
-
1621
- # 'traditional' attention computation
1622
- # i.e. softmax(Q*K^T / sqrt(d_model) + self_attn_mask) * V
1623
-
1624
- # Batch-wise matrix multiplication between `query` and (TRANSPOSED) `key`
1625
- attn_weights = torch.bmm(query, key.transpose(1, 2))
1626
-
1627
- if attention_mask is not None:
1628
- attn_weights = attn_weights.view(batch_size, self.num_heads, tgt_len, src_len) + attention_mask
1629
- attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
1630
-
1631
- # Normalize values on the `key` axis (dim=-1)
1632
- attn_weights = F.softmax(attn_weights, dim=-1)
1633
-
1634
- # if layer_head_mask is not None:
1635
- # attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(batch_size, self.num_heads, tgt_len, src_len)
1636
- # attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
1637
-
1638
- attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
1639
-
1640
- # Batch-wise matrix multiplication between the resulting probs and the value
1641
- attn_output = torch.bmm(attn_probs, value)
1642
-
1643
- ######################################################################################################
1644
-
1645
- attn_output = attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim)
1646
- attn_output = attn_output.transpose(1, 2)
1647
-
1648
- attn_output = attn_output.reshape(batch_size, tgt_len, self.embed_dim)
1649
- attn_output = self.W_o(attn_output)
1650
-
1651
- return attn_output, None, past_key_value
1652
-
1653
- ####
1654
-
1655
- class STLEncoder():
1656
- def __init__(self,
1657
- embed_dim: int,
1658
- anchor_filename: Optional[str] = None,
1659
- n_vars: int = 3):
1660
-
1661
- self.n_vars = n_vars # passaglielo in input
1662
- self.embed_dim = embed_dim
1663
- self.anchorset_filename = anchor_filename
1664
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1665
- self.mu = BaseMeasure(device=self.device)
1666
- self.kernel = StlKernel(self.mu, varn=self.n_vars)
1667
-
1668
- if anchor_filename is None:
1669
- anchor_filename = anchorGeneration(diff_init = True, embed_dim = self.embed_dim, n_vars = self.n_vars)
1670
- anchor_filename+='.pickle'
1671
-
1672
- # TO DO: check on the dimensions of the anchor set and the `embed_dim` and `n_vars` values
1673
- anchor_set = load_pickle(anchor_filename)
1674
- if len(anchor_set) != self.embed_dim:
1675
- raise ValueError("The anchor set and the embedding dimension do not match!")
1676
-
1677
- self.anchor_set = anchor_set
1678
-
1679
- def compute_embeddings(self, formula: List[str]):
1680
- return self.kernel.compute_bag_bag(formula, self.anchor_set)
1681
-
1682
- class STLModel(PreTrainedModel):
1683
- config_class = STLConfig
1684
- base_model_prefix = "model"
1685
- supports_gradient_checkpointing = True
1686
-
1687
- # initializes the weights of `nn.Linear`, `nn.Embedding` and `STLSinusoidalPositionalEmbedding`
1688
- def _init_weights(self, module: Union[nn.Linear, nn.Embedding, STLSinusoidalPositionalEmbedding]):
1689
- std = self.config.init_std
1690
- if isinstance(module, nn.Linear):
1691
- module.weight.data.normal_(mean=0.0, std=std)
1692
- if module.bias is not None:
1693
- module.bias.data.zero_()
1694
- elif isinstance(module, STLSinusoidalPositionalEmbedding):
1695
- pass
1696
- elif isinstance(module, nn.Embedding):
1697
- module.weight.data.normal_(mean=0.0, std=std)
1698
- if module.padding_idx is not None:
1699
- module.weight.data[module.padding_idx].zero_()
1700
-
1701
- @property
1702
- def dummy_inputs(self):
1703
- pad_token = self.config.pad_token_id
1704
- input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
1705
- dummy_inputs = {
1706
- "attention_mask": input_ids.ne(pad_token),
1707
- "input_ids": input_ids,
1708
- "decoder_input_ids": input_ids,
1709
- }
1710
- return dummy_inputs
1711
-
1712
- class STLDecoderBlock(nn.Module):
1713
-
1714
- def __init__(self, embed_dim: int,
1715
- num_decoder_attention_heads: int,
1716
- num_decoder_ffn_dim: int,
1717
- dropout: float = 0.0,
1718
- attention_dropout: float = 0.0,
1719
- activation_dropout: float = 0.0,
1720
- ):
1721
-
1722
- super().__init__()
1723
-
1724
- self.embed_dim = embed_dim
1725
-
1726
- # first block
1727
- self.self_attn = STLAttention(
1728
- embed_dim=self.embed_dim,
1729
- num_heads=num_decoder_attention_heads,
1730
- dropout=dropout,
1731
- is_decoder=True, # not used, debugging purposes
1732
- is_causal=True, # not used, debugging purposes
1733
- )
1734
- self.dropout = dropout
1735
- self.activation_fn = nn.functional.gelu
1736
- self.activation_dropout = activation_dropout
1737
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
1738
-
1739
- # second block
1740
- self.encoder_attn = STLAttention(
1741
- self.embed_dim,
1742
- num_decoder_attention_heads,
1743
- dropout=attention_dropout,
1744
- is_decoder=True, # not used, debugging purposes
1745
- )
1746
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
1747
-
1748
- # third block
1749
- self.fc1 = nn.Linear(self.embed_dim, num_decoder_ffn_dim)
1750
- self.fc2 = nn.Linear(num_decoder_ffn_dim, self.embed_dim)
1751
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
1752
-
1753
-
1754
- def forward(
1755
- self,
1756
- hidden_states: torch.Tensor,
1757
- attention_mask: Optional[torch.Tensor] = None,
1758
- encoder_hidden_states: Optional[torch.Tensor] = None,
1759
- encoder_attention_mask: Optional[torch.Tensor] = None,
1760
- layer_head_mask: Optional[torch.Tensor] = None,
1761
- cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
1762
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
1763
- output_attentions: Optional[bool] = False,
1764
- use_cache: Optional[bool] = True,
1765
- **kwargs,
1766
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1767
- """
1768
- Args:
1769
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1770
- attention_mask (`torch.FloatTensor`): attention mask of size
1771
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
1772
- encoder_hidden_states (`torch.FloatTensor`):
1773
- cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
1774
- encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
1775
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
1776
- layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
1777
- `(encoder_attention_heads,)`.
1778
- cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
1779
- size `(decoder_attention_heads,)`.
1780
- past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
1781
- output_attentions (`bool`, *optional*):
1782
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1783
- returned tensors for more detail.
1784
- """
1785
-
1786
- ###################################################################
1787
-
1788
- # BLOCK 1: processing what has been previously generated
1789
-
1790
- # previous state is stored into an auxiliary variable `residual`
1791
- residual = hidden_states
1792
-
1793
- # tries to exploit previous K, V values if there are any
1794
- # (practically picks up to the first 2 values stored in `past_key_value` vector)
1795
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
1796
-
1797
- # masked MHSA on the already generated sequence
1798
- # invokes `forward` method to transform the original vector accordingly
1799
- hidden_states, self_attn_weights, present_key_value = self.self_attn.forward(
1800
- hidden_states=hidden_states, # Q
1801
- past_key_value=self_attn_past_key_value, # K, V
1802
- attention_mask=attention_mask, # passed as input of the decoder layer
1803
- layer_head_mask=layer_head_mask, # to deactivate certain attn layers
1804
- output_attentions=output_attentions,
1805
- )
1806
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1807
-
1808
- # residual connection
1809
- hidden_states = residual + hidden_states
1810
-
1811
- # normalization
1812
- hidden_states = self.self_attn_layer_norm(hidden_states)
1813
-
1814
- ###################################################################
1815
-
1816
- # BLOCK 2: cross-attn between already generated input and previous information (from the encoder)
1817
-
1818
- # initialize K, Q, attn_weights for this new attn operation
1819
- cross_attn_present_key_value = None
1820
- cross_attn_weights = None
1821
-
1822
- # the important condition is that the encoder carries some information
1823
- if encoder_hidden_states is not None:
1824
-
1825
- # previous state is stored into an auxiliary variable `residual`
1826
- residual = hidden_states
1827
-
1828
- # cross_attn cached key/values tuple is at positions 3, 4 of PAST_key_value tuple
1829
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
1830
-
1831
- # MHSA in cross-attn
1832
- hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn.forward(
1833
- hidden_states=hidden_states, # Q = generated output
1834
- key_value_states=encoder_hidden_states, # K, V = encoder memory (used only in the 1st step when `use_cache = True`)
1835
- attention_mask=encoder_attention_mask, # just pads some elements (not causal this time!)
1836
- layer_head_mask=cross_attn_layer_head_mask, # again to mask certain heads
1837
- past_key_value=cross_attn_past_key_value, # K, V = encoder CACHED memory (used from the 2nd step on when `use_cache = True`)
1838
- output_attentions=output_attentions,
1839
- )
1840
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1841
-
1842
- # residual connection
1843
- hidden_states = residual + hidden_states
1844
-
1845
- # normalization
1846
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
1847
-
1848
- # add cross-attn to positions 3, 4 of PRESENT_key_value tuple
1849
- present_key_value = present_key_value + cross_attn_present_key_value
1850
-
1851
- ###################################################################
1852
-
1853
- # BLOCK 3: FFNN (transforming some merged generated output - encoder information)
1854
-
1855
- # previous state is stored into an auxiliary variable `residual`
1856
- residual = hidden_states
1857
-
1858
- # FFNN - core
1859
- hidden_states = self.activation_fn(self.fc1(hidden_states))
1860
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
1861
- hidden_states = self.fc2(hidden_states)
1862
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1863
-
1864
- # residual connection
1865
- hidden_states = residual + hidden_states
1866
-
1867
- # normalization
1868
- hidden_states = self.final_layer_norm(hidden_states)
1869
-
1870
- outputs = (hidden_states,)
1871
-
1872
- if output_attentions:
1873
- outputs += (self_attn_weights, cross_attn_weights)
1874
-
1875
- if use_cache: # otherwise it picks K and V each time
1876
- outputs += (present_key_value,)
1877
-
1878
- return outputs
1879
-
1880
- class STLDecoder(STLModel):
1881
- def __init__(self, config):
1882
- super().__init__(config)
1883
-
1884
- # Extract from `config` file
1885
- embed_dim = config.d_model
1886
- num_decoder_attention_heads = config.decoder_attention_heads
1887
- num_decoder_ffn_dim = config.decoder_ffn_dim
1888
- max_position_embeddings = config.max_position_embeddings
1889
- decoder_vocab_size = config.vocab_size
1890
- pad_token_id = config.pad_token_id
1891
- num_decoder_layers = config.decoder_layers
1892
- scale_embedding = config.scale_embedding
1893
- dropout = config.dropout
1894
- attention_dropout = config.attention_dropout
1895
- activation_dropout = config.activation_dropout
1896
- decoder_layerdrop = config.decoder_layerdrop
1897
-
1898
- self.dropout = dropout
1899
- self.layerdrop = decoder_layerdrop
1900
- self.padding_idx = pad_token_id
1901
- self.max_target_positions = max_position_embeddings
1902
- self.embed_scale = math.sqrt(embed_dim) if scale_embedding else 1.0
1903
-
1904
- # Initialize the input embedding (if not passed already)
1905
- self.embed_tokens = nn.Embedding(decoder_vocab_size, embed_dim, self.padding_idx)
1906
-
1907
- # Initialize positional embedding also
1908
- self.embed_positions = STLSinusoidalPositionalEmbedding(
1909
- max_position_embeddings, embed_dim, self.padding_idx
1910
- )
1911
-
1912
- # Initialize decoder layers (of a prespecified number)
1913
- self.layers = nn.ModuleList([STLDecoderBlock(embed_dim, num_decoder_attention_heads,
1914
- num_decoder_ffn_dim, dropout,
1915
- attention_dropout, activation_dropout)
1916
- for _ in range(num_decoder_layers)])
1917
-
1918
- self.gradient_checkpointing = False
1919
- self.post_init()
1920
-
1921
- def forward(
1922
- self,
1923
- input_ids: torch.LongTensor = None,
1924
- attention_mask: Optional[torch.Tensor] = None,
1925
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1926
- encoder_attention_mask: Optional[torch.Tensor] = None,
1927
- head_mask: Optional[torch.Tensor] = None,
1928
- cross_attn_head_mask: Optional[torch.Tensor] = None,
1929
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1930
- inputs_embeds: Optional[torch.FloatTensor] = None,
1931
- use_cache: Optional[bool] = None,
1932
- output_attentions: Optional[bool] = None,
1933
- output_hidden_states: Optional[bool] = None,
1934
- return_dict: Optional[bool] = None,
1935
- **kwargs,
1936
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
1937
-
1938
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1939
- output_hidden_states = (
1940
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1941
- )
1942
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1943
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1944
-
1945
- # retrieve input_ids and inputs_embeds
1946
- if input_ids is not None and inputs_embeds is not None:
1947
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1948
- elif input_ids is not None:
1949
- input_shape = input_ids.size()
1950
- input_ids = input_ids.view(-1, input_shape[-1])
1951
- elif inputs_embeds is not None:
1952
- input_shape = inputs_embeds.size()[:-1]
1953
- else:
1954
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1955
-
1956
- # past_key_values_length
1957
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1958
-
1959
- if inputs_embeds is None:
1960
- inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1961
-
1962
- attention_mask = _prepare_4d_causal_attention_mask(
1963
- attention_mask, input_shape, inputs_embeds, past_key_values_length
1964
- )
1965
-
1966
- # expand encoder attention mask
1967
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
1968
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1969
- encoder_attention_mask = _prepare_4d_attention_mask(
1970
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1971
- )
1972
-
1973
- # embed positions
1974
- positions = self.embed_positions(input_shape, past_key_values_length)
1975
-
1976
- hidden_states = inputs_embeds + positions
1977
-
1978
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1979
-
1980
- if self.gradient_checkpointing and self.training:
1981
- if use_cache:
1982
- logger.warning_once(
1983
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1984
- )
1985
- use_cache = False
1986
-
1987
- # decoder layers
1988
- all_hidden_states = () if output_hidden_states else None
1989
- all_self_attns = () if output_attentions else None
1990
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1991
- next_decoder_cache = () if use_cache else None
1992
-
1993
- for idx, decoder_layer in enumerate(self.layers):
1994
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1995
- if output_hidden_states:
1996
- all_hidden_states += (hidden_states,)
1997
- if self.training:
1998
- dropout_probability = torch.rand([])
1999
- if dropout_probability < self.layerdrop:
2000
- continue
2001
-
2002
- past_key_value = past_key_values[idx] if past_key_values is not None else None
2003
-
2004
- if self.gradient_checkpointing and self.training:
2005
- layer_outputs = self._gradient_checkpointing_func(
2006
- decoder_layer.__call__,
2007
- hidden_states,
2008
- attention_mask,
2009
- encoder_hidden_states,
2010
- encoder_attention_mask,
2011
- head_mask[idx] if head_mask is not None else None,
2012
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
2013
- None,
2014
- output_attentions,
2015
- use_cache,
2016
- )
2017
- else:
2018
- layer_outputs = decoder_layer(
2019
- hidden_states,
2020
- attention_mask=attention_mask,
2021
- encoder_hidden_states=encoder_hidden_states,
2022
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
2023
- cross_attn_layer_head_mask=(
2024
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
2025
- ),
2026
- past_key_value=past_key_value,
2027
- output_attentions=output_attentions,
2028
- use_cache=use_cache,
2029
- )
2030
- hidden_states = layer_outputs[0]
2031
-
2032
- if use_cache:
2033
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
2034
-
2035
- if output_attentions:
2036
- all_self_attns += (layer_outputs[1],)
2037
-
2038
- if encoder_hidden_states is not None:
2039
- all_cross_attentions += (layer_outputs[2],)
2040
-
2041
- # add hidden states from the last decoder layer
2042
- if output_hidden_states:
2043
- all_hidden_states += (hidden_states,)
2044
-
2045
- next_cache = next_decoder_cache if use_cache else None
2046
- if not return_dict:
2047
- return tuple(
2048
- v
2049
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
2050
- if v is not None
2051
- )
2052
- return BaseModelOutputWithPastAndCrossAttentions(
2053
- last_hidden_state=hidden_states,
2054
- past_key_values=next_cache,
2055
- hidden_states=all_hidden_states,
2056
- attentions=all_self_attns,
2057
- cross_attentions=all_cross_attentions,
2058
- )
2059
-
2060
- ####
2061
-
2062
- class STLForCausalLM(STLModel, GenerationMixin):
2063
- _tied_weights_keys = ["lm_head.weight"]
2064
-
2065
- def __init__(self, config):
2066
- config = copy.deepcopy(config)
2067
- config.is_decoder = True
2068
- config.is_encoder_decoder = False
2069
-
2070
- super().__init__(config)
2071
- self.model = STLDecoder(config)
2072
-
2073
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
2074
-
2075
- # Initialize weights and apply final processing
2076
- self.post_init()
2077
-
2078
- def get_input_embeddings(self):
2079
- return self.model.embed_tokens
2080
-
2081
- def set_input_embeddings(self, value):
2082
- self.model.embed_tokens = value
2083
-
2084
- def get_output_embeddings(self):
2085
- return self.lm_head
2086
-
2087
- def set_output_embeddings(self, new_embeddings):
2088
- self.lm_head = new_embeddings
2089
-
2090
- def set_decoder(self, decoder):
2091
- self.model = decoder
2092
-
2093
- def get_decoder(self):
2094
- return self.model
2095
-
2096
- def forward(
2097
- self,
2098
- input_ids: torch.LongTensor = None, # input sequence
2099
- attention_mask: Optional[torch.Tensor] = None, # masked MHSA + padding
2100
- encoder_hidden_states: Optional[torch.FloatTensor] = None, # embedding
2101
- encoder_attention_mask: Optional[torch.FloatTensor] = None, # MHSA + padding
2102
- head_mask: Optional[torch.Tensor] = None,
2103
- cross_attn_head_mask: Optional[torch.Tensor] = None,
2104
- past_key_values: Optional[List[torch.FloatTensor]] = None,
2105
- inputs_embeds: Optional[torch.FloatTensor] = None,
2106
- labels: Optional[torch.LongTensor] = None, # output sequence
2107
- use_cache: Optional[bool] = None,
2108
- output_attentions: Optional[bool] = None,
2109
- output_hidden_states: Optional[bool] = None,
2110
- return_dict: Optional[bool] = None,
2111
- **kwargs,
2112
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
2113
-
2114
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2115
- output_hidden_states = (
2116
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2117
- )
2118
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2119
-
2120
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2121
- outputs = self.model(
2122
- input_ids=input_ids,
2123
- attention_mask=attention_mask,
2124
- encoder_hidden_states=encoder_hidden_states,
2125
- encoder_attention_mask=encoder_attention_mask,
2126
- head_mask=head_mask,
2127
- cross_attn_head_mask=cross_attn_head_mask,
2128
- past_key_values=past_key_values,
2129
- inputs_embeds=inputs_embeds,
2130
- use_cache=use_cache,
2131
- output_attentions=output_attentions,
2132
- output_hidden_states=output_hidden_states,
2133
- return_dict=return_dict,
2134
- **kwargs
2135
- )
2136
-
2137
- logits = self.lm_head(outputs[0])
2138
-
2139
- loss = None
2140
- if labels is not None:
2141
- labels = labels.to(logits.device)
2142
- loss_fct = CrossEntropyLoss()
2143
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
2144
-
2145
- if not return_dict:
2146
- output = (logits,) + outputs[1:]
2147
- return (loss,) + output if loss is not None else output
2148
-
2149
- return CausalLMOutputWithCrossAttentions(
2150
- loss=loss,
2151
- logits=logits,
2152
- past_key_values=outputs.past_key_values,
2153
- hidden_states=outputs.hidden_states,
2154
- attentions=outputs.attentions,
2155
- cross_attentions=outputs.cross_attentions,
2156
- )
2157
-
2158
- @staticmethod
2159
- def _reorder_cache(past_key_values, beam_idx):
2160
- reordered_past = ()
2161
- for layer_past in past_key_values:
2162
- reordered_past += (
2163
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
2164
- )
2165
- return reordered_past
2166
-
2167
-
2168
-