stldec_random_1024_pca / modeling_stldec.py
saracandu's picture
Upload STLForCausalLM
f8515bb verified
import ast
import copy
import math
import pickle
import os
from collections import deque
from typing import List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
import torch.utils.checkpoint
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from transformers.generation import GenerationMixin
from transformers.utils import (
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from .configuration_stldec import STLConfig
from nltk.translate.bleu_score import sentence_bleu
# from stl import *
import networkx as nx
from datasets import load_dataset
### from custom_typing.py
realnum = Union[float, int]
### from stl.py
# For tensor functions
import torch
from torch import Tensor
import torch.nn.functional as F
def eventually(x: Tensor, time_span: int) -> Tensor:
"""
STL operator 'eventually' in 1D.
Parameters
----------
x: torch.Tensor
Signal
time_span: any numeric type
Timespan duration
Returns
-------
torch.Tensor
A tensor containing the result of the operation.
"""
return F.max_pool1d(x, kernel_size=time_span, stride=1)
class Node:
"""Abstract node class for STL semantics tree."""
def __init__(self) -> None:
# Must be overloaded.
pass
def __str__(self) -> str:
# Must be overloaded.
pass
def boolean(self, x: Tensor, evaluate_at_all_times: bool = False) -> Tensor:
"""
Evaluates the boolean semantics at the node.
Parameters
----------
x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
The input signals, stored as a batch tensor with trhee dimensions.
evaluate_at_all_times: bool
Whether to evaluate the semantics at all times (True) or
just at t=0 (False).
Returns
-------
torch.Tensor
A tensor with the boolean semantics for the node.
"""
z: Tensor = self._boolean(x)
if evaluate_at_all_times:
return z
else:
return self._extract_semantics_at_time_zero(z)
def quantitative(
self,
x: Tensor,
normalize: bool = False,
evaluate_at_all_times: bool = False,
) -> Tensor:
"""
Evaluates the quantitative semantics at the node.
Parameters
----------
x : torch.Tensor, of size N_samples x N_vars x N_sampling_points
The input signals, stored as a batch tensor with three dimensions.
normalize: bool
Whether the measure of robustness if normalized (True) or
not (False). Currently not in use.
evaluate_at_all_times: bool
Whether to evaluate the semantics at all times (True) or
just at t=0 (False).
Returns
-------
torch.Tensor
A tensor with the quantitative semantics for the node.
"""
z: Tensor = self._quantitative(x, normalize)
if evaluate_at_all_times:
return z
else:
return self._extract_semantics_at_time_zero(z)
def set_normalizing_flag(self, value: bool = True) -> None:
"""
Setter for the 'normalization of robustness of the formula' flag.
Currently not in use.
"""
def time_depth(self) -> int:
"""Returns time depth of bounded temporal operators only."""
# Must be overloaded.
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
"""Private method equivalent to public one for inner call."""
# Must be overloaded.
def _boolean(self, x: Tensor) -> Tensor:
"""Private method equivalent to public one for inner call."""
# Must be overloaded.
@staticmethod
def _extract_semantics_at_time_zero(x: Tensor) -> Tensor:
"""Extrapolates the vector of truth values at time zero"""
return torch.reshape(x[:, 0, 0], (-1,))
class Atom(Node):
"""Atomic formula node; for now of the form X<=t or X>=t"""
def __init__(self, var_index: int, threshold: realnum, lte: bool = False) -> None:
super().__init__()
self.var_index: int = var_index
self.threshold: realnum = threshold
self.lte: bool = lte
def __str__(self) -> str:
s: str = (
"x_"
+ str(self.var_index)
+ (" <= " if self.lte else " >= ")
+ str(round(self.threshold, 4))
)
return s
def time_depth(self) -> int:
return 0
def _boolean(self, x: Tensor) -> Tensor:
# extract tensor of the same dimension as data, but with only one variable
xj: Tensor = x[:, self.var_index, :]
xj: Tensor = xj.view(xj.size()[0], 1, -1)
if self.lte:
z: Tensor = torch.le(xj, self.threshold)
else:
z: Tensor = torch.ge(xj, self.threshold)
return z
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
# extract tensor of the same dimension as data, but with only one variable
xj: Tensor = x[:, self.var_index, :]
xj: Tensor = xj.view(xj.size()[0], 1, -1)
if self.lte:
z: Tensor = -xj + self.threshold
else:
z: Tensor = xj - self.threshold
if normalize:
z: Tensor = torch.tanh(z)
return z
class Not(Node):
"""Negation node."""
def __init__(self, child: Node) -> None:
super().__init__()
self.child: Node = child
def __str__(self) -> str:
s: str = "not ( " + self.child.__str__() + " )"
return s
def time_depth(self) -> int:
return self.child.time_depth()
def _boolean(self, x: Tensor) -> Tensor:
z: Tensor = ~self.child._boolean(x)
return z
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
z: Tensor = -self.child._quantitative(x, normalize)
return z
class And(Node):
"""Conjunction node."""
def __init__(self, left_child: Node, right_child: Node) -> None:
super().__init__()
self.left_child: Node = left_child
self.right_child: Node = right_child
def __str__(self) -> str:
s: str = (
"( "
+ self.left_child.__str__()
+ " and "
+ self.right_child.__str__()
+ " )"
)
return s
def time_depth(self) -> int:
return max(self.left_child.time_depth(), self.right_child.time_depth())
def _boolean(self, x: Tensor) -> Tensor:
z1: Tensor = self.left_child._boolean(x)
z2: Tensor = self.right_child._boolean(x)
size: int = min(z1.size()[2], z2.size()[2])
z1: Tensor = z1[:, :, :size]
z2: Tensor = z2[:, :, :size]
z: Tensor = torch.logical_and(z1, z2)
return z
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
z1: Tensor = self.left_child._quantitative(x, normalize)
z2: Tensor = self.right_child._quantitative(x, normalize)
size: int = min(z1.size()[2], z2.size()[2])
z1: Tensor = z1[:, :, :size]
z2: Tensor = z2[:, :, :size]
z: Tensor = torch.min(z1, z2)
return z
class Not(Node):
"""Negation node."""
def __init__(self, child: Node) -> None:
super().__init__()
self.child: Node = child
def __str__(self) -> str:
s: str = "not ( " + self.child.__str__() + " )"
return s
def time_depth(self) -> int:
return self.child.time_depth()
def _boolean(self, x: Tensor) -> Tensor:
z: Tensor = ~self.child._boolean(x)
return z
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
z: Tensor = -self.child._quantitative(x, normalize)
return z
class And(Node):
"""Conjunction node."""
def __init__(self, left_child: Node, right_child: Node) -> None:
super().__init__()
self.left_child: Node = left_child
self.right_child: Node = right_child
def __str__(self) -> str:
s: str = (
"( "
+ self.left_child.__str__()
+ " and "
+ self.right_child.__str__()
+ " )"
)
return s
def time_depth(self) -> int:
return max(self.left_child.time_depth(), self.right_child.time_depth())
def _boolean(self, x: Tensor) -> Tensor:
z1: Tensor = self.left_child._boolean(x)
z2: Tensor = self.right_child._boolean(x)
size: int = min(z1.size()[2], z2.size()[2])
z1: Tensor = z1[:, :, :size]
z2: Tensor = z2[:, :, :size]
z: Tensor = torch.logical_and(z1, z2)
return z
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
z1: Tensor = self.left_child._quantitative(x, normalize)
z2: Tensor = self.right_child._quantitative(x, normalize)
size: int = min(z1.size()[2], z2.size()[2])
z1: Tensor = z1[:, :, :size]
z2: Tensor = z2[:, :, :size]
z: Tensor = torch.min(z1, z2)
return z
class Or(Node):
"""Disjunction node."""
def __init__(self, left_child: Node, right_child: Node) -> None:
super().__init__()
self.left_child: Node = left_child
self.right_child: Node = right_child
def __str__(self) -> str:
s: str = (
"( "
+ self.left_child.__str__()
+ " or "
+ self.right_child.__str__()
+ " )"
)
return s
def time_depth(self) -> int:
return max(self.left_child.time_depth(), self.right_child.time_depth())
def _boolean(self, x: Tensor) -> Tensor:
z1: Tensor = self.left_child._boolean(x)
z2: Tensor = self.right_child._boolean(x)
size: int = min(z1.size()[2], z2.size()[2])
z1: Tensor = z1[:, :, :size]
z2: Tensor = z2[:, :, :size]
z: Tensor = torch.logical_or(z1, z2)
return z
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
z1: Tensor = self.left_child._quantitative(x, normalize)
z2: Tensor = self.right_child._quantitative(x, normalize)
size: int = min(z1.size()[2], z2.size()[2])
z1: Tensor = z1[:, :, :size]
z2: Tensor = z2[:, :, :size]
z: Tensor = torch.max(z1, z2)
return z
class Globally(Node):
"""Globally node."""
def __init__(
self,
child: Node,
unbound: bool = False,
right_unbound: bool = False,
left_time_bound: int = 0,
right_time_bound: int = 1,
adapt_unbound: bool = True,
) -> None:
super().__init__()
self.child: Node = child
self.unbound: bool = unbound
self.right_unbound: bool = right_unbound
self.left_time_bound: int = left_time_bound
self.right_time_bound: int = right_time_bound + 1
self.adapt_unbound: bool = adapt_unbound
def __str__(self) -> str:
s_left = "[" + str(self.left_time_bound) + ","
s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
s0: str = s_left + s_right + "]" if not self.unbound else ""
s: str = "always" + s0 + " ( " + self.child.__str__() + " )"
return s
def time_depth(self) -> int:
if self.unbound:
return self.child.time_depth()
elif self.right_unbound:
return self.child.time_depth() + self.left_time_bound
else:
# diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
return self.child.time_depth() + self.right_time_bound - 1
# (self.right_time_bound - self.left_time_bound + 1) - diff
def _boolean(self, x: Tensor) -> Tensor:
z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:]) # nested temporal parameters
# z1 = z1[:, :, self.left_time_bound:]
if self.unbound or self.right_unbound:
if self.adapt_unbound:
z: Tensor
_: Tensor
z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
z: Tensor = torch.flip(z, [2])
else:
z: Tensor
_: Tensor
z, _ = torch.min(z1, 2, keepdim=True)
else:
z: Tensor = torch.ge(1.0 - eventually((~z1).double(), self.right_time_bound - self.left_time_bound), 0.5)
return z
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
# z1 = z1[:, :, self.left_time_bound:]
if self.unbound or self.right_unbound:
if self.adapt_unbound:
z: Tensor
_: Tensor
z, _ = torch.cummin(torch.flip(z1, [2]), dim=2)
z: Tensor = torch.flip(z, [2])
else:
z: Tensor
_: Tensor
z, _ = torch.min(z1, 2, keepdim=True)
else:
z: Tensor = -eventually(-z1, self.right_time_bound - self.left_time_bound)
return z
class Eventually(Node):
"""Eventually node."""
def __init__(
self,
child: Node,
unbound: bool = False,
right_unbound: bool = False,
left_time_bound: int = 0,
right_time_bound: int = 1,
adapt_unbound: bool = True,
) -> None:
super().__init__()
self.child: Node = child
self.unbound: bool = unbound
self.right_unbound: bool = right_unbound
self.left_time_bound: int = left_time_bound
self.right_time_bound: int = right_time_bound + 1
self.adapt_unbound: bool = adapt_unbound
if (self.unbound is False) and (self.right_unbound is False) and \
(self.right_time_bound <= self.left_time_bound):
raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
def __str__(self) -> str:
s_left = "[" + str(self.left_time_bound) + ","
s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
s0: str = s_left + s_right + "]" if not self.unbound else ""
s: str = "eventually" + s0 + " ( " + self.child.__str__() + " )"
return s
def time_depth(self) -> int:
if self.unbound:
return self.child.time_depth()
elif self.right_unbound:
return self.child.time_depth() + self.left_time_bound
else:
# diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
return self.child.time_depth() + self.right_time_bound - 1
# (self.right_time_bound - self.left_time_bound + 1) - diff
def _boolean(self, x: Tensor) -> Tensor:
z1: Tensor = self.child._boolean(x[:, :, self.left_time_bound:])
if self.unbound or self.right_unbound:
if self.adapt_unbound:
z: Tensor
_: Tensor
z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
z: Tensor = torch.flip(z, [2])
else:
z: Tensor
_: Tensor
z, _ = torch.max(z1, 2, keepdim=True)
else:
z: Tensor = torch.ge(eventually(z1.double(), self.right_time_bound - self.left_time_bound), 0.5)
return z
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
z1: Tensor = self.child._quantitative(x[:, :, self.left_time_bound:], normalize)
if self.unbound or self.right_unbound:
if self.adapt_unbound:
z: Tensor
_: Tensor
z, _ = torch.cummax(torch.flip(z1, [2]), dim=2)
z: Tensor = torch.flip(z, [2])
else:
z: Tensor
_: Tensor
z, _ = torch.max(z1, 2, keepdim=True)
else:
z: Tensor = eventually(z1, self.right_time_bound - self.left_time_bound)
return z
class Until(Node):
"""Until node."""
def __init__(
self,
left_child: Node,
right_child: Node,
unbound: bool = False,
right_unbound: bool = False,
left_time_bound: int = 0,
right_time_bound: int = 1,
) -> None:
super().__init__()
self.left_child: Node = left_child
self.right_child: Node = right_child
self.unbound: bool = unbound
self.right_unbound: bool = right_unbound
self.left_time_bound: int = left_time_bound
self.right_time_bound: int = right_time_bound + 1
if (self.unbound is False) and (self.right_unbound is False) and \
(self.right_time_bound <= self.left_time_bound):
raise ValueError("Temporal thresholds are incorrect: right parameter is higher than left parameter")
def __str__(self) -> str:
s_left = "[" + str(self.left_time_bound) + ","
s_right = str(self.right_time_bound) if not self.right_unbound else "inf"
s0: str = s_left + s_right + "]" if not self.unbound else ""
s: str = "( " + self.left_child.__str__() + " until" + s0 + " " + self.right_child.__str__() + " )"
return s
def time_depth(self) -> int:
sum_children_depth: int = self.left_child.time_depth() + self.right_child.time_depth()
if self.unbound:
return sum_children_depth
elif self.right_unbound:
return sum_children_depth + self.left_time_bound
else:
# diff = torch.le(torch.tensor([self.left_time_bound]), 0).float()
return sum_children_depth + self.right_time_bound - 1
# (self.right_time_bound - self.left_time_bound + 1) - diff
def _boolean(self, x: Tensor) -> Tensor:
if self.unbound:
z1: Tensor = self.left_child._boolean(x)
z2: Tensor = self.right_child._boolean(x)
size: int = min(z1.size()[2], z2.size()[2])
z1: Tensor = z1[:, :, :size]
z2: Tensor = z2[:, :, :size]
z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
z1_triu = torch.triu(z1_rep)
z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
z2_triu = torch.triu(z2_rep)
z2_def = z2_tril + z2_triu
z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
dim=-1)[0]
elif self.right_unbound:
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
And(Eventually(self.right_child, right_unbound=True,
left_time_bound=self.left_time_bound),
Eventually(Until(self.left_child, self.right_child, unbound=True),
left_time_bound=self.left_time_bound, right_unbound=True)))
z: Tensor = timed_until._boolean(x)
else:
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
right_time_bound=self.right_time_bound - 1),
Eventually(Until(self.left_child, self.right_child, unbound=True),
left_time_bound=self.left_time_bound, right_unbound=True)))
z: Tensor = timed_until._boolean(x)
return z
def _quantitative(self, x: Tensor, normalize: bool = False) -> Tensor:
if self.unbound:
z1: Tensor = self.left_child._quantitative(x, normalize)
z2: Tensor = self.right_child._quantitative(x, normalize)
size: int = min(z1.size()[2], z2.size()[2])
z1: Tensor = z1[:, :, :size]
z2: Tensor = z2[:, :, :size]
# z1_rep = torch.repeat_interleave(z1.unsqueeze(2), z1.unsqueeze(2).shape[-1], 2)
# z1_tril = torch.tril(z1_rep.transpose(2, 3), diagonal=-1)
# z1_triu = torch.triu(z1_rep)
# z1_def = torch.cummin(z1_tril + z1_triu, dim=3)[0]
# z2_rep = torch.repeat_interleave(z2.unsqueeze(2), z2.unsqueeze(2).shape[-1], 2)
# z2_tril = torch.tril(z2_rep.transpose(2, 3), diagonal=-1)
# z2_triu = torch.triu(z2_rep)
# z2_def = z2_tril + z2_triu
# z: Tensor = torch.max(torch.min(torch.cat([z1_def.unsqueeze(-1), z2_def.unsqueeze(-1)], dim=-1), dim=-1)[0],
# dim=-1)[0]
z: Tensor = torch.cat([torch.max(torch.min(
torch.cat([torch.cummin(z1[:, :, t:].unsqueeze(-1), dim=2)[0], z2[:, :, t:].unsqueeze(-1)], dim=-1),
dim=-1)[0], dim=2, keepdim=True)[0] for t in range(size)], dim=2)
elif self.right_unbound:
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
And(Eventually(self.right_child, right_unbound=True,
left_time_bound=self.left_time_bound),
Eventually(Until(self.left_child, self.right_child, unbound=True),
left_time_bound=self.left_time_bound, right_unbound=True)))
z: Tensor = timed_until._quantitative(x, normalize=normalize)
else:
timed_until: Node = And(Globally(self.left_child, left_time_bound=0, right_time_bound=self.left_time_bound),
And(Eventually(self.right_child, left_time_bound=self.left_time_bound,
right_time_bound=self.right_time_bound-1),
Eventually(Until(self.left_child, self.right_child, unbound=True),
left_time_bound=self.left_time_bound, right_unbound=True)))
z: Tensor = timed_until._quantitative(x, normalize=normalize)
return z
# from anchor_set_generation import anchorGeneration
import re
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import PreTrainedTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
#### utils ####
def load_pickle(path):
with open(path, 'rb') as f:
x = pickle.load(f)
return x
def dump_pickle(name, thing):
with open(name + '.pickle', 'wb') as f:
pickle.dump(thing, f)
def from_string_to_formula(st):
root_arity = 2 if st.startswith('(') else 1
st_split = st.split()
if root_arity <= 1:
root_op_str = copy.deepcopy(st_split[0])
if root_op_str.startswith('x'):
atom_sign = True if st_split[1] == '<=' else False
root_phi = Atom(var_index=int(st_split[0][2]), lte=atom_sign, threshold=float(st_split[2]))
return root_phi
else:
assert (root_op_str.startswith('not') or root_op_str.startswith('eventually')
or root_op_str.startswith('always'))
current_st = copy.deepcopy(st_split[2:-1])
if root_op_str == 'not':
root_phi = Not(child=from_string_to_formula(' '.join(current_st)))
elif root_op_str.startswith('eventually'):
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
root_phi = Eventually(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
right_unbound=right_unbound, left_time_bound=left_time_bound,
right_time_bound=right_time_bound)
else:
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
root_phi = Globally(child=from_string_to_formula(' '.join(current_st)), unbound=unbound,
right_unbound=right_unbound, left_time_bound=left_time_bound,
right_time_bound=right_time_bound)
else:
# 1 - delete everything which is contained in other sets of parenthesis (if any)
current_st = copy.deepcopy(st_split[1:-1])
if '(' in current_st:
par_queue = deque()
par_idx_list = []
for i, sub in enumerate(current_st):
if sub == '(':
par_queue.append(i)
elif sub == ')':
par_idx_list.append(tuple([par_queue.pop(), i]))
# open_par_idx, close_par_idx = [current_st.index(p) for p in ['(', ')']]
# union of parentheses range --> from these we may extract the substrings to be the children!!!
children_range = []
for begin, end in sorted(par_idx_list):
if children_range and children_range[-1][1] >= begin - 1:
children_range[-1][1] = max(children_range[-1][1], end)
else:
children_range.append([begin, end])
n_children = len(children_range)
assert (n_children in [1, 2])
if n_children == 1:
# one of the children is a variable --> need to individuate it
var_child_idx = 1 if children_range[0][0] <= 1 else 0 # 0 is left child, 1 is right child
if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
children_range[0][0] -= 1
left_child_str = current_st[:3] if var_child_idx == 0 else \
current_st[children_range[0][0]:children_range[0][1] + 1]
right_child_str = current_st[-3:] if var_child_idx == 1 else \
current_st[children_range[0][0]:children_range[0][1] + 1]
root_op_str = current_st[children_range[0][1] + 1] if var_child_idx == 1 else \
current_st[children_range[0][0] - 1]
assert (root_op_str[:2] in ['an', 'or', 'un'])
else:
if children_range[0][0] != 0 and current_st[children_range[0][0] - 1][0:2] in ['no', 'ev', 'al']:
children_range[0][0] -= 1
if current_st[children_range[1][0] - 1][0:2] in ['no', 'ev', 'al']:
children_range[1][0] -= 1
# if there are two children, with parentheses, the element in the middle is the root
root_op_str = current_st[children_range[0][1] + 1]
assert (root_op_str[:2] in ['an', 'or', 'un'])
left_child_str = current_st[children_range[0][0]:children_range[0][1] + 1]
right_child_str = current_st[children_range[1][0]:children_range[1][1] + 1]
else:
# no parentheses means that both children are variables
left_child_str = current_st[:3]
right_child_str = current_st[-3:]
root_op_str = current_st[3]
left_child_str = ' '.join(left_child_str)
right_child_str = ' '.join(right_child_str)
if root_op_str == 'and':
root_phi = And(left_child=from_string_to_formula(left_child_str),
right_child=from_string_to_formula(right_child_str))
elif root_op_str == 'or':
root_phi = Or(left_child=from_string_to_formula(left_child_str),
right_child=from_string_to_formula(right_child_str))
else:
unbound, right_unbound, left_time_bound, right_time_bound = set_time_thresholds(root_op_str)
root_phi = Until(left_child=from_string_to_formula(left_child_str),
right_child=from_string_to_formula(right_child_str),
unbound=unbound, right_unbound=right_unbound, left_time_bound=left_time_bound,
right_time_bound=right_time_bound)
return root_phi
def load_json(path: str) -> Union[Dict, List]:
"""
Load a JSON file from the given path.
Args:
path (str): The path to the JSON file to be loaded.
Returns:
Union[Dict, List]: The parsed content of the JSON file, which could be a dictionary or a list.
"""
with open(path, "r") as f:
return json.load(f)
#### phis_generator ####
class StlGenerator:
def __init__(
self,
leaf_prob: float = 0.3,
inner_node_prob: list = None,
threshold_mean: float = 0.0,
threshold_sd: float = 1.0,
unbound_prob: float = 0.1,
right_unbound_prob: float = 0.2,
time_bound_max_range: float = 20,
adaptive_unbound_temporal_ops: bool = True,
max_timespan: int = 100,
):
"""
leaf_prob
probability of generating a leaf (always zero for root)
node_types = ["not", "and", "or", "always", "eventually", "until"]
Inner node types
inner_node_prob
probability vector for the different types of internal nodes
threshold_mean
threshold_sd
mean and std for the normal distribution of the thresholds of atoms
unbound_prob
probability of a temporal operator to have a time bound o the type [0,infty]
time_bound_max_range
maximum value of time span of a temporal operator (i.e. max value of t in [0,t])
adaptive_unbound_temporal_ops
if true, unbounded temporal operators are computed from current point to the end of the signal, otherwise
they are evaluated only at time zero.
max_timespan
maximum time depth of a formula.
"""
# Address the mutability of default arguments
if inner_node_prob is None:
inner_node_prob = [0.166, 0.166, 0.166, 0.17, 0.166, 0.166]
self.leaf_prob = leaf_prob
self.inner_node_prob = inner_node_prob
self.threshold_mean = threshold_mean
self.threshold_sd = threshold_sd
self.unbound_prob = unbound_prob
self.right_unbound_prob = right_unbound_prob
self.time_bound_max_range = time_bound_max_range
self.adaptive_unbound_temporal_ops = adaptive_unbound_temporal_ops
self.node_types = ["not", "and", "or", "always", "eventually", "until"]
self.max_timespan = max_timespan
def sample(self, nvars):
"""
Samples a random formula with distribution defined in class instance parameters
Parameters
----------
nvars : number of variables of input signals
how many variables the formula is expected to consider.
Returns
-------
TYPE
A random formula.
"""
return self._sample_internal_node(nvars)
def bag_sample(self, bag_size, nvars):
"""
Samples a bag of bag_size formulae
Parameters
----------
bag_size : INT
number of formulae.
nvars : INT
number of vars in formulae.
Returns
-------
a list of formulae.
"""
formulae = []
for _ in range(bag_size):
phi = self.sample(nvars)
formulae.append(phi)
return formulae
def _sample_internal_node(self, nvars):
# Declare & dummy-assign "idiom"
node: Union[None, Node]
node = None
# choose node type
nodetype = rnd.choice(self.node_types, p=self.inner_node_prob)
while True:
if nodetype == "not":
n = self._sample_node(nvars)
node = Not(n)
elif nodetype == "and":
n1 = self._sample_node(nvars)
n2 = self._sample_node(nvars)
node = And(n1, n2)
elif nodetype == "or":
n1 = self._sample_node(nvars)
n2 = self._sample_node(nvars)
node = Or(n1, n2)
elif nodetype == "always":
n = self._sample_node(nvars)
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
node = Globally(
n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
)
elif nodetype == "eventually":
n = self._sample_node(nvars)
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
node = Eventually(
n, unbound, right_unbound, left_time_bound, right_time_bound, self.adaptive_unbound_temporal_ops
)
elif nodetype == "until":
n1 = self._sample_node(nvars)
n2 = self._sample_node(nvars)
unbound, right_unbound, left_time_bound, right_time_bound = self._get_temporal_parameters()
node = Until(
n1, n2, unbound, right_unbound, left_time_bound, right_time_bound
)
if (node is not None) and (node.time_depth() < self.max_timespan):
return node
def _sample_node(self, nvars):
if rnd.rand() < self.leaf_prob:
# sample a leaf
var, thr, lte = self._get_atom(nvars)
return Atom(var, thr, lte)
else:
return self._sample_internal_node(nvars)
def _get_temporal_parameters(self):
if rnd.rand() < self.unbound_prob:
return True, False, 0, 0
elif rnd.rand() < self.right_unbound_prob:
return False, True, rnd.randint(self.time_bound_max_range), 1
else:
left_bound = rnd.randint(self.time_bound_max_range)
return False, False, left_bound, rnd.randint(left_bound, self.time_bound_max_range) + 1
def _get_atom(self, nvars):
variable = rnd.randint(nvars)
lte = rnd.rand() > 0.5
threshold = rnd.normal(self.threshold_mean, self.threshold_sd)
return variable, threshold, lte
#### traj_measure ####
class Measure:
def sample(self, samples=100000, varn=2, points=100):
# Must be overridden
pass
class BaseMeasure(Measure):
def __init__(
self, mu0=0.0, sigma0=1.0, mu1=0.0, sigma1=1.0, q=0.1, q0=0.5, device="cpu"
):
"""
Parameters
----------
mu0 : mean of normal distribution of initial state, optional
The default is 0.0.
sigma0 : standard deviation of normal distribution of initial state, optional
The default is 1.0.
mu1 : DOUBLE, optional
mean of normal distribution of total variation. The default is 0.0.
sigma1 : standard deviation of normal distribution of total variation, optional
The default is 1.0.
q : DOUBLE, optional
probability of change of sign in derivative. The default is 0.1.
q0 : DOUBLE, optional
probability of initial sign of derivative. The default is 0.5.
device : 'cpu' or 'cuda', optional
device on which to run the algorithm. The default is 'cpu'.
Returns
-------
None.
"""
self.mu0 = mu0
self.sigma0 = sigma0
self.mu1 = mu1
self.sigma1 = sigma1
self.q = q
self.q0 = q0
self.device = device
def sample(self, samples=100000, varn=2, points=100):
"""
Samples a set of trajectories from the basic measure space, with parameters
passed to the sampler
Parameters
----------
points : INT, optional
number of points per trajectory, including initial one. The default is 1000.
samples : INT, optional
number of trajectories. The default is 100000.
varn : INT, optional
number of variables per trajectory. The default is 2.
Returns
-------
signal : samples x varn x points double pytorch tensor
The sampled signals.
"""
if self.device == "cuda" and not torch.cuda.is_available():
raise RuntimeError("GPU card or CUDA library not available!")
# generate unif RN
signal = torch.rand(samples, varn, points, device=self.device)
# first point is special - set to zero for the moment, and set one point to 1
signal[:, :, 0] = 0.0
signal[:, :, -1] = 1.0
# sorting each trajectory
signal, _ = torch.sort(signal, 2)
# computing increments and storing them in points 1 to end
signal[:, :, 1:] = signal[:, :, 1:] - signal[:, :, :-1]
# generate initial state, according to a normal distribution
signal[:, :, 0] = self.mu0 + self.sigma0 * torch.randn(signal[:, :, 0].size())
# sampling change signs from bernoulli in -1, 1
derivs = (1 - self.q) * torch.ones(samples, varn, points, device=self.device)
derivs = 2 * torch.bernoulli(derivs) - 1
# sampling initial derivative
derivs[:, :, 0] = self.q0
derivs[:, :, 0] = 2 * torch.bernoulli(derivs[:, :, 0]) - 1
# taking the cumulative product along axis 2
derivs = torch.cumprod(derivs, 2)
# sampling total variation
totvar = torch.pow(
self.mu1 + self.sigma1 * torch.randn(samples, varn, 1, device=self.device),
2,
)
# multiplying total variation and derivatives and making initial point non-invasive
derivs = derivs * totvar
derivs[:, :, 0] = 1.0
# computing trajectories by multiplying and then doing a cumulative sum
signal = signal * derivs
signal = torch.cumsum(signal, 2)
return signal
#### kernel ####
realnum = Union[float, int]
class StlKernel:
def __init__(
self,
measure,
normalize=True,
exp_kernel=True,
sigma2=0.2, # 0.5 meglio, inizialmente era a 0.2
integrate_time=False,
samples=100000,
varn=2,
points=100,
boolean=False,
signals=None,
):
self.traj_measure = measure
self.exp_kernel = exp_kernel
self.normalize = normalize
self.sigma2 = sigma2
self.samples = samples
self.varn = varn
self.points = points
self.integrate_time = integrate_time
if signals is not None:
self.signals = signals
else:
self.signals = measure.sample(points=points, samples=samples, varn=varn)
self.boolean = boolean
def compute(self, phi1, phi2):
return self.compute_one_one(phi1, phi2)
def compute_one_one(self, phi1, phi2):
phis1: list = [phi1]
phis2: list = [phi2]
ker = self.compute_bag_bag(phis1, phis2)
return ker[0, 0]
def compute_bag(self, phis, return_robustness=True):
if self.integrate_time:
rhos, selfk, len0 = self._compute_robustness_time(phis)
kernel_matrix = self._compute_kernel_time(
rhos, rhos, selfk, selfk, len0, len0
)
else:
rhos, selfk = self._compute_robustness_no_time(phis)
kernel_matrix = self._compute_kernel_no_time(rhos, rhos, selfk, selfk)
len0 = None
if return_robustness:
return kernel_matrix.cpu(), rhos, selfk, len0
else:
return kernel_matrix.cpu()
def compute_one_bag(self, phi1, phis2, return_robustness=False):
phis1: list = [phi1]
return self.compute_bag_bag(phis1, phis2, return_robustness)
def compute_bag_bag(self, phis1, phis2, return_robustness=False):
if self.integrate_time:
rhos1, selfk1, len1 = self._compute_robustness_time(phis1)
rhos2, selfk2, len2 = self._compute_robustness_time(phis2)
kernel_matrix = self._compute_kernel_time(
rhos1, rhos2, selfk1, selfk2, len1, len2
)
else:
rhos1, selfk1 = self._compute_robustness_no_time(phis1)
rhos2, selfk2 = self._compute_robustness_no_time(phis2)
len1, len2 = [None, None]
kernel_matrix = self._compute_kernel_no_time(rhos1, rhos2, selfk1, selfk2)
if return_robustness:
return kernel_matrix.cpu(), rhos1, rhos2, selfk1, selfk2, len1, len2
else:
return kernel_matrix.cpu()
def compute_one_from_robustness(self, phi, rhos, rho_self, lengths=None, return_robustness=False):
phis: list = [phi]
return self.compute_bag_from_robustness(phis, rhos, rho_self, lengths, return_robustness)
def compute_bag_from_robustness(self, phis, rhos, rho_self, lengths=None, return_robustness=False):
if self.integrate_time:
rhos1, selfk1, len1 = self._compute_robustness_time(phis)
kernel_matrix = self._compute_kernel_time(
rhos1, rhos, selfk1, rho_self, len1, lengths
)
else:
rhos1, selfk1 = self._compute_robustness_no_time(phis)
len1 = None
kernel_matrix = self._compute_kernel_no_time(rhos1, rhos, selfk1, rho_self)
if return_robustness:
return kernel_matrix.cpu(), rhos1, selfk1, len1
else:
return kernel_matrix.cpu()
n = self.samples
p = self.points
k = len(phis)
rhos = torch.zeros((k, n, p), device="cpu")
lengths = torch.zeros(k)
self_kernels = torch.zeros((k, 1))
for i, phi in enumerate(phis):
if self.boolean:
rho = phi.boolean(self.signals, evaluate_at_all_times=True).float()
rho[rho == 0.0] = -1.0
else:
rho = phi.quantitative(self.signals, evaluate_at_all_times=True)
actual_p = rho.size()[2]
rho = rho.reshape(n, actual_p).cpu()
rhos[i, :, :actual_p] = rho
lengths[i] = actual_p
self_kernels[i] = torch.tensordot(
rho.reshape(1, n, -1), rho.reshape(1, n, -1), dims=[[1, 2], [1, 2]]
) / (actual_p * n)
return rhos, self_kernels, lengths
def _compute_robustness_no_time(self, phis):
n = self.samples
k = len(phis)
rhos = torch.zeros((k, n), device=self.traj_measure.device)
self_kernels = torch.zeros((k, 1), device=self.traj_measure.device)
for i, phi in enumerate(phis):
if self.boolean:
rho = phi.boolean(self.signals, evaluate_at_all_times=False).float()
rho[rho == 0.0] = -1.0
else:
rho = phi.quantitative(self.signals, evaluate_at_all_times=False)
self_kernels[i] = rho.dot(rho) / n
rhos[i, :] = rho
return rhos, self_kernels
def _compute_kernel_time(self, rhos1, rhos2, selfk1, selfk2, len1, len2):
kernel_matrix = torch.tensordot(rhos1, rhos2, [[1, 2], [1, 2]])
length_normalizer = self._compute_trajectory_length_normalizer(len1, len2)
kernel_matrix = kernel_matrix * length_normalizer / self.samples
if self.normalize:
kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
if self.exp_kernel:
kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
return kernel_matrix
def _compute_kernel_no_time(self, rhos1, rhos2, selfk1, selfk2):
kernel_matrix = torch.tensordot(rhos1, rhos2, [[1], [1]])
kernel_matrix = kernel_matrix / self.samples
if self.normalize:
kernel_matrix = self._normalize(kernel_matrix, selfk1, selfk2)
if self.exp_kernel:
kernel_matrix = self._exponentiate(kernel_matrix, selfk1, selfk2)
return kernel_matrix
@staticmethod
def _normalize(kernel_matrix, selfk1, selfk2):
normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
kernel_matrix = kernel_matrix / normalize
return kernel_matrix
@staticmethod
def _normalize(kernel_matrix, selfk1, selfk2):
normalize = torch.sqrt(torch.matmul(selfk1, torch.transpose(selfk2, 0, 1)))
kernel_matrix = kernel_matrix / normalize
return kernel_matrix
def _exponentiate(self, kernel_matrix, selfk1, selfk2, sigma2=None):
if sigma2 is None:
sigma2 = self.sigma2
if self.normalize:
# selfk is (1.0^2 + 1.0^2)
selfk = 2.0
else:
k1 = selfk1.size()[0]
k2 = selfk2.size()[0]
selfk = (selfk1 * selfk1).repeat(1, k2) + torch.transpose(
selfk2 * selfk2, 0, 1
).repeat(k1, 1)
return torch.exp(-(selfk - 2 * kernel_matrix) / (2 * sigma2))
@staticmethod
def _compute_trajectory_length_normalizer(len1, len2):
k1 = len1.size()[0]
k2 = len2.size()[0]
y1 = len1.reshape(-1, 1)
y1 = y1.repeat(1, k2)
y2 = len2.repeat(k1, 1)
return 1.0 / torch.min(y1, y2)
class GramMatrix:
def __init__(self, kernel, formulae, store_robustness=True, sample=False, sampler=None, bag_size=None):
self.kernel = kernel
self.formulae_list = formulae
# if kernel is computed from robustness at time zero only,
# we store the robustness for each formula and each sample
# to speed up computation later
self.store_robustness = store_robustness
self.dim = len(self.formulae_list) if not bag_size else int(bag_size)
self.sample = sample # whether to generate formulae in a controlled manner
if self.sample:
self.t = 0.99 if self.kernel.boolean else 0.85
self.sampler = sampler # stl formulae generator
self._compute_gram_matrix()
def _compute_gram_matrix(self):
if self.sample:
gram = torch.zeros(self.dim, self.dim)
rhos = torch.zeros((self.dim, self.kernel.samples), device=self.kernel.traj_measure.device) if \
not self.kernel.integrate_time else torch.zeros((self.dim, self.kernel.samples, self.kernel.points),
device=self.kernel.traj_measure.device)
lengths = torch.zeros(self.dim) if self.kernel.integrate_time else np.zeros(self.dim)
kernels = torch.zeros((self.dim, 1), device=self.kernel.traj_measure.device)
phis = [self.sampler.sample(nvars=self.kernel.varn)]
gram[0, :1], rhos[0], kernels[0, :], lengths[0] = self.kernel.compute_bag(phis, return_robustness=True)
while len(phis) < self.dim:
i = len(phis)
phi = self.sampler.sample(nvars=self.kernel.varn)
gram[i, :i], rhos[i], kernels[i, :], lengths[i] = self.kernel.compute_one_from_robustness(
phi, rhos[:i, :], kernels[:i, :], lengths[:i], return_robustness=True)
if torch.sum(gram[i, :i + 1] >= self.t) < 3:
phis.append(phi)
gram[:i, i] = gram[i, :i]
gram[i, i] = kernels[i, :]
self.formulae_list = phis
self.gram = gram.cpu()
self.robustness = rhos if self.store_robustness else None
self.self_kernels = kernels if self.store_robustness else None
self.robustness_lengths = lengths if self.store_robustness else None
else:
if self.store_robustness:
k_matrix, rhos, selfk, len0 = self.kernel.compute_bag(
self.formulae_list, return_robustness=True
)
self.gram = k_matrix
self.robustness = rhos
self.self_kernels = selfk
self.robustness_lengths = len0
else:
self.gram = self.kernel.compute_bag(
self.formulae_list, return_robustness=False
)
self.robustness = None
self.self_kernels = None
self.robustness_lengths = None
def compute_kernel_vector(self, phi):
if self.store_robustness:
return self.kernel.compute_one_from_robustness(
phi, self.robustness, self.self_kernels, self.robustness_lengths
)
else:
return self.kernel.compute_one_bag(phi, self.formulae_list)
def compute_bag_kernel_vector(self, phis, generate_phis=False, bag_size=None):
if generate_phis:
gram_test = torch.zeros(bag_size, self.dim) # self.dim, bag_size
rhos_test = torch.zeros((bag_size, self.kernel.samples), device=self.kernel.traj_measure.device) if \
not self.kernel.integrate_time else torch.zeros((bag_size, self.kernel.samples, self.kernel.points),
device=self.kernel.traj_measure.device)
lengths_test = torch.zeros(bag_size) if self.kernel.integrate_time else np.zeros(bag_size)
kernels_test = torch.zeros((bag_size, 1), device=self.kernel.traj_measure.device)
phi_test = []
while len(phi_test) < bag_size:
i = len(phi_test)
phi = self.sampler.sample(nvars=self.kernel.varn)
if self.store_robustness:
gram_test[i, :], rhos_test[i], kernels_test[i, :], lengths_test[i] = \
self.kernel.compute_one_from_robustness(phi, self.robustness, self.self_kernels,
self.robustness_lengths, return_robustness=True)
else:
gram_test[i, :], rhos_test[i], _, kernels_test[i, :], _, lengths_test[i], _ = \
self.kernel.compute_one_bag(phi, self.formulae_list, return_robustness=True)
if not ((rhos_test[i] > 0).all() or (rhos_test[i] < 0).all()):
phi_test.append(phi)
return phi_test, gram_test.cpu()
else:
if self.store_robustness:
return self.kernel.compute_bag_from_robustness(
phis, self.robustness, self.self_kernels, self.robustness_lengths
)
else:
return self.kernel.compute_bag_bag(phis, self.formulae_list)
def invert_regularized(self, alpha):
regularizer = abs(pow(10, alpha)) * torch.eye(self.dim)
return torch.inverse(self.gram + regularizer)
#### anchor_generation ####
def anchorGeneration(diff_init = False, # to control whether we want formulae to be semantically different by construction
embed_dim: int = 30, # embedding dimension, aka number of generated formulae in the anchor set
n_vars: int = 3, # dimension of the input signal (3D in this case)
leaf_prob: float = 0.4, # complexity of the generated formula
cosine_similarity_threshold: float = 0.8 # if two formulae cosine similarity exceeds 0.9, then discard one of the two
) -> str:
# initialize STL formula generator
sampler = StlGenerator(leaf_prob)
# effective anchor set generation
if diff_init:
# initialize the anchor set with a randomly sampled formula
diff_anchor_set = [sampler.sample(nvars=n_vars)]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mu = BaseMeasure(device=device)
# generates a set of random signals working as a tester for the formulae testing
signals = mu.sample(samples=10000, varn=n_vars)
# computes robustness value for the initial set of formulae in the anchor set
anchor_rob_vectors = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in diff_anchor_set], 0)
while len(diff_anchor_set) < embed_dim:
# sample the 'remaining' formulae to reach the desired number of `embed_dim` formulae:
candidate_anchors = sampler.bag_sample(embed_dim - len(diff_anchor_set), nvars = n_vars)
# compute robustness of candidate anchor formulae on the same signals as previous anchor set
candidate_robs = torch.cat([phi.quantitative(signals, normalize=True).unsqueeze(0) for phi in candidate_anchors], 0)
# compute cosine similarity between current anchor set and candidate new formulae
cos_simil = torch.tril(normalize(candidate_robs) @ normalize(anchor_rob_vectors).t(), diagonal=-1)
# check which formulae are similar (i.e. greater cosine similarity then threshold) w.r.t. current anchors
# NOTA: chiedere a gaia se cosine similarities negative vanno ammazzate con un valore assoluto o meno!
similar_idx = [torch.where(cos_simil[r, :] > cosine_similarity_threshold)[0].tolist() for r in range(cos_simil.shape[0])]
# keep only those who are semantically distant
keep_idx = list(set(np.arange(len(candidate_anchors)).tolist()).difference(set([i for sublist in similar_idx for i in sublist])))
diff_anchor_set += [copy.deepcopy(candidate_anchors[i]) for i in keep_idx]
# Convert keep_idx to a tensor on the same device as candidate_robs
keep_idx_tensor = torch.tensor(keep_idx, device=candidate_robs.device)
# Use index_select to pick the relevant rows
selected_robs = torch.index_select(candidate_robs, 0, keep_idx_tensor)
# Concatenate on the same device
anchor_rob_vectors = torch.cat([anchor_rob_vectors, copy.deepcopy(selected_robs)], dim=0)
anchor_set = diff_anchor_set[:embed_dim]
else:
anchor_set = sampler.bag_sample(bag_size=embed_dim, nvars=n_vars)
filename = f'anchor_set_no_diff_{embed_dim}_dim'
dump_pickle(filename, anchor_set)
return filename
####
"""
A custom tokenizer class that extends `PreTrainedTokenizer` to handle a specific vocabulary and tokenization process.
This tokenizer can load a vocabulary from a JSON file, tokenize text, convert tokens to IDs,
and handle padding and special tokens.
"""
def __init__(self, vocab_path: str, unk_token: str = "unk", pad_token: str = "pad",
bos_token: str = "/s", eos_token: str = "s", model_max_length = 512, *args, **kwargs):
"""
Initializes the STLTokenizer with a given vocabulary and special tokens.
Args:
vocab_path (str): The path to the JSON file containing the vocabulary.
unk_token (str, optional): The token used for unknown words. Defaults to "unk".
pad_token (str, optional): The token used for padding. Defaults to "pad".
bos_token (str, optional): The token used for the beginning of a sequence. Defaults to "/s".
eos_token (str, optional): The token used for the end of a sequence. Defaults to "s".
"""
self.vocab = load_json(vocab_path)
self.unk_token = unk_token
self.pad_token = pad_token
self.bos_token = bos_token
self.eos_token = eos_token
self.model_max_length = model_max_length
self.id_to_token = {v: k for k, v in self.vocab.items()} # Reverse mapping
super().__init__(unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token,
model_max_length=model_max_length, *args, **kwargs)
@property
def vocab_size(self) -> int:
"""
Returns the size of the vocabulary.
Returns:
int: The number of tokens in the vocabulary.
"""
return len(self.vocab)
def prepad_sequence(self, sequence, space_token = ' ', new_space_token = '@', undo = False):
"""
Replaces spaces in the input sequence with a specified token.
Args:
sequence (str): The input sequence.
undo (bool): If True, replace the padding token with spaces. Defaults to False, which pads the spaces.
Returns:
str: The preprocessed sequence with spaces or padding tokens replaced.
"""
if undo:
return sequence.replace(new_space_token, space_token)
else:
return sequence.replace(space_token, new_space_token)
def add_bos_eos(self, sequence: str) -> str:
"""
Aggiunge i token BOS all'inizio e EOS alla fine della sequenza.
Args:
sequence (str): La sequenza di input.
Returns:
str: La sequenza con i token BOS ed EOS.
"""
return f'{self.bos_token} {sequence} {self.eos_token}'
def tokenize(self, text: str) -> List[str]:
"""
Tokenizes the input text into a list of tokens.
The method preprocesses the input text by replacing spaces with padding tokens and then tries to
find the longest possible match for each substring in the vocabulary.
Args:
text (str): The input text to be tokenized.
Returns:
List[str]: A list of tokens representing the tokenized text.
"""
text = self.add_bos_eos(text)
text = self.prepad_sequence(text)
tokens = []
i = 0
while i < len(text):
best_match = None
for j in range(len(text), i, -1): # Try matching substrings of decreasing length
subtoken = text[i:j]
if subtoken in self.vocab:
best_match = subtoken
break
if best_match:
tokens.append(best_match)
i += len(best_match)
else:
tokens.append(self.unk_token)
i += 1
return tokens
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
"""
Converts a list of tokens into a list of token IDs.
Args:
tokens (List[str]): A list of tokens to be converted into IDs.
Returns:
List[int]: A list of corresponding token IDs.
"""
return [self.vocab.get(token, self.vocab[self.unk_token]) for token in tokens]
def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
"""
Converts a list of token IDs into a list of tokens.
Args:
ids (List[int]): A list of token IDs to be converted into tokens.
Returns:
List[str]: A list of corresponding tokens.
"""
return [self.id_to_token.get(i, self.unk_token) for i in ids]
def encode(self, sequence: str) -> List[int]:
"""
Encodes a string sequence into a list of token IDs.
This method tokenizes the input sequence using the `tokenize` method,
and then converts the resulting tokens into their corresponding token IDs
using the `convert_tokens_to_ids` method.
Args:
sequence (str): The input sequence (text) to be encoded.
Returns:
List[int]: A list of token IDs corresponding to the input sequence.
"""
splitted_sequence = self.tokenize(sequence)
return self.convert_tokens_to_ids(splitted_sequence)
def postpad_sequence(self, sequence, pad_token_id):
"""
Fills the sequence up to max_length padding elements
"""
num_extra_elements = self.model_max_length - len(sequence) -1
if num_extra_elements > 0:
sequence.extend([pad_token_id] * num_extra_elements)
return sequence
def decode(self, token_ids: List[int]) -> str:
"""
Decodes a list of token IDs into a string of text.
The method converts the IDs to tokens and joins them to form a string.
It also restores the original spaces or padding tokens if `undo` is True.
Args:
token_ids (List[int]): A list of token IDs to be decoded.
skip_special_tokens (bool, optional): Whether to skip special tokens during decoding. Defaults to False.
Returns:
str: The decoded string.
"""
tokens = self.convert_ids_to_tokens(token_ids)
decoded = "".join(tokens)
return self.prepad_sequence(decoded, undo=True)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Saves the tokenizer's vocabulary to a file.
Useful only when the vocabulary has to be retrieved and is not given
(thus this is not the case: here to further improvements with sentencepiece).
This method saves the vocabulary to a JSON file in the specified directory.
Args:
save_directory (str): The directory where the vocabulary file will be saved.
filename_prefix (Optional[str]): An optional prefix for the filename.
Returns:
Tuple[str]: A tuple containing the path to the saved vocabulary file.
"""
vocab_file = f"{save_directory}/{filename_prefix + '-' if filename_prefix else ''}vocab.json"
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self.vocab, f, indent=2, ensure_ascii=False)
return (vocab_file,)
def get_vocab(self) -> dict:
"""
Retrieves the vocabulary used by the tokenizer.
Returns:
dict: The vocabulary as a dictionary.
"""
return self.vocab
class STLSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
return out
@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions)
class STLAttention(nn.Module):
""" Multi-Head Attention as depicted from 'Attention is all you need' """
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0,
is_decoder: bool = False, bias: bool = False, is_causal: bool = False):
super().__init__()
self.embed_dim = embed_dim # overall embedding dimension -> to be divided between multiple heads
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (self.head_dim * num_heads) == self.embed_dim
self.scaling = self.head_dim ** -0.5 # used to normalize values when projected using `W_` matrices
self.is_decoder = is_decoder
self.is_causal = is_causal
# 'roleplaying' matrices
self.W_k = nn.Linear(embed_dim, embed_dim, bias = bias)
self.W_q = nn.Linear(embed_dim, embed_dim, bias = bias)
self.W_v = nn.Linear(embed_dim, embed_dim, bias = bias)
# to project the heads' outputs into a single vector
self.W_o = nn.Linear(embed_dim, embed_dim, bias = bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(self,
hidden_states: torch.Tensor, # previous values, passed to the multi-head attn layer
key_value_states: Optional[torch.Tensor] = None, # different key, value items (used in cross-attn)
past_key_value: Optional[Tuple[torch.Tensor]] = None, # stores the key and values of previous steps
attention_mask: Optional[torch.Tensor] = None, # masks non-allowed items (padded or future ones)
layer_head_mask: Optional[torch.Tensor] = None, # used to de-activate specific attn heads
output_attentions: bool = False # flag to control the output of the attn values,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
is_cross_attention = key_value_states is not None # cross-attn if key_value_states is not None
batch_size, tgt_len, embed_dim = hidden_states.size()
# Project the current input in the `query` role:
query = self.W_q(hidden_states) * self.scaling
if (is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1]):
key = past_key_value[0]
value = past_key_value[1]
elif is_cross_attention:
key = self._shape(self.W_k(key_value_states), -1, batch_size)
value = self._shape(self.W_v(key_value_states), -1, batch_size)
elif past_key_value is not None:
key = self._shape(self.W_k(hidden_states), -1, batch_size)
value = self._shape(self.W_v(hidden_states), -1, batch_size)
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)
else:
key = self._shape(self.W_k(hidden_states), -1, batch_size)
value = self._shape(self.W_v(hidden_states), -1, batch_size)
if self.is_decoder:
past_key_value = (key, value)
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
query = self._shape(query, tgt_len, batch_size).view(*proj_shape)
key = key.reshape(*proj_shape)
value = value.reshape(*proj_shape)
src_len = key.size(1)
######################################################################################################
# 'traditional' attention computation
# i.e. softmax(Q*K^T / sqrt(d_model) + self_attn_mask) * V
# Batch-wise matrix multiplication between `query` and (TRANSPOSED) `key`
attn_weights = torch.bmm(query, key.transpose(1, 2))
if attention_mask is not None:
attn_weights = attn_weights.view(batch_size, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
# Normalize values on the `key` axis (dim=-1)
attn_weights = F.softmax(attn_weights, dim=-1)
# if layer_head_mask is not None:
# attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(batch_size, self.num_heads, tgt_len, src_len)
# attn_weights = attn_weights.view(batch_size * self.num_heads, tgt_len, src_len)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
# Batch-wise matrix multiplication between the resulting probs and the value
attn_output = torch.bmm(attn_probs, value)
######################################################################################################
attn_output = attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, tgt_len, self.embed_dim)
attn_output = self.W_o(attn_output)
return attn_output, None, past_key_value
####
class STLEncoder():
def __init__(self,
embed_dim: int,
anchor_filename: Optional[str] = None,
n_vars: int = 3):
self.n_vars = n_vars # passaglielo in input
self.embed_dim = embed_dim
self.anchorset_filename = anchor_filename
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.mu = BaseMeasure(device=self.device)
self.kernel = StlKernel(self.mu, varn=self.n_vars)
if anchor_filename is None:
anchor_filename = anchorGeneration(diff_init = True, embed_dim = self.embed_dim, n_vars = self.n_vars)
anchor_filename+='.pickle'
# TO DO: check on the dimensions of the anchor set and the `embed_dim` and `n_vars` values
anchor_set = load_pickle(anchor_filename)
if len(anchor_set) != self.embed_dim:
raise ValueError("The anchor set and the embedding dimension do not match!")
self.anchor_set = anchor_set
def compute_embeddings(self, formula: List[str]):
return self.kernel.compute_bag_bag(formula, self.anchor_set)
class STLModel(PreTrainedModel):
config_class = STLConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
# initializes the weights of `nn.Linear`, `nn.Embedding` and `STLSinusoidalPositionalEmbedding`
def _init_weights(self, module: Union[nn.Linear, nn.Embedding, STLSinusoidalPositionalEmbedding]):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, STLSinusoidalPositionalEmbedding):
pass
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
"decoder_input_ids": input_ids,
}
return dummy_inputs
class STLDecoderBlock(nn.Module):
def __init__(self, embed_dim: int,
num_decoder_attention_heads: int,
num_decoder_ffn_dim: int,
dropout: float = 0.0,
attention_dropout: float = 0.0,
activation_dropout: float = 0.0,
):
super().__init__()
self.embed_dim = embed_dim
# first block
self.self_attn = STLAttention(
embed_dim=self.embed_dim,
num_heads=num_decoder_attention_heads,
dropout=dropout,
is_decoder=True, # not used, debugging purposes
is_causal=True, # not used, debugging purposes
)
self.dropout = dropout
self.activation_fn = nn.functional.gelu
self.activation_dropout = activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# second block
self.encoder_attn = STLAttention(
self.embed_dim,
num_decoder_attention_heads,
dropout=attention_dropout,
is_decoder=True, # not used, debugging purposes
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# third block
self.fc1 = nn.Linear(self.embed_dim, num_decoder_ffn_dim)
self.fc2 = nn.Linear(num_decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
###################################################################
# BLOCK 1: processing what has been previously generated
# previous state is stored into an auxiliary variable `residual`
residual = hidden_states
# tries to exploit previous K, V values if there are any
# (practically picks up to the first 2 values stored in `past_key_value` vector)
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# masked MHSA on the already generated sequence
# invokes `forward` method to transform the original vector accordingly
hidden_states, self_attn_weights, present_key_value = self.self_attn.forward(
hidden_states=hidden_states, # Q
past_key_value=self_attn_past_key_value, # K, V
attention_mask=attention_mask, # passed as input of the decoder layer
layer_head_mask=layer_head_mask, # to deactivate certain attn layers
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# residual connection
hidden_states = residual + hidden_states
# normalization
hidden_states = self.self_attn_layer_norm(hidden_states)
###################################################################
# BLOCK 2: cross-attn between already generated input and previous information (from the encoder)
# initialize K, Q, attn_weights for this new attn operation
cross_attn_present_key_value = None
cross_attn_weights = None
# the important condition is that the encoder carries some information
if encoder_hidden_states is not None:
# previous state is stored into an auxiliary variable `residual`
residual = hidden_states
# cross_attn cached key/values tuple is at positions 3, 4 of PAST_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
# MHSA in cross-attn
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn.forward(
hidden_states=hidden_states, # Q = generated output
key_value_states=encoder_hidden_states, # K, V = encoder memory (used only in the 1st step when `use_cache = True`)
attention_mask=encoder_attention_mask, # just pads some elements (not causal this time!)
layer_head_mask=cross_attn_layer_head_mask, # again to mask certain heads
past_key_value=cross_attn_past_key_value, # K, V = encoder CACHED memory (used from the 2nd step on when `use_cache = True`)
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# residual connection
hidden_states = residual + hidden_states
# normalization
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# add cross-attn to positions 3, 4 of PRESENT_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
###################################################################
# BLOCK 3: FFNN (transforming some merged generated output - encoder information)
# previous state is stored into an auxiliary variable `residual`
residual = hidden_states
# FFNN - core
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# residual connection
hidden_states = residual + hidden_states
# normalization
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache: # otherwise it picks K and V each time
outputs += (present_key_value,)
return outputs
class STLDecoder(STLModel):
def __init__(self, config):
super().__init__(config)
# Extract from `config` file
embed_dim = config.d_model
num_decoder_attention_heads = config.decoder_attention_heads
num_decoder_ffn_dim = config.decoder_ffn_dim
max_position_embeddings = config.max_position_embeddings
decoder_vocab_size = config.vocab_size
pad_token_id = config.pad_token_id
num_decoder_layers = config.decoder_layers
scale_embedding = config.scale_embedding
dropout = config.dropout
attention_dropout = config.attention_dropout
activation_dropout = config.activation_dropout
decoder_layerdrop = config.decoder_layerdrop
self.dropout = dropout
self.layerdrop = decoder_layerdrop
self.padding_idx = pad_token_id
self.max_target_positions = max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if scale_embedding else 1.0
# Initialize the input embedding (if not passed already)
self.embed_tokens = nn.Embedding(decoder_vocab_size, embed_dim, self.padding_idx)
# Initialize positional embedding also
self.embed_positions = STLSinusoidalPositionalEmbedding(
max_position_embeddings, embed_dim, self.padding_idx
)
# Initialize decoder layers (of a prespecified number)
self.layers = nn.ModuleList([STLDecoderBlock(embed_dim, num_decoder_attention_heads,
num_decoder_ffn_dim, dropout,
attention_dropout, activation_dropout)
for _ in range(num_decoder_layers)])
self.gradient_checkpointing = False
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions
positions = self.embed_positions(input_shape, past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
####
class STLForCausalLM(STLModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
super().__init__(config)
self.model = STLDecoder(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None, # input sequence
attention_mask: Optional[torch.Tensor] = None, # masked MHSA + padding
encoder_hidden_states: Optional[torch.FloatTensor] = None, # embedding
encoder_attention_mask: Optional[torch.FloatTensor] = None, # MHSA + padding
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, # output sequence
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past