|
|
import ast |
|
|
import copy |
|
|
import math |
|
|
import pickle |
|
|
import os |
|
|
from collections import deque |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.utils.checkpoint |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.utils import ( |
|
|
add_end_docstrings, |
|
|
add_start_docstrings, |
|
|
add_start_docstrings_to_model_forward, |
|
|
logging, |
|
|
replace_return_docstrings, |
|
|
) |
|
|
from transformers.modeling_outputs import ( |
|
|
BaseModelOutput, |
|
|
BaseModelOutputWithPastAndCrossAttentions, |
|
|
CausalLMOutputWithCrossAttentions, |
|
|
Seq2SeqLMOutput, |
|
|
Seq2SeqModelOutput, |
|
|
) |
|
|
|
|
|
from .configuration_stldec import STLConfig |
|
|
from nltk.translate.bleu_score import sentence_bleu |
|
|
|
|
|
import networkx as nx |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
|
|
|
realnum = Union[float, int] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def eventually(x: Tensor, time_span: int) -> Tensor: |
|
|
""" |
|
|
STL operator 'eventually' in 1D. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
x: torch.Tensor |
|
|
Signal |
|
|
time_span: any numeric type |
|
|
Timespan duration |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
A tensor containing the result of the operation. |
|
|
""" |
|
|
return F.max_pool1d(x, kernel_size=time_span, stride=1) |
|
|
|
|
|
class Node: |
|
|
"""Abstract node class for STL semantics tree.""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
|
|
pass |
|
|
|
|
|
def __str__(self) -> str: |
|
|
|
|
|
pass |
|
|
|
|
|
def boolean(self, x: Tensor, evaluate_at_all_times: bool = False) -> Tensor: |
|
|
""" |
|
|
Evaluates the boolean semantics at the node. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
x : torch.Tensor, of size N_samples x N_vars x N_sampling_points |
|
|
The input signals, stored as a batch tensor with trhee dimensions. |
|
|
evaluate_at_all_times: bool |
|
|
Whether to evaluate the semantics at all times (True) or |
|
|
just at t=0 (False). |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
A tensor with the boolean semantics for the node. |
|
|
""" |
|
|
z: Tensor = self._boolean(x) |
|
|
if evaluate_at_all_times: |
|
|
return z |
|
|
else: |
|
|
return self._extract_semantics_at_time_zero(z) |
|
|
|
|
|
def quantitative( |
|
|
self, |
|
|
x: Tensor, |
|
|
normalize: bool = False, |
|
|
evaluate_at_all_times: bool = False, |
|
|
) -> Tensor: |
|
|
""" |
|
|
Evaluates the quantitative semantics at the node. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
x : torch.Tensor, of size N_samples x N_vars x N_sampling_points |
|
|
The input signals, stored as a batch tensor with three dimensions. |
|
|
normalize: bool |
|
|
Whether the measure of robustness if normalized (True) or |
|
|
not (False). Currently not in use. |
|
|
evaluate_at_all_times: bool |
|
|
Whether to evaluate the semantics at all times (True) or |
|
|
just at t=0 (False). |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
A tensor with the quantitative semantics for the node. |
|
|
""" |
|
|
z: Tensor = self._quantitative(x, normalize) |
|
|
if evaluate_at_all_times: |
|
|
return z |
|
|
else: |
|
|
return self._extract_semantics_at_time_zero(z) |
|
|
|
|
|
def set_normalizing_flag(self, value: bool = True) -> None: |
|
|
""" |
|
|
Setter for the 'normalization of robustness of the formula' flag. |
|
|
Currently not in use. |
|
|
""" |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
"""Returns time depth of bounded temporal operators only.""" |
|
|
|
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
"""Private method equivalent to public one for inner call.""" |
|
|
|
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
"""Private method equivalent to public one for inner call.""" |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _extract_semantics_at_time_zero(x: Tensor) -> Tensor: |
|
|
"""Extrapolates the vector of truth values at time zero""" |
|
|
return torch.reshape(x[:, 0, 0], (-1,)) |
|
|
|
|
|
|
|
|
class Atom(Node): |
|
|
"""Atomic formula node; for now of the form X<=t or X>=t""" |
|
|
|
|
|
def __init__(self, var_index: int, threshold: realnum, lte: bool = False) -> None: |
|
|
super().__init__() |
|
|
self.var_index: int = var_index |
|
|
self.threshold: realnum = threshold |
|
|
self.lte: bool = lte |
|
|
|
|
|
def __str__(self) -> str: |
|
|
s: str = ( |
|
|
"x_" |
|
|
+ str(self.var_index) |
|
|
+ (" <= " if self.lte else " >= ") |
|
|
+ str(round(self.threshold, 4)) |
|
|
) |
|
|
return s |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
return 0 |
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
|
|
|
xj: Tensor = x[:, self.var_index, :] |
|
|
xj: Tensor = xj.view(xj.size()[0], 1, -1) |
|
|
if self.lte: |
|
|
z: Tensor = torch.le(xj, self.threshold) |
|
|
else: |
|
|
z: Tensor = torch.ge(xj, self.threshold) |
|
|
return z |
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
|
|
|
xj: Tensor = x[:, self.var_index, :] |
|
|
xj: Tensor = xj.view(xj.size()[0], 1, -1) |
|
|
if self.lte: |
|
|
z: Tensor = -xj + self.threshold |
|
|
else: |
|
|
z: Tensor = xj - self.threshold |
|
|
if normalize: |
|
|
z: Tensor = torch.tanh(z) |
|
|
return z |
|
|
|
|
|
class Not(Node): |
|
|
"""Negation node.""" |
|
|
|
|
|
def __init__(self, child: Node) -> None: |
|
|
super().__init__() |
|
|
self.child: Node = child |
|
|
|
|
|
def __str__(self) -> str: |
|
|
s: str = "not ( " + self.child.__str__() + " )" |
|
|
return s |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
return self.child.time_depth() |
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
z: Tensor = ~self.child._boolean(x) |
|
|
return z |
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
z: Tensor = -self.child._quantitative(x, normalize) |
|
|
return z |
|
|
|
|
|
|
|
|
class And(Node): |
|
|
"""Conjunction node.""" |
|
|
|
|
|
def __init__(self, left_child: Node, right_child: Node) -> None: |
|
|
super().__init__() |
|
|
self.left_child: Node = left_child |
|
|
self.right_child: Node = right_child |
|
|
|
|
|
def __str__(self) -> str: |
|
|
s: str = ( |
|
|
"( " |
|
|
+ self.left_child.__str__() |
|
|
+ " and " |
|
|
+ self.right_child.__str__() |
|
|
+ " )" |
|
|
) |
|
|
return s |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
return max(self.left_child.time_depth(), self.right_child.time_depth()) |
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
z1: Tensor = self.left_child._boolean(x) |
|
|
z2: Tensor = self.right_child._boolean(x) |
|
|
size: int = min(z1.size()[2], z2.size()[2]) |
|
|
z1: Tensor = z1[:, :, :size] |
|
|
z2: Tensor = z2[:, :, :size] |
|
|
z: Tensor = torch.logical_and(z1, z2) |
|
|
return z |
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
z1: Tensor = self.left_child._quantitative(x, normalize) |
|
|
z2: Tensor = self.right_child._quantitative(x, normalize) |
|
|
size: int = min(z1.size()[2], z2.size()[2]) |
|
|
z1: Tensor = z1[:, :, :size] |
|
|
z2: Tensor = z2[:, :, :size] |
|
|
z: Tensor = torch.min(z1, z2) |
|
|
return z |
|
|
|
|
|
class Not(Node): |
|
|
"""Negation node.""" |
|
|
|
|
|
def __init__(self, child: Node) -> None: |
|
|
super().__init__() |
|
|
self.child: Node = child |
|
|
|
|
|
def __str__(self) -> str: |
|
|
s: str = "not ( " + self.child.__str__() + " )" |
|
|
return s |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
return self.child.time_depth() |
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
z: Tensor = ~self.child._boolean(x) |
|
|
return z |
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
z: Tensor = -self.child._quantitative(x, normalize) |
|
|
return z |
|
|
|
|
|
|
|
|
class And(Node): |
|
|
"""Conjunction node.""" |
|
|
|
|
|
def __init__(self, left_child: Node, right_child: Node) -> None: |
|
|
super().__init__() |
|
|
self.left_child: Node = left_child |
|
|
self.right_child: Node = right_child |
|
|
|
|
|
def __str__(self) -> str: |
|
|
s: str = ( |
|
|
"( " |
|
|
+ self.left_child.__str__() |
|
|
+ " and " |
|
|
+ self.right_child.__str__() |
|
|
+ " )" |
|
|
) |
|
|
return s |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
return max(self.left_child.time_depth(), self.right_child.time_depth()) |
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
z1: Tensor = self.left_child._boolean(x) |
|
|
z2: Tensor = self.right_child._boolean(x) |
|
|
size: int = min(z1.size()[2], z2.size()[2]) |
|
|
z1: Tensor = z1[:, :, :size] |
|
|
z2: Tensor = z2[:, :, :size] |
|
|
z: Tensor = torch.logical_and(z1, z2) |
|
|
return z |
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
z1: Tensor = self.left_child._quantitative(x, normalize) |
|
|
z2: Tensor = self.right_child._quantitative(x, normalize) |
|
|
size: int = min(z1.size()[2], z2.size()[2]) |
|
|
z1: Tensor = z1[:, :, :size] |
|
|
z2: Tensor = z2[:, :, :size] |
|
|
z: Tensor = torch.min(z1, z2) |
|
|
return z |
|
|
|
|
|
class Or(Node): |
|
|
"""Disjunction node.""" |
|
|
|
|
|
def __init__(self, left_child: Node, right_child: Node) -> None: |
|
|
super().__init__() |
|
|
self.left_child: Node = left_child |
|
|
self.right_child: Node = right_child |
|
|
|
|
|
def __str__(self) -> str: |
|
|
s: str = ( |
|
|
"( " |
|
|
+ self.left_child.__str__() |
|
|
+ " or " |
|
|
+ self.right_child.__str__() |
|
|
+ " )" |
|
|
) |
|
|
return s |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
return max(self.left_child.time_depth(), self.right_child.time_depth()) |
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
z1: Tensor = self.left_child._boolean(x) |
|
|
z2: Tensor = self.right_child._boolean(x) |
|
|
size: int = min(z1.size()[2], z2.size()[2]) |
|
|
z1: Tensor = z1[:, :, :size] |
|
|
z2: Tensor = z2[:, :, :size] |
|
|
z: Tensor = torch.logical_or(z1, z2) |
|
|
return z |
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
z1: Tensor = self.left_child._quantitative(x, normalize) |
|
|
z2: Tensor = self.right_child._quantitative(x, normalize) |
|
|
size: int = min(z1.size()[2], z2.size()[2]) |
|
|
z1: Tensor = z1[:, :, :size] |
|
|
z2: Tensor = z2[:, :, :size] |
|
|
z: Tensor = torch.max(z1, z2) |
|
|
return z |
|
|
|
|
|
|
|
|
class Globally(Node): |
|
|
"""Globally node.""" |
|
|
def __init__( |
|
|
self, |
|
|
child: Node, |
|
|
unbound: bool = False, |
|
|
right_unbound: bool = False, |
|
|
left_time_bound: int = 0, |
|
|
right_time_bound: int = 1, |
|
|
adapt_unbound: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.child: Node = child |
|
|
self.unbound: bool = unbound |
|
|
self.right_unbound: bool = right_unbound |
|
|
self.left_time_bound: int = left_time_bound |
|
|
self.right_time_bound: int = right_time_bound + 1 |
|
|
self.adapt_unbound: bool = adapt_unbound |
|
|
|
|
|
def __str__(self) -> str: |
|
|
s_left = "[" + str(self.left_time_bound) + "," |
|
|
s_right = str(self.right_time_bound) if not self.right_unbound else "inf" |
|
|
s0: str = s_left + s_right + "]" if not self.unbound else "" |
|
|
s: str = "always" + s0 + " ( " + self.child.__str__() + " )" |
|
|
return s |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
if self.unbound: |
|
|
return self.child.time_depth() |
|
|
elif self.right_unbound: |
|
|
return self.child.time_depth() + self.left_time_bound |
|
|
else: |
|
|
|
|
|
return self.child.time_depth() + self.right_time_bound - 1 |
|
|
|
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:]) |
|
|
|
|
|
if self.unbound or self.right_unbound: |
|
|
if self.adapt_unbound: |
|
|
z: Tensor |
|
|
_: Tensor |
|
|
z, _ = torch.cummin(torch.flip(z1, [2]), dim=2) |
|
|
z: Tensor = torch.flip(z, [2]) |
|
|
else: |
|
|
z: Tensor |
|
|
_: Tensor |
|
|
z, _ = torch.min(z1, 2, keepdim=True) |
|
|
else: |
|
|
z: Tensor = torch.ge(1.0 - eventually((~z1).double(), self.right_time_bound - self.left_time_bound), 0.5) |
|
|
return z |
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize) |
|
|
|
|
|
if self.unbound or self.right_unbound: |
|
|
if self.adapt_unbound: |
|
|
z: Tensor |
|
|
_: Tensor |
|
|
z, _ = torch.cummin(torch.flip(z1, [2]), dim=2) |
|
|
z: Tensor = torch.flip(z, [2]) |
|
|
else: |
|
|
z: Tensor |
|
|
_: Tensor |
|
|
z, _ = torch.min(z1, 2, keepdim=True) |
|
|
else: |
|
|
z: Tensor = -eventually(-z1, self.right_time_bound - self.left_time_bound) |
|
|
return z |
|
|
|
|
|
|
|
|
|
|
|
class Eventually(Node): |
|
|
"""Eventually node.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
child: Node, |
|
|
unbound: bool = False, |
|
|
right_unbound: bool = False, |
|
|
left_time_bound: int = 0, |
|
|
right_time_bound: int = 1, |
|
|
adapt_unbound: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.child: Node = child |
|
|
self.unbound: bool = unbound |
|
|
self.right_unbound: bool = right_unbound |
|
|
self.left_time_bound: int = left_time_bound |
|
|
self.right_time_bound: int = right_time_bound + 1 |
|
|
self.adapt_unbound: bool = adapt_unbound |
|
|
|
|
|
if (self.unbound is False) and (self.right_unbound is False) and \ |
|
|
(self.right_time_bound <= self.left_time_bound): |
|
|
raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter") |
|
|
|
|
|
def __str__(self) -> str: |
|
|
s_left = "[" + str(self.left_time_bound) + "," |
|
|
s_right = str(self.right_time_bound) if not self.right_unbound else "inf" |
|
|
s0: str = s_left + s_right + "]" if not self.unbound else "" |
|
|
s: str = "eventually" + s0 + " ( " + self.child.__str__() + " )" |
|
|
return s |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
if self.unbound: |
|
|
return self.child.time_depth() |
|
|
elif self.right_unbound: |
|
|
return self.child.time_depth() + self.left_time_bound |
|
|
else: |
|
|
|
|
|
return self.child.time_depth() + self.right_time_bound - 1 |
|
|
|
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:]) |
|
|
if self.unbound or self.right_unbound: |
|
|
if self.adapt_unbound: |
|
|
z: Tensor |
|
|
_: Tensor |
|
|
z, _ = torch.cummax(torch.flip(z1, [2]), dim=2) |
|
|
z: Tensor = torch.flip(z, [2]) |
|
|
else: |
|
|
z: Tensor |
|
|
_: Tensor |
|
|
z, _ = torch.max(z1, 2, keepdim=True) |
|
|
else: |
|
|
z: Tensor = torch.ge(eventually(z1.double(), self.right_time_bound - self.left_time_bound), 0.5) |
|
|
return z |
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize) |
|
|
if self.unbound or self.right_unbound: |
|
|
if self.adapt_unbound: |
|
|
z: Tensor |
|
|
_: Tensor |
|
|
z, _ = torch.cummax(torch.flip(z1, [2]), dim=2) |
|
|
z: Tensor = torch.flip(z, [2]) |
|
|
else: |
|
|
z: Tensor |
|
|
_: Tensor |
|
|
z, _ = torch.max(z1, 2, keepdim=True) |
|
|
else: |
|
|
z: Tensor = eventually(z1, self.right_time_bound - self.left_time_bound) |
|
|
return z |
|
|
|
|
|
class Until(Node): |
|
|
"""Until node.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
left_child: Node, |
|
|
right_child: Node, |
|
|
unbound: bool = False, |
|
|
right_unbound: bool = False, |
|
|
left_time_bound: int = 0, |
|
|
right_time_bound: int = 1, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.left_child: Node = left_child |
|
|
self.right_child: Node = right_child |
|
|
self.unbound: bool = unbound |
|
|
self.right_unbound: bool = right_unbound |
|
|
self.left_time_bound: int = left_time_bound |
|
|
self.right_time_bound: int = right_time_bound + 1 |
|
|
|
|
|
if (self.unbound is False) and (self.right_unbound is False) and \ |
|
|
(self.right_time_bound <= self.left_time_bound): |
|
|
raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter") |
|
|
|
|
|
def __str__(self) -> str: |
|
|
s_left = "[" + str(self.left_time_bound) + "," |
|
|
s_right = str(self.right_time_bound) if not self.right_unbound else "inf" |
|
|
s0: str = s_left + s_right + "]" if not self.unbound else "" |
|
|
s: str = "( " + self.left_child.__str__() + " until" + s0 + " " + self.right_child.__str__() + " )" |
|
|
return s |
|
|
|
|
|
def time_depth(self) -> int: |
|
|
sum_children_depth: int = self.left_child.time_depth() + self.right_child.time_depth() |
|
|
if self.unbound: |
|
|
return sum_children_depth |
|
|
elif self.right_unbound: |
|
|
return sum_children_depth + self.left_time_bound |
|
|
else: |
|
|
|
|
|
return sum_children_depth + self.right_time_bound - 1 |
|
|
|
|
|
|
|
|
def _boolean(self, x: Tensor) -> Tensor: |
|
|
if self.unbound: |
|
|
z1: Tensor = self.left_child._boolean(x) |
|
|
z2: Tensor = self.right_child._boolean(x) |
|
|
size: int = min(z1.size()[2], z2.size()[2]) |
|
|
z1: Tensor = z1[:, :, :size] |
|
|
z2: Tensor = z2[:, :, :size] |
|
|
z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2) |
|
|
z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1) |
|
|
z1_triu = torch.triu(z1_rep) |
|
|
z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0] |
|
|
|
|
|
z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2) |
|
|
z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1) |
|
|
z2_triu = torch.triu(z2_rep) |
|
|
z2_def = z2_tril + z2_triu |
|
|
z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0], |
|
|
dim=-1)[0] |
|
|
elif self.right_unbound: |
|
|
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound), |
|
|
And(Eventually(self.right_child, right_unbound=True, |
|
|
left_time_bound=self.left_time_bound), |
|
|
Eventually(Until(self.left_child, self.right_child, unbound=True), |
|
|
left_time_bound=self.left_time_bound, right_unbound=True))) |
|
|
z: Tensor = timed_until._boolean(x) |
|
|
else: |
|
|
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound), |
|
|
And(Eventually(self.right_child, left_time_bound=self.left_time_bound, |
|
|
right_time_bound=self.right_time_bound - 1), |
|
|
Eventually(Until(self.left_child, self.right_child, unbound=True), |
|
|
left_time_bound=self.left_time_bound, right_unbound=True))) |
|
|
z: Tensor = timed_until._boolean(x) |
|
|
return z |
|
|
|
|
|
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor: |
|
|
if self.unbound: |
|
|
z1: Tensor = self.left_child._quantitative(x, normalize) |
|
|
z2: Tensor = self.right_child._quantitative(x, normalize) |
|
|
size: int = min(z1.size()[2], z2.size()[2]) |
|
|
z1: Tensor = z1[:, :, :size] |
|
|
z2: Tensor = z2[:, :, :size] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
z: Tensor = torch.cat([torch.max(torch.min( |
|
|
torch.cat([torch.cummin(z1[:, :, t:].unsqueeze(-1), dim=2)[0], z2[:, :, t:].unsqueeze(-1)], dim=-1), |
|
|
dim=-1)[0], dim=2, keepdim=True)[0] for t in range(size)], dim=2) |
|
|
elif self.right_unbound: |
|
|
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound), |
|
|
And(Eventually(self.right_child, right_unbound=True, |
|
|
left_time_bound=self.left_time_bound), |
|
|
Eventually(Until(self.left_child, self.right_child, unbound=True), |
|
|
left_time_bound=self.left_time_bound, right_unbound=True))) |
|
|
z: Tensor = timed_until._quantitative(x, normalize=normalize) |
|
|
else: |
|
|
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound), |
|
|
And(Eventually(self.right_child, left_time_bound=self.left_time_bound, |
|
|
right_time_bound=self.right_time_bound-1), |
|
|
Eventually(Until(self.left_child, self.right_child, unbound=True), |
|
|
left_time_bound=self.left_time_bound, right_unbound=True))) |
|
|
z: Tensor = timed_until._quantitative(x, normalize=normalize) |
|
|
return z |
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
import json |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
from transformers import PreTrainedTokenizer |
|
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def load_pickle(path): |
|
|
with open(path, 'rb') as f: |
|
|
x = pickle.load(f) |
|
|
return x |
|
|
|
|
|
def dump_pickle(name, thing): |
|
|
with open(name + '.pickle', 'wb') as f: |
|
|
pickle.dump(thing, f) |
|
|
|
|
|
def from_string_to_formula(st): |
|
|
root_arity = 2 if st.startswith('(') else 1 |
|
|
st_split = st.split() |
|
|
if root_arity <= 1: |
|
|
root_op_str = copy.deepcopy(st_split[0]) |
|
|
if root_op_str.startswith('x'): |
|
|
atom_sign = True if st_split[1] == '<=' else False |
|
|
root_phi = Atom(var_index=int(st_split[0][2]), lte=atom_sign, threshold=float(st_split[2])) |
|
|
return root_phi |
|
|
else: |
|
|
assert (root_op_str.startswith('not') or root_op_str.startswith('eventually') |
|
|
or root_op_str.startswith('always')) |
|
|
current_st = copy.deepcopy(st_split[2:-1]) |
|
|
if root_op_str == 'not': |
|
|
root_phi = Not(child=from_string_to_formula(' '.join(current_st))) |
|
|
elif root_op_str.startswith('eventually'): |
|
|
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str) |
|
|
root_phi = Eventually(child=from_string_to_formula(' '.join(current_st)), unbound=unbound, |
|
|
right_unbound=right_unbound, left_time_bound=left_time_bound, |
|
|
right_time_bound=right_time_bound) |
|
|
else: |
|
|
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str) |
|
|
root_phi = Globally(child=from_string_to_formula(' '.join(current_st)), unbound=unbound, |
|
|
right_unbound=right_unbound, left_time_bound=left_time_bound, |
|
|
right_time_bound=right_time_bound) |
|
|
else: |
|
|
|
|
|
current_st = copy.deepcopy(st_split[1:-1]) |
|
|
if '(' in current_st: |
|
|
par_queue = deque() |
|
|
par_idx_list = [] |
|
|
for i, sub in enumerate(current_st): |
|
|
if sub == '(': |
|
|
par_queue.append(i) |
|
|
elif sub == ')': |
|
|
par_idx_list.append(tuple([par_queue.pop(), i])) |
|
|
|
|
|
|
|
|
children_range = [] |
|
|
for begin, end in sorted(par_idx_list): |
|
|
if children_range and children_range[-1][1] >= begin - 1: |
|
|
children_range[-1][1] = max(children_range[-1][1], end) |
|
|
else: |
|
|
children_range.append([begin, end]) |
|
|
n_children = len(children_range) |
|
|
assert (n_children in [1, 2]) |
|
|
if n_children == 1: |
|
|
|
|
|
var_child_idx = 1 if children_range[0][0] <= 1 else 0 |
|
|
if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']: |
|
|
children_range[0][0] -= 1 |
|
|
left_child_str = current_st[:3] if var_child_idx == 0 else \ |
|
|
current_st[children_range[0][0]:children_range[0][1] + 1] |
|
|
right_child_str = current_st[-3:] if var_child_idx == 1 else \ |
|
|
current_st[children_range[0][0]:children_range[0][1] + 1] |
|
|
root_op_str = current_st[children_range[0][1] + 1] if var_child_idx == 1 else \ |
|
|
current_st[children_range[0][0] - 1] |
|
|
assert (root_op_str[:2] in ['an', 'or', 'un']) |
|
|
else: |
|
|
if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']: |
|
|
children_range[0][0] -= 1 |
|
|
if current_st[children_range[1][0] - 1][0:2] in ['no', 'ev', 'al']: |
|
|
children_range[1][0] -= 1 |
|
|
|
|
|
root_op_str = current_st[children_range[0][1] + 1] |
|
|
assert (root_op_str[:2] in ['an', 'or', 'un']) |
|
|
left_child_str = current_st[children_range[0][0]:children_range[0][1] + 1] |
|
|
right_child_str = current_st[children_range[1][0]:children_range[1][1] + 1] |
|
|
else: |
|
|
|
|
|
left_child_str = current_st[:3] |
|
|
right_child_str = current_st[-3:] |
|
|
root_op_str = current_st[3] |
|
|
left_child_str = ' '.join(left_child_str) |
|
|
right_child_str = ' '.join(right_child_str) |
|
|
if root_op_str == 'and': |
|
|
root_phi = And(left_child=from_string_to_formula(left_child_str), |
|
|
right_child=from_string_to_formula(right_child_str)) |
|
|
elif root_op_str == 'or': |
|
|
root_phi = Or(left_child=from_string_to_formula(left_child_str), |
|
|
right_child=from_string_to_formula(right_child_str)) |
|
|
else: |
|
|
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str) |
|
|
root_phi = Until(left_child=from_string_to_formula(left_child_str), |
|
|
right_child=from_string_to_formula(right_child_str), |
|
|
unbound=unbound, right_unbound=right_unbound, left_time_bound=left_time_bound, |
|
|
right_time_bound=right_time_bound) |
|
|
return root_phi |
|
|
|
|
|
def load_json(path: str) -> Union[Dict, List]: |
|
|
""" |
|
|
Load a JSON file from the given path. |
|
|
Args: |
|
|
path (str): The path to the JSON file to be loaded. |
|
|
|
|
|
Returns: |
|
|
Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list. |
|
|
""" |
|
|
with open(path, "r") as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
class StlGenerator: |
|
|
def __init__( |
|
|
self, |
|
|
leaf_prob: float = 0.3, |
|
|
inner_node_prob: list = None, |
|
|
threshold_mean: float = 0.0, |
|
|
threshold_sd: float = 1.0, |
|
|
unbound_prob: float = 0.1, |
|
|
right_unbound_prob: float = 0.2, |
|
|
time_bound_max_range: float = 20, |
|
|
adaptive_unbound_temporal_ops: bool = True, |
|
|
max_timespan: int = 100, |
|
|
): |
|
|
""" |
|
|
leaf_prob |
|
|
probability of generating a leaf (always zero for root) |
|
|
node_types = ["not", "and", "or", "always", "eventually", "until"] |
|
|
Inner node types |
|
|
inner_node_prob |
|
|
probability vector for the different types of internal nodes |
|
|
threshold_mean |
|
|
threshold_sd |
|
|
mean and std for the normal distribution of the thresholds of atoms |
|
|
unbound_prob |
|
|
probability of a temporal operator to have a time bound o the type [0,infty] |
|
|
time_bound_max_range |
|
|
maximum value of time span of a temporal operator (i.e. max value of t in [0,t]) |
|
|
adaptive_unbound_temporal_ops |
|
|
if true, unbounded temporal operators are computed from current point to the end of the signal, otherwise |
|
|
they are evaluated only at time zero. |
|
|
max_timespan |
|
|
maximum time depth of a formula. |
|
|
""" |
|
|
|
|
|
|
|
|
if inner_node_prob is None: |
|
|
inner_node_prob = [0.166, 0.166, 0.166, 0.17, 0.166, 0.166] |
|
|
|
|
|
self.leaf_prob = leaf_prob |
|
|
self.inner_node_prob = inner_node_prob |
|
|
self.threshold_mean = threshold_mean |
|
|
self.threshold_sd = threshold_sd |
|
|
self.unbound_prob = unbound_prob |
|
|
self.right_unbound_prob = right_unbound_prob |
|
|
self.time_bound_max_range = time_bound_max_range |
|
|
self.adaptive_unbound_temporal_ops = adaptive_unbound_temporal_ops |
|
|
self.node_types = ["not", "and", "or", "always", "eventually", "until"] |
|
|
self.max_timespan = max_timespan |
|
|
|
|
|
def sample(self, nvars): |
|
|
""" |
|
|
Samples a random formula with distribution defined in class instance parameters |
|
|
Parameters |
|
|
---------- |
|
|
nvars : number of variables of input signals |
|
|
how many variables the formula is expected to consider. |
|
|
Returns |
|
|
------- |
|
|
TYPE |
|
|
A random formula. |
|
|
""" |
|
|
return self._sample_internal_node(nvars) |
|
|
def bag_sample(self, bag_size, nvars): |
|
|
""" |
|
|
Samples a bag of bag_size formulae |
|
|
Parameters |
|
|
---------- |
|
|
bag_size : INT |
|
|
number of formulae. |
|
|
nvars : INT |
|
|
number of vars in formulae. |
|
|
Returns |
|
|
------- |
|
|
a list of formulae. |
|
|
""" |
|
|
formulae = [] |
|
|
for _ in range(bag_size): |
|
|
phi = self.sample(nvars) |
|
|
formulae.append(phi) |
|
|
return formulae |
|
|
|
|
|
def _sample_internal_node(self, nvars): |
|
|
|
|
|
node: Union[None, Node] |
|
|
node = None |
|
|
|
|
|
nodetype = rnd.choice(self.node_types, p=self.inner_node_prob) |
|
|
while True: |
|
|
if nodetype == "not": |
|
|
n = self._sample_node(nvars) |
|
|
node = Not(n) |
|
|
elif nodetype == "and": |
|
|
n1 = self._sample_node(nvars) |
|
|
n2 = self._sample_node(nvars) |
|
|
node = And(n1, n2) |
|
|
elif nodetype == "or": |
|
|
n1 = self._sample_node(nvars) |
|
|
n2 = self._sample_node(nvars) |
|
|
node = Or(n1, n2) |
|
|
elif nodetype == "always": |
|
|
n = self._sample_node(nvars) |
|
|
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters() |
|
|
node = Globally( |
|
|
n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops |
|
|
) |
|
|
elif nodetype == "eventually": |
|
|
n = self._sample_node(nvars) |
|
|
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters() |
|
|
node = Eventually( |
|
|
n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops |
|
|
) |
|
|
elif nodetype == "until": |
|
|
n1 = self._sample_node(nvars) |
|
|
n2 = self._sample_node(nvars) |
|
|
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters() |
|
|
node = Until( |
|
|
n1, n2, unbound, right_unbound, left_time_bound, right_time_bound |
|
|
) |
|
|
|
|
|
if (node is not None) and (node.time_depth() < self.max_timespan): |
|
|
return node |
|
|
|
|
|
def _sample_node(self, nvars): |
|
|
if rnd.rand() < self.leaf_prob: |
|
|
|
|
|
var, thr, lte = self._get_atom(nvars) |
|
|
return Atom(var, thr, lte) |
|
|
else: |
|
|
return self._sample_internal_node(nvars) |
|
|
|
|
|
def _get_temporal_parameters(self): |
|
|
if rnd.rand() < self.unbound_prob: |
|
|
return True, False, 0, 0 |
|
|
elif rnd.rand() < self.right_unbound_prob: |
|
|
return False, True, rnd.randint(self.time_bound_max_range), 1 |
|
|
else: |
|
|
left_bound = rnd.randint(self.time_bound_max_range) |
|
|
return False, False, left_bound, rnd.randint(left_bound, self.time_bound_max_range) + 1 |
|
|
|
|
|
def _get_atom(self, nvars): |
|
|
variable = rnd.randint(nvars) |
|
|
lte = rnd.rand() > 0.5 |
|
|
threshold = rnd.normal(self.threshold_mean, self.threshold_sd) |
|
|
return variable, threshold, lte |
|
|
|
|
|
|
|
|
|
|
|
class Measure: |
|
|
def sample(self, samples=100000, varn=2, points=100): |
|
|
|
|
|
pass |
|
|
|
|
|
class BaseMeasure(Measure): |
|
|
def __init__( |
|
|
self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu" |
|
|
): |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
mu0 : mean of normal distribution of initial state, optional |
|
|
The default is 0.0. |
|
|
sigma0 : standard deviation of normal distribution of initial state, optional |
|
|
The default is 1.0. |
|
|
mu1 : DOUBLE, optional |
|
|
mean of normal distribution of total variation. The default is 0.0. |
|
|
sigma1 : standard deviation of normal distribution of total variation, optional |
|
|
The default is 1.0. |
|
|
q : DOUBLE, optional |
|
|
probability of change of sign in derivative. The default is 0.1. |
|
|
q0 : DOUBLE, optional |
|
|
probability of initial sign of derivative. The default is 0.5. |
|
|
device : 'cpu' or 'cuda', optional |
|
|
device on which to run the algorithm. The default is 'cpu'. |
|
|
Returns |
|
|
------- |
|
|
None. |
|
|
""" |
|
|
self.mu0 = mu0 |
|
|
self.sigma0 = sigma0 |
|
|
self.mu1 = mu1 |
|
|
self.sigma1 = sigma1 |
|
|
self.q = q |
|
|
self.q0 = q0 |
|
|
self.device = device |
|
|
|
|
|
def sample(self, samples=100000, varn=2, points=100): |
|
|
""" |
|
|
Samples a set of trajectories from the basic measure space, with parameters |
|
|
passed to the sampler |
|
|
Parameters |
|
|
---------- |
|
|
points : INT, optional |
|
|
number of points per trajectory, including initial one. The default is 1000. |
|
|
samples : INT, optional |
|
|
number of trajectories. The default is 100000. |
|
|
varn : INT, optional |
|
|
number of variables per trajectory. The default is 2. |
|
|
Returns |
|
|
------- |
|
|
signal : samples x varn x points double pytorch tensor |
|
|
The sampled signals. |
|
|
""" |
|
|
if self.device == "cuda" and not torch.cuda.is_available(): |
|
|
raise RuntimeError("GPU card or CUDA library not available!") |
|
|
|
|
|
|
|
|
signal = torch.rand(samples, varn, points, device=self.device) |
|
|
|
|
|
signal[:, :, 0] = 0.0 |
|
|
signal[:, :, -1] = 1.0 |
|
|
|
|
|
signal, _ = torch.sort(signal, 2) |
|
|
|
|
|
signal[:, :, 1:] = signal[:, :, 1:] - signal[:, :, :-1] |
|
|
|
|
|
signal[:, :, 0] = self.mu0 + self.sigma0 * torch.randn(signal[:, :, 0].size()) |
|
|
|
|
|
|
|
|
derivs = (1 - self.q) * torch.ones(samples, varn, points, device=self.device) |
|
|
derivs = 2 * torch.bernoulli(derivs) - 1 |
|
|
|
|
|
derivs[:, :, 0] = self.q0 |
|
|
derivs[:, :, 0] = 2 * torch.bernoulli(derivs[:, :, 0]) - 1 |
|
|
|
|
|
derivs = torch.cumprod(derivs, 2) |
|
|
|
|
|
|
|
|
totvar = torch.pow( |
|
|
self.mu1 + self.sigma1 * torch.randn(samples, varn, 1, device=self.device), |
|
|
2, |
|
|
) |
|
|
|
|
|
derivs = derivs * totvar |
|
|
derivs[:, :, 0] = 1.0 |
|
|
|
|
|
|
|
|
signal = signal * derivs |
|
|
signal = torch.cumsum(signal, 2) |
|
|
return signal |
|
|
|
|
|
|
|
|
|
|
|
realnum = Union[float, int] |
|
|
|
|
|
class StlKernel: |
|
|
def __init__( |
|
|
self, |
|
|
measure, |
|
|
normalize=True, |
|
|
exp_kernel=True, |
|
|
sigma2=0.2, |
|
|
integrate_time=False, |
|
|
samples=100000, |
|
|
varn=2, |
|
|
points=100, |
|
|
boolean=False, |
|
|
signals=None, |
|
|
): |
|
|
self.traj_measure = measure |
|
|
self.exp_kernel = exp_kernel |
|
|
self.normalize = normalize |
|
|
self.sigma2 = sigma2 |
|
|
self.samples = samples |
|
|
self.varn = varn |
|
|
self.points = points |
|
|
self.integrate_time = integrate_time |
|
|
if signals is not None: |
|
|
self.signals = signals |
|
|
else: |
|
|
self.signals = measure.sample(points=points, samples=samples, varn=varn) |
|
|
self.boolean = boolean |
|
|
|
|
|
def compute(self, phi1, phi2): |
|
|
return self.compute_one_one(phi1, phi2) |
|
|
|
|
|
def compute_one_one(self, phi1, phi2): |
|
|
phis1: list = [phi1] |
|
|
phis2: list = [phi2] |
|
|
ker = self.compute_bag_bag(phis1, phis2) |
|
|
return ker[0, 0] |
|
|
|
|
|
def compute_bag(self, phis, return_robustness=True): |
|
|
if self.integrate_time: |
|
|
rhos, selfk, len0 = self._compute_robustness_time(phis) |
|
|
kernel_matrix = self._compute_kernel_time( |
|
|
rhos, rhos, selfk, selfk, len0, len0 |
|
|
) |
|
|
else: |
|
|
rhos, selfk = self._compute_robustness_no_time(phis) |
|
|
kernel_matrix = self._compute_kernel_no_time(rhos, rhos, selfk, selfk) |
|
|
len0 = None |
|
|
if return_robustness: |
|
|
return kernel_matrix.cpu(), rhos, selfk, len0 |
|
|
else: |
|
|
return kernel_matrix.cpu() |
|
|
|
|
|
def compute_one_bag(self, phi1, phis2, return_robustness=False): |
|
|
phis1: list = [phi1] |
|
|
return self.compute_bag_bag(phis1, phis2, return_robustness) |
|
|
|
|
|
def compute_bag_bag(self, phis1, phis2, return_robustness=False): |
|
|
if self.integrate_time: |
|
|
rhos1, selfk1, len1 = self._compute_robustness_time(phis1) |
|
|
rhos2, selfk2, len2 = self._compute_robustness_time(phis2) |
|
|
kernel_matrix = self._compute_kernel_time( |
|
|
rhos1, rhos2, selfk1, selfk2, len1, len2 |
|
|
) |
|
|
else: |
|
|
rhos1, selfk1 = self._compute_robustness_no_time(phis1) |
|
|
rhos2, selfk2 = self._compute_robustness_no_time(phis2) |
|
|
len1, len2 = [None, None] |
|
|
kernel_matrix = self._compute_kernel_no_time(rhos1, rhos2, selfk1, selfk2) |
|
|
if return_robustness: |
|
|
return kernel_matrix.cpu(), rhos1, rhos2, selfk1, selfk2, len1, len2 |
|
|
else: |
|
|
return kernel_matrix.cpu() |
|
|
|
|
|
def compute_one_from_robustness(self, phi, rhos, rho_self, lengths=None, return_robustness=False): |
|
|
phis: list = [phi] |
|
|
return self.compute_bag_from_robustness(phis, rhos, rho_self, lengths, return_robustness) |
|
|
|
|
|
def compute_bag_from_robustness(self, phis, rhos, rho_self, lengths=None, return_robustness=False): |
|
|
if self.integrate_time: |
|
|
rhos1, selfk1, len1 = self._compute_robustness_time(phis) |
|
|
kernel_matrix = self._compute_kernel_time( |
|
|
rhos1, rhos, selfk1, rho_self, len1, lengths |
|
|
) |
|
|
else: |
|
|
rhos1, selfk1 = self._compute_robustness_no_time(phis) |
|
|
len1 = None |
|
|
kernel_matrix = self._compute_kernel_no_time(rhos1, rhos, selfk1, rho_self) |
|
|
if return_robustness: |
|
|
return kernel_matrix.cpu(), rhos1, selfk1, len1 |
|
|
else: |
|
|
return kernel_matrix.cpu() |
|
|
n = self.samples |
|
|
p = self.points |
|
|
k = len(phis) |
|
|
rhos = torch.zeros((k, n, p), device="cpu") |
|
|
lengths = torch.zeros(k) |
|
|
self_kernels = torch.zeros((k, 1)) |
|
|
for i, phi in enumerate(phis): |
|
|
if self.boolean: |
|
|
rho = phi.boolean(self.signals, evaluate_at_all_times=True).float() |
|
|
rho[rho == 0.0] = -1.0 |
|
|
else: |
|
|
rho = phi.quantitative(self.signals, evaluate_at_all_times=True) |
|
|
actual_p = rho.size()[2] |
|
|
rho = rho.reshape(n, actual_p).cpu() |
|
|
rhos[i, :, :actual_p] = rho |
|
|
lengths[i] = actual_p |
|
|
self_kernels[i] = torch.tensordot( |
|
|
rho.reshape(1, n, -1), rho.reshape(1, n, -1), dims=[[1, 2], [1, 2]] |
|
|
) / (actual_p * n) |
|
|
return rhos, self_kernels, lengths |
|
|
|
|
|
def _compute_robustness_no_time(self, phis): |
|
|
n = self.samples |
|
|
k = len(phis) |
|
|
rhos = torch.zeros((k, n), device=self.traj_measure.device) |
|
|
self_kernels = torch.zeros((k, 1), device=self.traj_measure.device) |
|
|
for i, phi in enumerate(phis): |
|
|
if self.boolean: |
|
|
rho = phi.boolean(self.signals, evaluate_at_all_times=False).float() |
|
|
rho[rho == 0.0] = -1.0 |
|
|
else: |
|
|
rho = phi.quantitative(self.signals, evaluate_at_all_times=False) |
|
|
self_kernels[i] = rho.dot(rho) / n |
|
|
rhos[i, :] = rho |
|
|
return rhos, self_kernels |
|
|
|
|
|
def _compute_kernel_time(self, rhos1, rhos2, selfk1, selfk2, len1, len2): |
|
|
kernel_matrix = torch.tensordot(rhos1, rhos2, [[1, 2], [1, 2]]) |
|
|
length_normalizer = self._compute_trajectory_length_normalizer(len1, len2) |
|
|
kernel_matrix = kernel_matrix * length_normalizer / self.samples |
|
|
if self.normalize: |
|
|
kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2) |
|
|
if self.exp_kernel: |
|
|
kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2) |
|
|
return kernel_matrix |
|
|
|
|
|
def _compute_kernel_no_time(self, rhos1, rhos2, selfk1, selfk2): |
|
|
kernel_matrix = torch.tensordot(rhos1, rhos2, [[1], [1]]) |
|
|
kernel_matrix = kernel_matrix / self.samples |
|
|
if self.normalize: |
|
|
kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2) |
|
|
if self.exp_kernel: |
|
|
kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2) |
|
|
return kernel_matrix |
|
|
|
|
|
@staticmethod |
|
|
def _normalize(kernel_matrix, selfk1, selfk2): |
|
|
normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1))) |
|
|
kernel_matrix = kernel_matrix / normalize |
|
|
return kernel_matrix |
|
|
|
|
|
@staticmethod |
|
|
def _normalize(kernel_matrix, selfk1, selfk2): |
|
|
normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1))) |
|
|
kernel_matrix = kernel_matrix / normalize |
|
|
return kernel_matrix |
|
|
|
|
|
def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None): |
|
|
if sigma2 is None: |
|
|
sigma2 = self.sigma2 |
|
|
if self.normalize: |
|
|
|
|
|
selfk = 2.0 |
|
|
else: |
|
|
k1 = selfk1.size()[0] |
|
|
k2 = selfk2.size()[0] |
|
|
selfk = (selfk1 * selfk1).repeat(1, k2) + torch.transpose( |
|
|
selfk2 * selfk2, 0, 1 |
|
|
).repeat(k1, 1) |
|
|
return torch.exp(-(selfk - 2 * kernel_matrix) / (2 * sigma2)) |
|
|
|
|
|
@staticmethod |
|
|
def _compute_trajectory_length_normalizer(len1, len2): |
|
|
k1 = len1.size()[0] |
|
|
k2 = len2.size()[0] |
|
|
y1 = len1.reshape(-1, 1) |
|
|
y1 = y1.repeat(1, k2) |
|
|
y2 = len2.repeat(k1, 1) |
|
|
return 1.0 / torch.min(y1, y2) |
|
|
|
|
|
class GramMatrix: |
|
|
def __init__(self, kernel, formulae, store_robustness=True, sample=False, sampler=None, bag_size=None): |
|
|
self.kernel = kernel |
|
|
self.formulae_list = formulae |
|
|
|
|
|
|
|
|
|
|
|
self.store_robustness = store_robustness |
|
|
self.dim = len(self.formulae_list) if not bag_size else int(bag_size) |
|
|
self.sample = sample |
|
|
if self.sample: |
|
|
self.t = 0.99 if self.kernel.boolean else 0.85 |
|
|
self.sampler = sampler |
|
|
self._compute_gram_matrix() |
|
|
|
|
|
def _compute_gram_matrix(self): |
|
|
if self.sample: |
|
|
gram = torch.zeros(self.dim, self.dim) |
|
|
rhos = torch.zeros((self.dim, self.kernel.samples), device=self.kernel.traj_measure.device) if \ |
|
|
not self.kernel.integrate_time else torch.zeros((self.dim, self.kernel.samples, self.kernel.points), |
|
|
device=self.kernel.traj_measure.device) |
|
|
lengths = torch.zeros(self.dim) if self.kernel.integrate_time else np.zeros(self.dim) |
|
|
kernels = torch.zeros((self.dim, 1), device=self.kernel.traj_measure.device) |
|
|
phis = [self.sampler.sample(nvars=self.kernel.varn)] |
|
|
gram[0, :1], rhos[0], kernels[0, :], lengths[0] = self.kernel.compute_bag(phis, return_robustness=True) |
|
|
while len(phis) < self.dim: |
|
|
i = len(phis) |
|
|
phi = self.sampler.sample(nvars=self.kernel.varn) |
|
|
gram[i, :i], rhos[i], kernels[i, :], lengths[i] = self.kernel.compute_one_from_robustness( |
|
|
phi, rhos[:i, :], kernels[:i, :], lengths[:i], return_robustness=True) |
|
|
if torch.sum(gram[i, :i + 1] >= self.t) < 3: |
|
|
phis.append(phi) |
|
|
gram[:i, i] = gram[i, :i] |
|
|
gram[i, i] = kernels[i, :] |
|
|
|
|
|
self.formulae_list = phis |
|
|
self.gram = gram.cpu() |
|
|
self.robustness = rhos if self.store_robustness else None |
|
|
self.self_kernels = kernels if self.store_robustness else None |
|
|
self.robustness_lengths = lengths if self.store_robustness else None |
|
|
else: |
|
|
if self.store_robustness: |
|
|
k_matrix, rhos, selfk, len0 = self.kernel.compute_bag( |
|
|
self.formulae_list, return_robustness=True |
|
|
) |
|
|
self.gram = k_matrix |
|
|
self.robustness = rhos |
|
|
self.self_kernels = selfk |
|
|
self.robustness_lengths = len0 |
|
|
else: |
|
|
self.gram = self.kernel.compute_bag( |
|
|
self.formulae_list, return_robustness=False |
|
|
) |
|
|
self.robustness = None |
|
|
self.self_kernels = None |
|
|
self.robustness_lengths = None |
|
|
|
|
|
def compute_kernel_vector(self, phi): |
|
|
if self.store_robustness: |
|
|
return self.kernel.compute_one_from_robustness( |
|
|
phi, self.robustness, self.self_kernels, self.robustness_lengths |
|
|
) |
|
|
else: |
|
|
return self.kernel.compute_one_bag(phi, self.formulae_list) |
|
|
|
|
|
def compute_bag_kernel_vector(self, phis, generate_phis=False, bag_size=None): |
|
|
if generate_phis: |
|
|
gram_test = torch.zeros(bag_size, self.dim) |
|
|
rhos_test = torch.zeros((bag_size, self.kernel.samples), device=self.kernel.traj_measure.device) if \ |
|
|
not self.kernel.integrate_time else torch.zeros((bag_size, self.kernel.samples, self.kernel.points), |
|
|
device=self.kernel.traj_measure.device) |
|
|
lengths_test = torch.zeros(bag_size) if self.kernel.integrate_time else np.zeros(bag_size) |
|
|
kernels_test = torch.zeros((bag_size, 1), device=self.kernel.traj_measure.device) |
|
|
phi_test = [] |
|
|
while len(phi_test) < bag_size: |
|
|
i = len(phi_test) |
|
|
phi = self.sampler.sample(nvars=self.kernel.varn) |
|
|
if self.store_robustness: |
|
|
gram_test[i, :], rhos_test[i], kernels_test[i, :], lengths_test[i] = \ |
|
|
self.kernel.compute_one_from_robustness(phi, self.robustness, self.self_kernels, |
|
|
self.robustness_lengths, return_robustness=True) |
|
|
else: |
|
|
gram_test[i, :], rhos_test[i], _, kernels_test[i, :], _, lengths_test[i], _ = \ |
|
|
self.kernel.compute_one_bag(phi, self.formulae_list, return_robustness=True) |
|
|
if not ((rhos_test[i] > 0).all() or (rhos_test[i] < 0).all()): |
|
|
phi_test.append(phi) |
|
|
return phi_test, gram_test.cpu() |
|
|
else: |
|
|
if self.store_robustness: |
|
|
return self.kernel.compute_bag_from_robustness( |
|
|
phis, self.robustness, self.self_kernels, self.robustness_lengths |
|
|
) |
|
|
else: |
|
|
return self.kernel.compute_bag_bag(phis, self.formulae_list) |
|
|
|
|
|
def invert_regularized(self, alpha): |
|
|
regularizer = abs(pow(10, alpha)) * torch.eye(self.dim) |
|
|
return torch.inverse(self.gram + regularizer) |
|
|
|
|
|
|
|
|
|
|
|
def anchorGeneration(diff_init = False, |
|
|
embed_dim: int = 30, |
|
|
n_vars: int = 3, |
|
|
leaf_prob: float = 0.4, |
|
|
cosine_similarity_threshold: float = 0.8 |
|
|
) -> str: |
|
|
|
|
|
|
|
|
sampler = StlGenerator(leaf_prob) |
|
|
|
|
|
|
|
|
if diff_init: |
|
|
|
|
|
|
|
|
diff_anchor_set = [sampler.sample(nvars=n_vars)] |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
mu = BaseMeasure(device=device) |
|
|
|
|
|
|
|
|
signals = mu.sample(samples=10000, varn=n_vars) |
|
|
|
|
|
|
|
|
anchor_rob_vectors = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in diff_anchor_set], 0) |
|
|
|
|
|
while len(diff_anchor_set) < embed_dim: |
|
|
|
|
|
candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars) |
|
|
|
|
|
|
|
|
candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0) |
|
|
|
|
|
|
|
|
cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1) |
|
|
|
|
|
|
|
|
|
|
|
similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])] |
|
|
|
|
|
|
|
|
keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist]))) |
|
|
|
|
|
diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx] |
|
|
|
|
|
|
|
|
keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device) |
|
|
|
|
|
|
|
|
selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor) |
|
|
|
|
|
|
|
|
anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0) |
|
|
|
|
|
anchor_set = diff_anchor_set[:embed_dim] |
|
|
|
|
|
else: |
|
|
anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars) |
|
|
|
|
|
filename = f'anchor_set_no_diff_{embed_dim}_dim' |
|
|
dump_pickle(filename, anchor_set) |
|
|
return filename |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process. |
|
|
This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs, |
|
|
and handle padding and special tokens. |
|
|
""" |
|
|
|
|
|
def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad", |
|
|
bos_token: str = "/s", eos_token: str = "s", model_max_length = 512, *args, **kwargs): |
|
|
""" |
|
|
Initializes the STLTokenizer with a given vocabulary and special tokens. |
|
|
Args: |
|
|
vocab_path (str): The path to the JSON file containing the vocabulary. |
|
|
unk_token (str, optional): The token used for unknown words. Defaults to "unk". |
|
|
pad_token (str, optional): The token used for padding. Defaults to "pad". |
|
|
bos_token (str, optional): The token used for the beginning of a sequence. Defaults to "/s". |
|
|
eos_token (str, optional): The token used for the end of a sequence. Defaults to "s". |
|
|
""" |
|
|
self.vocab = load_json(vocab_path) |
|
|
self.unk_token = unk_token |
|
|
self.pad_token = pad_token |
|
|
self.bos_token = bos_token |
|
|
self.eos_token = eos_token |
|
|
self.model_max_length = model_max_length |
|
|
self.id_to_token = {v: k for k, v in self.vocab.items()} |
|
|
super().__init__(unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token, |
|
|
model_max_length=model_max_length, *args, **kwargs) |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
""" |
|
|
Returns the size of the vocabulary. |
|
|
Returns: |
|
|
int: The number of tokens in the vocabulary. |
|
|
""" |
|
|
return len(self.vocab) |
|
|
|
|
|
def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False): |
|
|
""" |
|
|
Replaces spaces in the input sequence with a specified token. |
|
|
Args: |
|
|
sequence (str): The input sequence. |
|
|
undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces. |
|
|
Returns: |
|
|
str: The preprocessed sequence with spaces or padding tokens replaced. |
|
|
""" |
|
|
if undo: |
|
|
return sequence.replace(new_space_token, space_token) |
|
|
else: |
|
|
return sequence.replace(space_token, new_space_token) |
|
|
|
|
|
def add_bos_eos(self, sequence: str) -> str: |
|
|
""" |
|
|
Aggiunge i token BOS all'inizio e EOS alla fine della sequenza. |
|
|
Args: |
|
|
sequence (str): La sequenza di input. |
|
|
Returns: |
|
|
str: La sequenza con i token BOS ed EOS. |
|
|
""" |
|
|
return f'{self.bos_token} {sequence} {self.eos_token}' |
|
|
|
|
|
def tokenize(self, text: str) -> List[str]: |
|
|
""" |
|
|
Tokenizes the input text into a list of tokens. |
|
|
The method preprocesses the input text by replacing spaces with padding tokens and then tries to |
|
|
find the longest possible match for each substring in the vocabulary. |
|
|
Args: |
|
|
text (str): The input text to be tokenized. |
|
|
Returns: |
|
|
List[str]: A list of tokens representing the tokenized text. |
|
|
""" |
|
|
text = self.add_bos_eos(text) |
|
|
text = self.prepad_sequence(text) |
|
|
tokens = [] |
|
|
i = 0 |
|
|
while i < len(text): |
|
|
best_match = None |
|
|
for j in range(len(text), i, -1): |
|
|
subtoken = text[i:j] |
|
|
if subtoken in self.vocab: |
|
|
best_match = subtoken |
|
|
break |
|
|
if best_match: |
|
|
tokens.append(best_match) |
|
|
i += len(best_match) |
|
|
else: |
|
|
tokens.append(self.unk_token) |
|
|
i += 1 |
|
|
return tokens |
|
|
|
|
|
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: |
|
|
""" |
|
|
Converts a list of tokens into a list of token IDs. |
|
|
Args: |
|
|
tokens (List[str]): A list of tokens to be converted into IDs. |
|
|
Returns: |
|
|
List[int]: A list of corresponding token IDs. |
|
|
""" |
|
|
return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens] |
|
|
|
|
|
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: |
|
|
""" |
|
|
Converts a list of token IDs into a list of tokens. |
|
|
Args: |
|
|
ids (List[int]): A list of token IDs to be converted into tokens. |
|
|
Returns: |
|
|
List[str]: A list of corresponding tokens. |
|
|
""" |
|
|
return [self.id_to_token.get(i, self.unk_token) for i in ids] |
|
|
|
|
|
def encode(self, sequence: str) -> List[int]: |
|
|
""" |
|
|
Encodes a string sequence into a list of token IDs. |
|
|
|
|
|
This method tokenizes the input sequence using the `tokenize` method, |
|
|
and then converts the resulting tokens into their corresponding token IDs |
|
|
using the `convert_tokens_to_ids` method. |
|
|
|
|
|
Args: |
|
|
sequence (str): The input sequence (text) to be encoded. |
|
|
|
|
|
Returns: |
|
|
List[int]: A list of token IDs corresponding to the input sequence. |
|
|
""" |
|
|
splitted_sequence = self.tokenize(sequence) |
|
|
return self.convert_tokens_to_ids(splitted_sequence) |
|
|
|
|
|
def postpad_sequence(self, sequence, pad_token_id): |
|
|
""" |
|
|
Fills the sequence up to max_length padding elements |
|
|
""" |
|
|
num_extra_elements = self.model_max_length - len(sequence) -1 |
|
|
if num_extra_elements > 0: |
|
|
sequence.extend([pad_token_id] * num_extra_elements) |
|
|
return sequence |
|
|
|
|
|
def decode(self, token_ids: List[int]) -> str: |
|
|
""" |
|
|
Decodes a list of token IDs into a string of text. |
|
|
The method converts the IDs to tokens and joins them to form a string. |
|
|
It also restores the original spaces or padding tokens if `undo` is True. |
|
|
Args: |
|
|
token_ids (List[int]): A list of token IDs to be decoded. |
|
|
skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False. |
|
|
Returns: |
|
|
str: The decoded string. |
|
|
""" |
|
|
tokens = self.convert_ids_to_tokens(token_ids) |
|
|
decoded = "".join(tokens) |
|
|
return self.prepad_sequence(decoded, undo=True) |
|
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
|
|
""" |
|
|
Saves the tokenizer's vocabulary to a file. |
|
|
Useful only when the vocabulary has to be retrieved and is not given |
|
|
(thus this is not the case: here to further improvements with sentencepiece). |
|
|
This method saves the vocabulary to a JSON file in the specified directory. |
|
|
Args: |
|
|
save_directory (str): The directory where the vocabulary file will be saved. |
|
|
filename_prefix (Optional[str]): An optional prefix for the filename. |
|
|
Returns: |
|
|
Tuple[str]: A tuple containing the path to the saved vocabulary file. |
|
|
""" |
|
|
vocab_file = f"{save_directory}/{filename_prefix + '-' if filename_prefix else ''}vocab.json" |
|
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
|
json.dump(self.vocab, f, indent=2, ensure_ascii=False) |
|
|
return (vocab_file,) |
|
|
|
|
|
def get_vocab(self) -> dict: |
|
|
""" |
|
|
Retrieves the vocabulary used by the tokenizer. |
|
|
Returns: |
|
|
dict: The vocabulary as a dictionary. |
|
|
""" |
|
|
return self.vocab |
|
|
|
|
|
class STLSinusoidalPositionalEmbedding(nn.Embedding): |
|
|
"""This module produces sinusoidal positional embeddings of any length.""" |
|
|
|
|
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None: |
|
|
super().__init__(num_positions, embedding_dim) |
|
|
self.weight = self._init_weight(self.weight) |
|
|
|
|
|
@staticmethod |
|
|
def _init_weight(out: nn.Parameter) -> nn.Parameter: |
|
|
""" |
|
|
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in |
|
|
the 2nd half of the vector. [dim // 2:] |
|
|
""" |
|
|
n_pos, dim = out.shape |
|
|
position_enc = np.array( |
|
|
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] |
|
|
) |
|
|
out.requires_grad = False |
|
|
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 |
|
|
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) |
|
|
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) |
|
|
out.detach_() |
|
|
return out |
|
|
@torch.no_grad() |
|
|
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: |
|
|
"""`input_ids_shape` is expected to be [bsz x seqlen].""" |
|
|
bsz, seq_len = input_ids_shape[:2] |
|
|
positions = torch.arange( |
|
|
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device |
|
|
) |
|
|
return super().forward(positions) |
|
|
|
|
|
class STLAttention(nn.Module): |
|
|
""" Multi-Head Attention as depicted from 'Attention is all you need' """ |
|
|
|
|
|
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, |
|
|
is_decoder: bool = False, bias: bool = False, is_causal: bool = False): |
|
|
|
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.dropout = dropout |
|
|
self.head_dim = embed_dim // num_heads |
|
|
assert (self.head_dim * num_heads) == self.embed_dim |
|
|
self.scaling = self.head_dim ** -0.5 |
|
|
self.is_decoder = is_decoder |
|
|
self.is_causal = is_causal |
|
|
|
|
|
|
|
|
self.W_k = nn.Linear(embed_dim, embed_dim, bias = bias) |
|
|
self.W_q = nn.Linear(embed_dim, embed_dim, bias = bias) |
|
|
self.W_v = nn.Linear(embed_dim, embed_dim, bias = bias) |
|
|
|
|
|
|
|
|
self.W_o = nn.Linear(embed_dim, embed_dim, bias = bias) |
|
|
|
|
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): |
|
|
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
|
|
|
|
|
def forward(self, |
|
|
hidden_states: torch.Tensor, |
|
|
key_value_states: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
layer_head_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
|
|
|
|
batch_size, tgt_len, embed_dim = hidden_states.size() |
|
|
|
|
|
|
|
|
query = self.W_q(hidden_states) * self.scaling |
|
|
|
|
|
if (is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]): |
|
|
key = past_key_value[0] |
|
|
value = past_key_value[1] |
|
|
elif is_cross_attention: |
|
|
key = self._shape(self.W_k(key_value_states), -1, batch_size) |
|
|
value = self._shape(self.W_v(key_value_states), -1, batch_size) |
|
|
elif past_key_value is not None: |
|
|
key = self._shape(self.W_k(hidden_states), -1, batch_size) |
|
|
value = self._shape(self.W_v(hidden_states), -1, batch_size) |
|
|
key = torch.cat([past_key_value[0], key], dim=2) |
|
|
value = torch.cat([past_key_value[1], value], dim=2) |
|
|
else: |
|
|
key = self._shape(self.W_k(hidden_states), -1, batch_size) |
|
|
value = self._shape(self.W_v(hidden_states), -1, batch_size) |
|
|
if self.is_decoder: |
|
|
past_key_value = (key, value) |
|
|
|
|
|
proj_shape = (batch_size * self.num_heads, -1, self.head_dim) |
|
|
|
|
|
query = self._shape(query, tgt_len, batch_size).view(*proj_shape) |
|
|
key = key.reshape(*proj_shape) |
|
|
value = value.reshape(*proj_shape) |
|
|
|
|
|
src_len = key.size(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = torch.bmm(query, key.transpose(1, 2)) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attn_weights = attn_weights.view(batch_size, self.num_heads, tgt_len, src_len) + attention_mask |
|
|
attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len) |
|
|
|
|
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
|
|
|
|
|
|
attn_output = torch.bmm(attn_probs, value) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim) |
|
|
attn_output = attn_output.transpose(1, 2) |
|
|
|
|
|
attn_output = attn_output.reshape(batch_size, tgt_len, self.embed_dim) |
|
|
attn_output = self.W_o(attn_output) |
|
|
|
|
|
return attn_output, None, past_key_value |
|
|
|
|
|
|
|
|
|
|
|
class STLEncoder(): |
|
|
def __init__(self, |
|
|
embed_dim: int, |
|
|
anchor_filename: Optional[str] = None, |
|
|
n_vars: int = 3): |
|
|
|
|
|
self.n_vars = n_vars |
|
|
self.embed_dim = embed_dim |
|
|
self.anchorset_filename = anchor_filename |
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
self.mu = BaseMeasure(device=self.device) |
|
|
self.kernel = StlKernel(self.mu, varn=self.n_vars) |
|
|
|
|
|
if anchor_filename is None: |
|
|
anchor_filename = anchorGeneration(diff_init = True, embed_dim = self.embed_dim, n_vars = self.n_vars) |
|
|
anchor_filename+='.pickle' |
|
|
|
|
|
|
|
|
anchor_set = load_pickle(anchor_filename) |
|
|
if len(anchor_set) != self.embed_dim: |
|
|
raise ValueError("The anchor set and the embedding dimension do not match!") |
|
|
|
|
|
self.anchor_set = anchor_set |
|
|
|
|
|
def compute_embeddings(self, formula: List[str]): |
|
|
return self.kernel.compute_bag_bag(formula, self.anchor_set) |
|
|
|
|
|
class STLModel(PreTrainedModel): |
|
|
config_class = STLConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
|
|
|
def _init_weights(self, module: Union[nn.Linear, nn.Embedding, STLSinusoidalPositionalEmbedding]): |
|
|
std = self.config.init_std |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, STLSinusoidalPositionalEmbedding): |
|
|
pass |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
@property |
|
|
def dummy_inputs(self): |
|
|
pad_token = self.config.pad_token_id |
|
|
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) |
|
|
dummy_inputs = { |
|
|
"attention_mask": input_ids.ne(pad_token), |
|
|
"input_ids": input_ids, |
|
|
"decoder_input_ids": input_ids, |
|
|
} |
|
|
return dummy_inputs |
|
|
|
|
|
class STLDecoderBlock(nn.Module): |
|
|
|
|
|
def __init__(self, embed_dim: int, |
|
|
num_decoder_attention_heads: int, |
|
|
num_decoder_ffn_dim: int, |
|
|
dropout: float = 0.0, |
|
|
attention_dropout: float = 0.0, |
|
|
activation_dropout: float = 0.0, |
|
|
): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
|
|
|
self.self_attn = STLAttention( |
|
|
embed_dim=self.embed_dim, |
|
|
num_heads=num_decoder_attention_heads, |
|
|
dropout=dropout, |
|
|
is_decoder=True, |
|
|
is_causal=True, |
|
|
) |
|
|
self.dropout = dropout |
|
|
self.activation_fn = nn.functional.gelu |
|
|
self.activation_dropout = activation_dropout |
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
|
|
|
|
|
self.encoder_attn = STLAttention( |
|
|
self.embed_dim, |
|
|
num_decoder_attention_heads, |
|
|
dropout=attention_dropout, |
|
|
is_decoder=True, |
|
|
) |
|
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(self.embed_dim, num_decoder_ffn_dim) |
|
|
self.fc2 = nn.Linear(num_decoder_ffn_dim, self.embed_dim) |
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
layer_head_mask: Optional[torch.Tensor] = None, |
|
|
cross_attn_layer_head_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
use_cache: Optional[bool] = True, |
|
|
**kwargs, |
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
|
attention_mask (`torch.FloatTensor`): attention mask of size |
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
|
encoder_hidden_states (`torch.FloatTensor`): |
|
|
cross attention input to the layer of shape `(batch, seq_len, embed_dim)` |
|
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size |
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size |
|
|
`(encoder_attention_heads,)`. |
|
|
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of |
|
|
size `(decoder_attention_heads,)`. |
|
|
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states |
|
|
output_attentions (`bool`, *optional*): |
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
|
returned tensors for more detail. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
|
|
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
|
|
|
|
|
|
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn.forward( |
|
|
hidden_states=hidden_states, |
|
|
past_key_value=self_attn_past_key_value, |
|
|
attention_mask=attention_mask, |
|
|
layer_head_mask=layer_head_mask, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
|
|
|
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cross_attn_present_key_value = None |
|
|
cross_attn_weights = None |
|
|
|
|
|
|
|
|
if encoder_hidden_states is not None: |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None |
|
|
|
|
|
|
|
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn.forward( |
|
|
hidden_states=hidden_states, |
|
|
key_value_states=encoder_hidden_states, |
|
|
attention_mask=encoder_attention_mask, |
|
|
layer_head_mask=cross_attn_layer_head_mask, |
|
|
past_key_value=cross_attn_past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
|
|
|
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
hidden_states = self.encoder_attn_layer_norm(hidden_states) |
|
|
|
|
|
|
|
|
present_key_value = present_key_value + cross_attn_present_key_value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
|
|
hidden_states = self.fc2(hidden_states) |
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
|
|
|
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
hidden_states = self.final_layer_norm(hidden_states) |
|
|
|
|
|
outputs = (hidden_states,) |
|
|
|
|
|
if output_attentions: |
|
|
outputs += (self_attn_weights, cross_attn_weights) |
|
|
|
|
|
if use_cache: |
|
|
outputs += (present_key_value,) |
|
|
|
|
|
return outputs |
|
|
|
|
|
class STLDecoder(STLModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
embed_dim = config.d_model |
|
|
num_decoder_attention_heads = config.decoder_attention_heads |
|
|
num_decoder_ffn_dim = config.decoder_ffn_dim |
|
|
max_position_embeddings = config.max_position_embeddings |
|
|
decoder_vocab_size = config.vocab_size |
|
|
pad_token_id = config.pad_token_id |
|
|
num_decoder_layers = config.decoder_layers |
|
|
scale_embedding = config.scale_embedding |
|
|
dropout = config.dropout |
|
|
attention_dropout = config.attention_dropout |
|
|
activation_dropout = config.activation_dropout |
|
|
decoder_layerdrop = config.decoder_layerdrop |
|
|
|
|
|
self.dropout = dropout |
|
|
self.layerdrop = decoder_layerdrop |
|
|
self.padding_idx = pad_token_id |
|
|
self.max_target_positions = max_position_embeddings |
|
|
self.embed_scale = math.sqrt(embed_dim) if scale_embedding else 1.0 |
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(decoder_vocab_size, embed_dim, self.padding_idx) |
|
|
|
|
|
|
|
|
self.embed_positions = STLSinusoidalPositionalEmbedding( |
|
|
max_position_embeddings, embed_dim, self.padding_idx |
|
|
) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([STLDecoderBlock(embed_dim, num_decoder_attention_heads, |
|
|
num_decoder_ffn_dim, dropout, |
|
|
attention_dropout, activation_dropout) |
|
|
for _ in range(num_decoder_layers)]) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs, |
|
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
|
|
elif input_ids is not None: |
|
|
input_shape = input_ids.size() |
|
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
|
elif inputs_embeds is not None: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
else: |
|
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
|
|
|
|
|
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale |
|
|
|
|
|
attention_mask = _prepare_4d_causal_attention_mask( |
|
|
attention_mask, input_shape, inputs_embeds, past_key_values_length |
|
|
) |
|
|
|
|
|
|
|
|
if encoder_hidden_states is not None and encoder_attention_mask is not None: |
|
|
|
|
|
encoder_attention_mask = _prepare_4d_attention_mask( |
|
|
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] |
|
|
) |
|
|
|
|
|
|
|
|
positions = self.embed_positions(input_shape, past_key_values_length) |
|
|
|
|
|
hidden_states = inputs_embeds + positions |
|
|
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
if use_cache: |
|
|
logger.warning_once( |
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
|
) |
|
|
use_cache = False |
|
|
|
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attns = () if output_attentions else None |
|
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
|
|
next_decoder_cache = () if use_cache else None |
|
|
|
|
|
for idx, decoder_layer in enumerate(self.layers): |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
if self.training: |
|
|
dropout_probability = torch.rand([]) |
|
|
if dropout_probability < self.layerdrop: |
|
|
continue |
|
|
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = self._gradient_checkpointing_func( |
|
|
decoder_layer.__call__, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
encoder_hidden_states, |
|
|
encoder_attention_mask, |
|
|
head_mask[idx] if head_mask is not None else None, |
|
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, |
|
|
None, |
|
|
output_attentions, |
|
|
use_cache, |
|
|
) |
|
|
else: |
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
|
|
cross_attn_layer_head_mask=( |
|
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None |
|
|
), |
|
|
past_key_value=past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if use_cache: |
|
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) |
|
|
|
|
|
if output_attentions: |
|
|
all_self_attns += (layer_outputs[1],) |
|
|
|
|
|
if encoder_hidden_states is not None: |
|
|
all_cross_attentions += (layer_outputs[2],) |
|
|
|
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
|
if not return_dict: |
|
|
return tuple( |
|
|
v |
|
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] |
|
|
if v is not None |
|
|
) |
|
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=next_cache, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attns, |
|
|
cross_attentions=all_cross_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class STLForCausalLM(STLModel, GenerationMixin): |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
config = copy.deepcopy(config) |
|
|
config.is_decoder = True |
|
|
config.is_encoder_decoder = False |
|
|
|
|
|
super().__init__(config) |
|
|
self.model = STLDecoder(config) |
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.model.embed_tokens = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.model = decoder |
|
|
|
|
|
def get_decoder(self): |
|
|
return self.model |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
head_mask=head_mask, |
|
|
cross_attn_head_mask=cross_attn_head_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
logits = self.lm_head(outputs[0]) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
labels = labels.to(logits.device) |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
cross_attentions=outputs.cross_attentions, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _reorder_cache(past_key_values, beam_idx): |
|
|
reordered_past = () |
|
|
for layer_past in past_key_values: |
|
|
reordered_past += ( |
|
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
|
|
) |
|
|
return reordered_past |
|
|
|
|
|
|