update
Browse files- modeling.py +704 -8
modeling.py
CHANGED
|
@@ -33,18 +33,13 @@ from transformers.modeling_outputs import (
|
|
| 33 |
)
|
| 34 |
|
| 35 |
from configuration import STLConfig
|
| 36 |
-
# from handcoded_tokenizer import STLTokenizer
|
| 37 |
from nltk.translate.bleu_score import sentence_bleu
|
| 38 |
from stl import *
|
| 39 |
import networkx as nx
|
| 40 |
-
# import phis_generator_depth
|
| 41 |
from datasets import load_dataset
|
| 42 |
|
| 43 |
-
|
| 44 |
-
from
|
| 45 |
-
from traj_measure import BaseMeasure
|
| 46 |
-
from kernel import StlKernel
|
| 47 |
-
from anchor_set_generation import anchorGeneration
|
| 48 |
|
| 49 |
import re
|
| 50 |
import json
|
|
@@ -54,6 +49,105 @@ from transformers.utils import logging
|
|
| 54 |
|
| 55 |
logger = logging.get_logger(__name__)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def load_json(path: str) -> Union[Dict, List]:
|
| 59 |
"""
|
|
@@ -68,6 +162,607 @@ def load_json(path: str) -> Union[Dict, List]:
|
|
| 68 |
with open(path, "r") as f:
|
| 69 |
return json.load(f)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
class STLTokenizer(PreTrainedTokenizer):
|
| 73 |
"""
|
|
@@ -404,6 +1099,7 @@ class STLAttention(nn.Module):
|
|
| 404 |
|
| 405 |
return attn_output, None, past_key_value
|
| 406 |
|
|
|
|
| 407 |
|
| 408 |
class STLEncoder():
|
| 409 |
def __init__(self,
|
|
@@ -808,7 +1504,7 @@ class STLDecoder(STLModel):
|
|
| 808 |
cross_attentions=all_cross_attentions,
|
| 809 |
)
|
| 810 |
|
| 811 |
-
|
| 812 |
|
| 813 |
class STLForCausalLM(STLModel, GenerationMixin):
|
| 814 |
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
from configuration import STLConfig
|
|
|
|
| 36 |
from nltk.translate.bleu_score import sentence_bleu
|
| 37 |
from stl import *
|
| 38 |
import networkx as nx
|
|
|
|
| 39 |
from datasets import load_dataset
|
| 40 |
|
| 41 |
+
|
| 42 |
+
# from anchor_set_generation import anchorGeneration
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
import re
|
| 45 |
import json
|
|
|
|
| 49 |
|
| 50 |
logger = logging.get_logger(__name__)
|
| 51 |
|
| 52 |
+
#### utils ####
|
| 53 |
+
|
| 54 |
+
def load_pickle(path):
|
| 55 |
+
with open(path, 'rb') as f:
|
| 56 |
+
x = pickle.load(f)
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
def dump_pickle(name, thing):
|
| 60 |
+
with open(name + '.pickle', 'wb') as f:
|
| 61 |
+
pickle.dump(thing, f)
|
| 62 |
+
|
| 63 |
+
def from_string_to_formula(st):
|
| 64 |
+
root_arity = 2 if st.startswith('(') else 1
|
| 65 |
+
st_split = st.split()
|
| 66 |
+
if root_arity <= 1:
|
| 67 |
+
root_op_str = copy.deepcopy(st_split[0])
|
| 68 |
+
if root_op_str.startswith('x'):
|
| 69 |
+
atom_sign = True if st_split[1] == '<=' else False
|
| 70 |
+
root_phi = Atom(var_index=int(st_split[0][2]), lte=atom_sign, threshold=float(st_split[2]))
|
| 71 |
+
return root_phi
|
| 72 |
+
else:
|
| 73 |
+
assert (root_op_str.startswith('not') or root_op_str.startswith('eventually')
|
| 74 |
+
or root_op_str.startswith('always'))
|
| 75 |
+
current_st = copy.deepcopy(st_split[2:-1])
|
| 76 |
+
if root_op_str == 'not':
|
| 77 |
+
root_phi = Not(child=from_string_to_formula(' '.join(current_st)))
|
| 78 |
+
elif root_op_str.startswith('eventually'):
|
| 79 |
+
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
|
| 80 |
+
root_phi = Eventually(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
|
| 81 |
+
right_unbound=right_unbound, left_time_bound=left_time_bound,
|
| 82 |
+
right_time_bound=right_time_bound)
|
| 83 |
+
else:
|
| 84 |
+
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
|
| 85 |
+
root_phi = Globally(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
|
| 86 |
+
right_unbound=right_unbound, left_time_bound=left_time_bound,
|
| 87 |
+
right_time_bound=right_time_bound)
|
| 88 |
+
else:
|
| 89 |
+
# 1 - delete everything which is contained in other sets of parenthesis (if any)
|
| 90 |
+
current_st = copy.deepcopy(st_split[1:-1])
|
| 91 |
+
if '(' in current_st:
|
| 92 |
+
par_queue = deque()
|
| 93 |
+
par_idx_list = []
|
| 94 |
+
for i, sub in enumerate(current_st):
|
| 95 |
+
if sub == '(':
|
| 96 |
+
par_queue.append(i)
|
| 97 |
+
elif sub == ')':
|
| 98 |
+
par_idx_list.append(tuple([par_queue.pop(), i]))
|
| 99 |
+
# open_par_idx, close_par_idx = [current_st.index(p) for p in ['(', ')']]
|
| 100 |
+
# union of parentheses range --> from these we may extract the substrings to be the children!!!
|
| 101 |
+
children_range = []
|
| 102 |
+
for begin, end in sorted(par_idx_list):
|
| 103 |
+
if children_range and children_range[-1][1] >= begin - 1:
|
| 104 |
+
children_range[-1][1] = max(children_range[-1][1], end)
|
| 105 |
+
else:
|
| 106 |
+
children_range.append([begin, end])
|
| 107 |
+
n_children = len(children_range)
|
| 108 |
+
assert (n_children in [1, 2])
|
| 109 |
+
if n_children == 1:
|
| 110 |
+
# one of the children is a variable --> need to individuate it
|
| 111 |
+
var_child_idx = 1 if children_range[0][0] <= 1 else 0 # 0 is left child, 1 is right child
|
| 112 |
+
if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
|
| 113 |
+
children_range[0][0] -= 1
|
| 114 |
+
left_child_str = current_st[:3] if var_child_idx == 0 else \
|
| 115 |
+
current_st[children_range[0][0]:children_range[0][1] + 1]
|
| 116 |
+
right_child_str = current_st[-3:] if var_child_idx == 1 else \
|
| 117 |
+
current_st[children_range[0][0]:children_range[0][1] + 1]
|
| 118 |
+
root_op_str = current_st[children_range[0][1] + 1] if var_child_idx == 1 else \
|
| 119 |
+
current_st[children_range[0][0] - 1]
|
| 120 |
+
assert (root_op_str[:2] in ['an', 'or', 'un'])
|
| 121 |
+
else:
|
| 122 |
+
if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
|
| 123 |
+
children_range[0][0] -= 1
|
| 124 |
+
if current_st[children_range[1][0] - 1][0:2] in ['no', 'ev', 'al']:
|
| 125 |
+
children_range[1][0] -= 1
|
| 126 |
+
# if there are two children, with parentheses, the element in the middle is the root
|
| 127 |
+
root_op_str = current_st[children_range[0][1] + 1]
|
| 128 |
+
assert (root_op_str[:2] in ['an', 'or', 'un'])
|
| 129 |
+
left_child_str = current_st[children_range[0][0]:children_range[0][1] + 1]
|
| 130 |
+
right_child_str = current_st[children_range[1][0]:children_range[1][1] + 1]
|
| 131 |
+
else:
|
| 132 |
+
# no parentheses means that both children are variables
|
| 133 |
+
left_child_str = current_st[:3]
|
| 134 |
+
right_child_str = current_st[-3:]
|
| 135 |
+
root_op_str = current_st[3]
|
| 136 |
+
left_child_str = ' '.join(left_child_str)
|
| 137 |
+
right_child_str = ' '.join(right_child_str)
|
| 138 |
+
if root_op_str == 'and':
|
| 139 |
+
root_phi = And(left_child=from_string_to_formula(left_child_str),
|
| 140 |
+
right_child=from_string_to_formula(right_child_str))
|
| 141 |
+
elif root_op_str == 'or':
|
| 142 |
+
root_phi = Or(left_child=from_string_to_formula(left_child_str),
|
| 143 |
+
right_child=from_string_to_formula(right_child_str))
|
| 144 |
+
else:
|
| 145 |
+
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
|
| 146 |
+
root_phi = Until(left_child=from_string_to_formula(left_child_str),
|
| 147 |
+
right_child=from_string_to_formula(right_child_str),
|
| 148 |
+
unbound=unbound, right_unbound=right_unbound, left_time_bound=left_time_bound,
|
| 149 |
+
right_time_bound=right_time_bound)
|
| 150 |
+
return root_phi
|
| 151 |
|
| 152 |
def load_json(path: str) -> Union[Dict, List]:
|
| 153 |
"""
|
|
|
|
| 162 |
with open(path, "r") as f:
|
| 163 |
return json.load(f)
|
| 164 |
|
| 165 |
+
#### phis_generator ####
|
| 166 |
+
|
| 167 |
+
class StlGenerator:
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
leaf_prob: float = 0.3,
|
| 171 |
+
inner_node_prob: list = None,
|
| 172 |
+
threshold_mean: float = 0.0,
|
| 173 |
+
threshold_sd: float = 1.0,
|
| 174 |
+
unbound_prob: float = 0.1,
|
| 175 |
+
right_unbound_prob: float = 0.2,
|
| 176 |
+
time_bound_max_range: float = 20,
|
| 177 |
+
adaptive_unbound_temporal_ops: bool = True,
|
| 178 |
+
max_timespan: int = 100,
|
| 179 |
+
):
|
| 180 |
+
"""
|
| 181 |
+
leaf_prob
|
| 182 |
+
probability of generating a leaf (always zero for root)
|
| 183 |
+
node_types = ["not", "and", "or", "always", "eventually", "until"]
|
| 184 |
+
Inner node types
|
| 185 |
+
inner_node_prob
|
| 186 |
+
probability vector for the different types of internal nodes
|
| 187 |
+
threshold_mean
|
| 188 |
+
threshold_sd
|
| 189 |
+
mean and std for the normal distribution of the thresholds of atoms
|
| 190 |
+
unbound_prob
|
| 191 |
+
probability of a temporal operator to have a time bound o the type [0,infty]
|
| 192 |
+
time_bound_max_range
|
| 193 |
+
maximum value of time span of a temporal operator (i.e. max value of t in [0,t])
|
| 194 |
+
adaptive_unbound_temporal_ops
|
| 195 |
+
if true, unbounded temporal operators are computed from current point to the end of the signal, otherwise
|
| 196 |
+
they are evaluated only at time zero.
|
| 197 |
+
max_timespan
|
| 198 |
+
maximum time depth of a formula.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
# Address the mutability of default arguments
|
| 202 |
+
if inner_node_prob is None:
|
| 203 |
+
inner_node_prob = [0.166, 0.166, 0.166, 0.17, 0.166, 0.166]
|
| 204 |
+
|
| 205 |
+
self.leaf_prob = leaf_prob
|
| 206 |
+
self.inner_node_prob = inner_node_prob
|
| 207 |
+
self.threshold_mean = threshold_mean
|
| 208 |
+
self.threshold_sd = threshold_sd
|
| 209 |
+
self.unbound_prob = unbound_prob
|
| 210 |
+
self.right_unbound_prob = right_unbound_prob
|
| 211 |
+
self.time_bound_max_range = time_bound_max_range
|
| 212 |
+
self.adaptive_unbound_temporal_ops = adaptive_unbound_temporal_ops
|
| 213 |
+
self.node_types = ["not", "and", "or", "always", "eventually", "until"]
|
| 214 |
+
self.max_timespan = max_timespan
|
| 215 |
+
|
| 216 |
+
def sample(self, nvars):
|
| 217 |
+
"""
|
| 218 |
+
Samples a random formula with distribution defined in class instance parameters
|
| 219 |
+
|
| 220 |
+
Parameters
|
| 221 |
+
----------
|
| 222 |
+
nvars : number of variables of input signals
|
| 223 |
+
how many variables the formula is expected to consider.
|
| 224 |
+
|
| 225 |
+
Returns
|
| 226 |
+
-------
|
| 227 |
+
TYPE
|
| 228 |
+
A random formula.
|
| 229 |
+
|
| 230 |
+
"""
|
| 231 |
+
return self._sample_internal_node(nvars)
|
| 232 |
+
|
| 233 |
+
def bag_sample(self, bag_size, nvars):
|
| 234 |
+
"""
|
| 235 |
+
Samples a bag of bag_size formulae
|
| 236 |
+
|
| 237 |
+
Parameters
|
| 238 |
+
----------
|
| 239 |
+
bag_size : INT
|
| 240 |
+
number of formulae.
|
| 241 |
+
nvars : INT
|
| 242 |
+
number of vars in formulae.
|
| 243 |
+
|
| 244 |
+
Returns
|
| 245 |
+
-------
|
| 246 |
+
a list of formulae.
|
| 247 |
+
|
| 248 |
+
"""
|
| 249 |
+
formulae = []
|
| 250 |
+
for _ in range(bag_size):
|
| 251 |
+
phi = self.sample(nvars)
|
| 252 |
+
formulae.append(phi)
|
| 253 |
+
return formulae
|
| 254 |
+
|
| 255 |
+
def _sample_internal_node(self, nvars):
|
| 256 |
+
# Declare & dummy-assign "idiom"
|
| 257 |
+
node: Union[None, Node]
|
| 258 |
+
node = None
|
| 259 |
+
# choose node type
|
| 260 |
+
nodetype = rnd.choice(self.node_types, p=self.inner_node_prob)
|
| 261 |
+
while True:
|
| 262 |
+
if nodetype == "not":
|
| 263 |
+
n = self._sample_node(nvars)
|
| 264 |
+
node = stl.Not(n)
|
| 265 |
+
elif nodetype == "and":
|
| 266 |
+
n1 = self._sample_node(nvars)
|
| 267 |
+
n2 = self._sample_node(nvars)
|
| 268 |
+
node = stl.And(n1, n2)
|
| 269 |
+
elif nodetype == "or":
|
| 270 |
+
n1 = self._sample_node(nvars)
|
| 271 |
+
n2 = self._sample_node(nvars)
|
| 272 |
+
node = stl.Or(n1, n2)
|
| 273 |
+
elif nodetype == "always":
|
| 274 |
+
n = self._sample_node(nvars)
|
| 275 |
+
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
|
| 276 |
+
node = stl.Globally(
|
| 277 |
+
n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
|
| 278 |
+
)
|
| 279 |
+
elif nodetype == "eventually":
|
| 280 |
+
n = self._sample_node(nvars)
|
| 281 |
+
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
|
| 282 |
+
node = stl.Eventually(
|
| 283 |
+
n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
|
| 284 |
+
)
|
| 285 |
+
elif nodetype == "until":
|
| 286 |
+
n1 = self._sample_node(nvars)
|
| 287 |
+
n2 = self._sample_node(nvars)
|
| 288 |
+
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
|
| 289 |
+
node = stl.Until(
|
| 290 |
+
n1, n2, unbound, right_unbound, left_time_bound, right_time_bound
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if (node is not None) and (node.time_depth() < self.max_timespan):
|
| 294 |
+
return node
|
| 295 |
+
|
| 296 |
+
def _sample_node(self, nvars):
|
| 297 |
+
if rnd.rand() < self.leaf_prob:
|
| 298 |
+
# sample a leaf
|
| 299 |
+
var, thr, lte = self._get_atom(nvars)
|
| 300 |
+
return stl.Atom(var, thr, lte)
|
| 301 |
+
else:
|
| 302 |
+
return self._sample_internal_node(nvars)
|
| 303 |
+
|
| 304 |
+
def _get_temporal_parameters(self):
|
| 305 |
+
if rnd.rand() < self.unbound_prob:
|
| 306 |
+
return True, False, 0, 0
|
| 307 |
+
elif rnd.rand() < self.right_unbound_prob:
|
| 308 |
+
return False, True, rnd.randint(self.time_bound_max_range), 1
|
| 309 |
+
else:
|
| 310 |
+
left_bound = rnd.randint(self.time_bound_max_range)
|
| 311 |
+
return False, False, left_bound, rnd.randint(left_bound, self.time_bound_max_range) + 1
|
| 312 |
+
|
| 313 |
+
def _get_atom(self, nvars):
|
| 314 |
+
variable = rnd.randint(nvars)
|
| 315 |
+
lte = rnd.rand() > 0.5
|
| 316 |
+
threshold = rnd.normal(self.threshold_mean, self.threshold_sd)
|
| 317 |
+
return variable, threshold, lte
|
| 318 |
+
|
| 319 |
+
#### traj_measure ####
|
| 320 |
+
|
| 321 |
+
class Measure:
|
| 322 |
+
def sample(self, samples=100000, varn=2, points=100):
|
| 323 |
+
# Must be overridden
|
| 324 |
+
pass
|
| 325 |
+
|
| 326 |
+
class BaseMeasure(Measure):
|
| 327 |
+
def __init__(
|
| 328 |
+
self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu"
|
| 329 |
+
):
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
Parameters
|
| 333 |
+
----------
|
| 334 |
+
mu0 : mean of normal distribution of initial state, optional
|
| 335 |
+
The default is 0.0.
|
| 336 |
+
sigma0 : standard deviation of normal distribution of initial state, optional
|
| 337 |
+
The default is 1.0.
|
| 338 |
+
mu1 : DOUBLE, optional
|
| 339 |
+
mean of normal distribution of total variation. The default is 0.0.
|
| 340 |
+
sigma1 : standard deviation of normal distribution of total variation, optional
|
| 341 |
+
The default is 1.0.
|
| 342 |
+
q : DOUBLE, optional
|
| 343 |
+
probability of change of sign in derivative. The default is 0.1.
|
| 344 |
+
q0 : DOUBLE, optional
|
| 345 |
+
probability of initial sign of derivative. The default is 0.5.
|
| 346 |
+
device : 'cpu' or 'cuda', optional
|
| 347 |
+
device on which to run the algorithm. The default is 'cpu'.
|
| 348 |
+
|
| 349 |
+
Returns
|
| 350 |
+
-------
|
| 351 |
+
None.
|
| 352 |
+
|
| 353 |
+
"""
|
| 354 |
+
self.mu0 = mu0
|
| 355 |
+
self.sigma0 = sigma0
|
| 356 |
+
self.mu1 = mu1
|
| 357 |
+
self.sigma1 = sigma1
|
| 358 |
+
self.q = q
|
| 359 |
+
self.q0 = q0
|
| 360 |
+
self.device = device
|
| 361 |
+
|
| 362 |
+
def sample(self, samples=100000, varn=2, points=100):
|
| 363 |
+
"""
|
| 364 |
+
Samples a set of trajectories from the basic measure space, with parameters
|
| 365 |
+
passed to the sampler
|
| 366 |
+
|
| 367 |
+
Parameters
|
| 368 |
+
----------
|
| 369 |
+
points : INT, optional
|
| 370 |
+
number of points per trajectory, including initial one. The default is 1000.
|
| 371 |
+
samples : INT, optional
|
| 372 |
+
number of trajectories. The default is 100000.
|
| 373 |
+
varn : INT, optional
|
| 374 |
+
number of variables per trajectory. The default is 2.
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
Returns
|
| 378 |
+
-------
|
| 379 |
+
signal : samples x varn x points double pytorch tensor
|
| 380 |
+
The sampled signals.
|
| 381 |
+
|
| 382 |
+
"""
|
| 383 |
+
if self.device == "cuda" and not torch.cuda.is_available():
|
| 384 |
+
raise RuntimeError("GPU card or CUDA library not available!")
|
| 385 |
+
|
| 386 |
+
# generate unif RN
|
| 387 |
+
signal = torch.rand(samples, varn, points, device=self.device)
|
| 388 |
+
# first point is special - set to zero for the moment, and set one point to 1
|
| 389 |
+
signal[:, :, 0] = 0.0
|
| 390 |
+
signal[:, :, -1] = 1.0
|
| 391 |
+
# sorting each trajectory
|
| 392 |
+
signal, _ = torch.sort(signal, 2)
|
| 393 |
+
# computing increments and storing them in points 1 to end
|
| 394 |
+
signal[:, :, 1:] = signal[:, :, 1:] - signal[:, :, :-1]
|
| 395 |
+
# generate initial state, according to a normal distribution
|
| 396 |
+
signal[:, :, 0] = self.mu0 + self.sigma0 * torch.randn(signal[:, :, 0].size())
|
| 397 |
+
|
| 398 |
+
# sampling change signs from bernoulli in -1, 1
|
| 399 |
+
derivs = (1 - self.q) * torch.ones(samples, varn, points, device=self.device)
|
| 400 |
+
derivs = 2 * torch.bernoulli(derivs) - 1
|
| 401 |
+
# sampling initial derivative
|
| 402 |
+
derivs[:, :, 0] = self.q0
|
| 403 |
+
derivs[:, :, 0] = 2 * torch.bernoulli(derivs[:, :, 0]) - 1
|
| 404 |
+
# taking the cumulative product along axis 2
|
| 405 |
+
derivs = torch.cumprod(derivs, 2)
|
| 406 |
+
|
| 407 |
+
# sampling total variation
|
| 408 |
+
totvar = torch.pow(
|
| 409 |
+
self.mu1 + self.sigma1 * torch.randn(samples, varn, 1, device=self.device),
|
| 410 |
+
2,
|
| 411 |
+
)
|
| 412 |
+
# multiplying total variation and derivatives and making initial point non-invasive
|
| 413 |
+
derivs = derivs * totvar
|
| 414 |
+
derivs[:, :, 0] = 1.0
|
| 415 |
+
|
| 416 |
+
# computing trajectories by multiplying and then doing a cumulative sum
|
| 417 |
+
signal = signal * derivs
|
| 418 |
+
signal = torch.cumsum(signal, 2)
|
| 419 |
+
return signal
|
| 420 |
+
|
| 421 |
+
#### kernel ####
|
| 422 |
+
|
| 423 |
+
realnum = Union[float, int]
|
| 424 |
+
|
| 425 |
+
class StlKernel:
|
| 426 |
+
def __init__(
|
| 427 |
+
self,
|
| 428 |
+
measure,
|
| 429 |
+
normalize=True,
|
| 430 |
+
exp_kernel=True,
|
| 431 |
+
sigma2=0.2, # 0.5 meglio, inizialmente era a 0.2
|
| 432 |
+
integrate_time=False,
|
| 433 |
+
samples=100000,
|
| 434 |
+
varn=2,
|
| 435 |
+
points=100,
|
| 436 |
+
boolean=False,
|
| 437 |
+
signals=None,
|
| 438 |
+
):
|
| 439 |
+
self.traj_measure = measure
|
| 440 |
+
self.exp_kernel = exp_kernel
|
| 441 |
+
self.normalize = normalize
|
| 442 |
+
self.sigma2 = sigma2
|
| 443 |
+
self.samples = samples
|
| 444 |
+
self.varn = varn
|
| 445 |
+
self.points = points
|
| 446 |
+
self.integrate_time = integrate_time
|
| 447 |
+
if signals is not None:
|
| 448 |
+
self.signals = signals
|
| 449 |
+
else:
|
| 450 |
+
self.signals = measure.sample(points=points, samples=samples, varn=varn)
|
| 451 |
+
self.boolean = boolean
|
| 452 |
+
|
| 453 |
+
def compute(self, phi1, phi2):
|
| 454 |
+
return self.compute_one_one(phi1, phi2)
|
| 455 |
+
|
| 456 |
+
def compute_one_one(self, phi1, phi2):
|
| 457 |
+
phis1: list = [phi1]
|
| 458 |
+
phis2: list = [phi2]
|
| 459 |
+
ker = self.compute_bag_bag(phis1, phis2)
|
| 460 |
+
return ker[0, 0]
|
| 461 |
+
|
| 462 |
+
def compute_bag(self, phis, return_robustness=True):
|
| 463 |
+
if self.integrate_time:
|
| 464 |
+
rhos, selfk, len0 = self._compute_robustness_time(phis)
|
| 465 |
+
kernel_matrix = self._compute_kernel_time(
|
| 466 |
+
rhos, rhos, selfk, selfk, len0, len0
|
| 467 |
+
)
|
| 468 |
+
else:
|
| 469 |
+
rhos, selfk = self._compute_robustness_no_time(phis)
|
| 470 |
+
kernel_matrix = self._compute_kernel_no_time(rhos, rhos, selfk, selfk)
|
| 471 |
+
len0 = None
|
| 472 |
+
if return_robustness:
|
| 473 |
+
return kernel_matrix.cpu(), rhos, selfk, len0
|
| 474 |
+
else:
|
| 475 |
+
return kernel_matrix.cpu()
|
| 476 |
+
|
| 477 |
+
def compute_one_bag(self, phi1, phis2, return_robustness=False):
|
| 478 |
+
phis1: list = [phi1]
|
| 479 |
+
return self.compute_bag_bag(phis1, phis2, return_robustness)
|
| 480 |
+
|
| 481 |
+
def compute_bag_bag(self, phis1, phis2, return_robustness=False):
|
| 482 |
+
if self.integrate_time:
|
| 483 |
+
rhos1, selfk1, len1 = self._compute_robustness_time(phis1)
|
| 484 |
+
rhos2, selfk2, len2 = self._compute_robustness_time(phis2)
|
| 485 |
+
kernel_matrix = self._compute_kernel_time(
|
| 486 |
+
rhos1, rhos2, selfk1, selfk2, len1, len2
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
rhos1, selfk1 = self._compute_robustness_no_time(phis1)
|
| 490 |
+
rhos2, selfk2 = self._compute_robustness_no_time(phis2)
|
| 491 |
+
len1, len2 = [None, None]
|
| 492 |
+
kernel_matrix = self._compute_kernel_no_time(rhos1, rhos2, selfk1, selfk2)
|
| 493 |
+
if return_robustness:
|
| 494 |
+
return kernel_matrix.cpu(), rhos1, rhos2, selfk1, selfk2, len1, len2
|
| 495 |
+
else:
|
| 496 |
+
return kernel_matrix.cpu()
|
| 497 |
+
|
| 498 |
+
def compute_one_from_robustness(self, phi, rhos, rho_self, lengths=None, return_robustness=False):
|
| 499 |
+
phis: list = [phi]
|
| 500 |
+
return self.compute_bag_from_robustness(phis, rhos, rho_self, lengths, return_robustness)
|
| 501 |
+
|
| 502 |
+
def compute_bag_from_robustness(self, phis, rhos, rho_self, lengths=None, return_robustness=False):
|
| 503 |
+
if self.integrate_time:
|
| 504 |
+
rhos1, selfk1, len1 = self._compute_robustness_time(phis)
|
| 505 |
+
kernel_matrix = self._compute_kernel_time(
|
| 506 |
+
rhos1, rhos, selfk1, rho_self, len1, lengths
|
| 507 |
+
)
|
| 508 |
+
else:
|
| 509 |
+
rhos1, selfk1 = self._compute_robustness_no_time(phis)
|
| 510 |
+
len1 = None
|
| 511 |
+
kernel_matrix = self._compute_kernel_no_time(rhos1, rhos, selfk1, rho_self)
|
| 512 |
+
if return_robustness:
|
| 513 |
+
return kernel_matrix.cpu(), rhos1, selfk1, len1
|
| 514 |
+
else:
|
| 515 |
+
return kernel_matrix.cpu()
|
| 516 |
+
|
| 517 |
+
def _compute_robustness_time(self, phis):
|
| 518 |
+
n = self.samples
|
| 519 |
+
p = self.points
|
| 520 |
+
k = len(phis)
|
| 521 |
+
rhos = torch.zeros((k, n, p), device="cpu")
|
| 522 |
+
lengths = torch.zeros(k)
|
| 523 |
+
self_kernels = torch.zeros((k, 1))
|
| 524 |
+
for i, phi in enumerate(phis):
|
| 525 |
+
if self.boolean:
|
| 526 |
+
rho = phi.boolean(self.signals, evaluate_at_all_times=True).float()
|
| 527 |
+
rho[rho == 0.0] = -1.0
|
| 528 |
+
else:
|
| 529 |
+
rho = phi.quantitative(self.signals, evaluate_at_all_times=True)
|
| 530 |
+
actual_p = rho.size()[2]
|
| 531 |
+
rho = rho.reshape(n, actual_p).cpu()
|
| 532 |
+
rhos[i, :, :actual_p] = rho
|
| 533 |
+
lengths[i] = actual_p
|
| 534 |
+
self_kernels[i] = torch.tensordot(
|
| 535 |
+
rho.reshape(1, n, -1), rho.reshape(1, n, -1), dims=[[1, 2], [1, 2]]
|
| 536 |
+
) / (actual_p * n)
|
| 537 |
+
return rhos, self_kernels, lengths
|
| 538 |
+
|
| 539 |
+
def _compute_robustness_no_time(self, phis):
|
| 540 |
+
n = self.samples
|
| 541 |
+
k = len(phis)
|
| 542 |
+
rhos = torch.zeros((k, n), device=self.traj_measure.device)
|
| 543 |
+
self_kernels = torch.zeros((k, 1), device=self.traj_measure.device)
|
| 544 |
+
for i, phi in enumerate(phis):
|
| 545 |
+
if self.boolean:
|
| 546 |
+
rho = phi.boolean(self.signals, evaluate_at_all_times=False).float()
|
| 547 |
+
rho[rho == 0.0] = -1.0
|
| 548 |
+
else:
|
| 549 |
+
rho = phi.quantitative(self.signals, evaluate_at_all_times=False)
|
| 550 |
+
self_kernels[i] = rho.dot(rho) / n
|
| 551 |
+
rhos[i, :] = rho
|
| 552 |
+
return rhos, self_kernels
|
| 553 |
+
|
| 554 |
+
def _compute_kernel_time(self, rhos1, rhos2, selfk1, selfk2, len1, len2):
|
| 555 |
+
kernel_matrix = torch.tensordot(rhos1, rhos2, [[1, 2], [1, 2]])
|
| 556 |
+
length_normalizer = self._compute_trajectory_length_normalizer(len1, len2)
|
| 557 |
+
kernel_matrix = kernel_matrix * length_normalizer / self.samples
|
| 558 |
+
if self.normalize:
|
| 559 |
+
kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
|
| 560 |
+
if self.exp_kernel:
|
| 561 |
+
kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
|
| 562 |
+
return kernel_matrix
|
| 563 |
+
|
| 564 |
+
def _compute_kernel_no_time(self, rhos1, rhos2, selfk1, selfk2):
|
| 565 |
+
kernel_matrix = torch.tensordot(rhos1, rhos2, [[1], [1]])
|
| 566 |
+
kernel_matrix = kernel_matrix / self.samples
|
| 567 |
+
if self.normalize:
|
| 568 |
+
kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
|
| 569 |
+
if self.exp_kernel:
|
| 570 |
+
kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
|
| 571 |
+
return kernel_matrix
|
| 572 |
+
|
| 573 |
+
@staticmethod
|
| 574 |
+
def _normalize(kernel_matrix, selfk1, selfk2):
|
| 575 |
+
normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
|
| 576 |
+
kernel_matrix = kernel_matrix / normalize
|
| 577 |
+
return kernel_matrix
|
| 578 |
+
|
| 579 |
+
def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None):
|
| 580 |
+
if sigma2 is None:
|
| 581 |
+
sigma2 = self.sigma2
|
| 582 |
+
if self.normalize:
|
| 583 |
+
# selfk is (1.0^2 + 1.0^2)
|
| 584 |
+
selfk = 2.0
|
| 585 |
+
else:
|
| 586 |
+
k1 = selfk1.size()[0]
|
| 587 |
+
k2 = selfk2.size()[0]
|
| 588 |
+
selfk = (selfk1 * selfk1).repeat(1, k2) + torch.transpose(
|
| 589 |
+
selfk2 * selfk2, 0, 1
|
| 590 |
+
).repeat(k1, 1)
|
| 591 |
+
return torch.exp(-(selfk - 2 * kernel_matrix) / (2 * sigma2))
|
| 592 |
+
|
| 593 |
+
@staticmethod
|
| 594 |
+
def _compute_trajectory_length_normalizer(len1, len2):
|
| 595 |
+
k1 = len1.size()[0]
|
| 596 |
+
k2 = len2.size()[0]
|
| 597 |
+
y1 = len1.reshape(-1, 1)
|
| 598 |
+
y1 = y1.repeat(1, k2)
|
| 599 |
+
y2 = len2.repeat(k1, 1)
|
| 600 |
+
return 1.0 / torch.min(y1, y2)
|
| 601 |
+
|
| 602 |
+
class GramMatrix:
|
| 603 |
+
def __init__(self, kernel, formulae, store_robustness=True, sample=False, sampler=None, bag_size=None):
|
| 604 |
+
self.kernel = kernel
|
| 605 |
+
self.formulae_list = formulae
|
| 606 |
+
# if kernel is computed from robustness at time zero only,
|
| 607 |
+
# we store the robustness for each formula and each sample
|
| 608 |
+
# to speed up computation later
|
| 609 |
+
self.store_robustness = store_robustness
|
| 610 |
+
self.dim = len(self.formulae_list) if not bag_size else int(bag_size)
|
| 611 |
+
self.sample = sample # whether to generate formulae in a controlled manner
|
| 612 |
+
if self.sample:
|
| 613 |
+
self.t = 0.99 if self.kernel.boolean else 0.85
|
| 614 |
+
self.sampler = sampler # stl formulae generator
|
| 615 |
+
self._compute_gram_matrix()
|
| 616 |
+
|
| 617 |
+
def _compute_gram_matrix(self):
|
| 618 |
+
if self.sample:
|
| 619 |
+
gram = torch.zeros(self.dim, self.dim)
|
| 620 |
+
rhos = torch.zeros((self.dim, self.kernel.samples), device=self.kernel.traj_measure.device) if \
|
| 621 |
+
not self.kernel.integrate_time else torch.zeros((self.dim, self.kernel.samples, self.kernel.points),
|
| 622 |
+
device=self.kernel.traj_measure.device)
|
| 623 |
+
lengths = torch.zeros(self.dim) if self.kernel.integrate_time else np.zeros(self.dim)
|
| 624 |
+
kernels = torch.zeros((self.dim, 1), device=self.kernel.traj_measure.device)
|
| 625 |
+
phis = [self.sampler.sample(nvars=self.kernel.varn)]
|
| 626 |
+
gram[0, :1], rhos[0], kernels[0, :], lengths[0] = self.kernel.compute_bag(phis, return_robustness=True)
|
| 627 |
+
while len(phis) < self.dim:
|
| 628 |
+
i = len(phis)
|
| 629 |
+
phi = self.sampler.sample(nvars=self.kernel.varn)
|
| 630 |
+
gram[i, :i], rhos[i], kernels[i, :], lengths[i] = self.kernel.compute_one_from_robustness(
|
| 631 |
+
phi, rhos[:i, :], kernels[:i, :], lengths[:i], return_robustness=True)
|
| 632 |
+
if torch.sum(gram[i, :i + 1] >= self.t) < 3:
|
| 633 |
+
phis.append(phi)
|
| 634 |
+
gram[:i, i] = gram[i, :i]
|
| 635 |
+
gram[i, i] = kernels[i, :]
|
| 636 |
+
|
| 637 |
+
self.formulae_list = phis
|
| 638 |
+
self.gram = gram.cpu()
|
| 639 |
+
self.robustness = rhos if self.store_robustness else None
|
| 640 |
+
self.self_kernels = kernels if self.store_robustness else None
|
| 641 |
+
self.robustness_lengths = lengths if self.store_robustness else None
|
| 642 |
+
else:
|
| 643 |
+
if self.store_robustness:
|
| 644 |
+
k_matrix, rhos, selfk, len0 = self.kernel.compute_bag(
|
| 645 |
+
self.formulae_list, return_robustness=True
|
| 646 |
+
)
|
| 647 |
+
self.gram = k_matrix
|
| 648 |
+
self.robustness = rhos
|
| 649 |
+
self.self_kernels = selfk
|
| 650 |
+
self.robustness_lengths = len0
|
| 651 |
+
else:
|
| 652 |
+
self.gram = self.kernel.compute_bag(
|
| 653 |
+
self.formulae_list, return_robustness=False
|
| 654 |
+
)
|
| 655 |
+
self.robustness = None
|
| 656 |
+
self.self_kernels = None
|
| 657 |
+
self.robustness_lengths = None
|
| 658 |
+
|
| 659 |
+
def compute_kernel_vector(self, phi):
|
| 660 |
+
if self.store_robustness:
|
| 661 |
+
return self.kernel.compute_one_from_robustness(
|
| 662 |
+
phi, self.robustness, self.self_kernels, self.robustness_lengths
|
| 663 |
+
)
|
| 664 |
+
else:
|
| 665 |
+
return self.kernel.compute_one_bag(phi, self.formulae_list)
|
| 666 |
+
|
| 667 |
+
def compute_bag_kernel_vector(self, phis, generate_phis=False, bag_size=None):
|
| 668 |
+
if generate_phis:
|
| 669 |
+
gram_test = torch.zeros(bag_size, self.dim) # self.dim, bag_size
|
| 670 |
+
rhos_test = torch.zeros((bag_size, self.kernel.samples), device=self.kernel.traj_measure.device) if \
|
| 671 |
+
not self.kernel.integrate_time else torch.zeros((bag_size, self.kernel.samples, self.kernel.points),
|
| 672 |
+
device=self.kernel.traj_measure.device)
|
| 673 |
+
lengths_test = torch.zeros(bag_size) if self.kernel.integrate_time else np.zeros(bag_size)
|
| 674 |
+
kernels_test = torch.zeros((bag_size, 1), device=self.kernel.traj_measure.device)
|
| 675 |
+
phi_test = []
|
| 676 |
+
while len(phi_test) < bag_size:
|
| 677 |
+
i = len(phi_test)
|
| 678 |
+
phi = self.sampler.sample(nvars=self.kernel.varn)
|
| 679 |
+
if self.store_robustness:
|
| 680 |
+
gram_test[i, :], rhos_test[i], kernels_test[i, :], lengths_test[i] = \
|
| 681 |
+
self.kernel.compute_one_from_robustness(phi, self.robustness, self.self_kernels,
|
| 682 |
+
self.robustness_lengths, return_robustness=True)
|
| 683 |
+
else:
|
| 684 |
+
gram_test[i, :], rhos_test[i], _, kernels_test[i, :], _, lengths_test[i], _ = \
|
| 685 |
+
self.kernel.compute_one_bag(phi, self.formulae_list, return_robustness=True)
|
| 686 |
+
if not ((rhos_test[i] > 0).all() or (rhos_test[i] < 0).all()):
|
| 687 |
+
phi_test.append(phi)
|
| 688 |
+
return phi_test, gram_test.cpu()
|
| 689 |
+
else:
|
| 690 |
+
if self.store_robustness:
|
| 691 |
+
return self.kernel.compute_bag_from_robustness(
|
| 692 |
+
phis, self.robustness, self.self_kernels, self.robustness_lengths
|
| 693 |
+
)
|
| 694 |
+
else:
|
| 695 |
+
return self.kernel.compute_bag_bag(phis, self.formulae_list)
|
| 696 |
+
|
| 697 |
+
def invert_regularized(self, alpha):
|
| 698 |
+
regularizer = abs(pow(10, alpha)) * torch.eye(self.dim)
|
| 699 |
+
return torch.inverse(self.gram + regularizer)
|
| 700 |
+
|
| 701 |
+
#### anchor_generation ####
|
| 702 |
+
|
| 703 |
+
def anchorGeneration(diff_init = False, # to control whether we want formulae to be semantically different by construction
|
| 704 |
+
embed_dim: int = 30, # embedding dimension, aka number of generated formulae in the anchor set
|
| 705 |
+
n_vars: int = 3, # dimension of the input signal (3D in this case)
|
| 706 |
+
leaf_prob: float = 0.4, # complexity of the generated formula
|
| 707 |
+
cosine_similarity_threshold: float = 0.8 # if two formulae cosine similarity exceeds 0.9, then discard one of the two
|
| 708 |
+
) -> str:
|
| 709 |
+
|
| 710 |
+
# initialize STL formula generator
|
| 711 |
+
sampler = StlGenerator(leaf_prob)
|
| 712 |
+
|
| 713 |
+
# effective anchor set generation
|
| 714 |
+
if diff_init:
|
| 715 |
+
|
| 716 |
+
# initialize the anchor set with a randomly sampled formula
|
| 717 |
+
diff_anchor_set = [sampler.sample(nvars=n_vars)]
|
| 718 |
+
|
| 719 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 720 |
+
mu = BaseMeasure(device=device)
|
| 721 |
+
|
| 722 |
+
# generates a set of random signals working as a tester for the formulae testing
|
| 723 |
+
signals = mu.sample(samples=10000, varn=n_vars)
|
| 724 |
+
|
| 725 |
+
# computes robustness value for the initial set of formulae in the anchor set
|
| 726 |
+
anchor_rob_vectors = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in diff_anchor_set], 0)
|
| 727 |
+
|
| 728 |
+
while len(diff_anchor_set) < embed_dim:
|
| 729 |
+
# sample the 'remaining' formulae to reach the desired number of `embed_dim` formulae:
|
| 730 |
+
candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars)
|
| 731 |
+
|
| 732 |
+
# compute robustness of candidate anchor formulae on the same signals as previous anchor set
|
| 733 |
+
candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0)
|
| 734 |
+
|
| 735 |
+
# compute cosine similarity between current anchor set and candidate new formulae
|
| 736 |
+
cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1)
|
| 737 |
+
|
| 738 |
+
# check which formulae are similar (i.e. greater cosine similarity then threshold) w.r.t. current anchors
|
| 739 |
+
# NOTA: chiedere a gaia se cosine similarities negative vanno ammazzate con un valore assoluto o meno!
|
| 740 |
+
similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])]
|
| 741 |
+
|
| 742 |
+
# keep only those who are semantically distant
|
| 743 |
+
keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist])))
|
| 744 |
+
|
| 745 |
+
diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx]
|
| 746 |
+
|
| 747 |
+
# Convert keep_idx to a tensor on the same device as candidate_robs
|
| 748 |
+
keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device)
|
| 749 |
+
|
| 750 |
+
# Use index_select to pick the relevant rows
|
| 751 |
+
selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor)
|
| 752 |
+
|
| 753 |
+
# Concatenate on the same device
|
| 754 |
+
anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0)
|
| 755 |
+
|
| 756 |
+
anchor_set = diff_anchor_set[:embed_dim]
|
| 757 |
+
|
| 758 |
+
else:
|
| 759 |
+
anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars)
|
| 760 |
+
|
| 761 |
+
filename = f'anchor_set_no_diff_{embed_dim}_dim'
|
| 762 |
+
dump_pickle(filename, anchor_set)
|
| 763 |
+
return filename
|
| 764 |
+
|
| 765 |
+
####
|
| 766 |
|
| 767 |
class STLTokenizer(PreTrainedTokenizer):
|
| 768 |
"""
|
|
|
|
| 1099 |
|
| 1100 |
return attn_output, None, past_key_value
|
| 1101 |
|
| 1102 |
+
####
|
| 1103 |
|
| 1104 |
class STLEncoder():
|
| 1105 |
def __init__(self,
|
|
|
|
| 1504 |
cross_attentions=all_cross_attentions,
|
| 1505 |
)
|
| 1506 |
|
| 1507 |
+
####
|
| 1508 |
|
| 1509 |
class STLForCausalLM(STLModel, GenerationMixin):
|
| 1510 |
_tied_weights_keys = ["lm_head.weight"]
|