|
|
import torch |
|
|
import torch.nn as nn |
|
|
import random |
|
|
from torch.autograd import Variable |
|
|
|
|
|
class BraLM(nn.Module): |
|
|
def __init__(self, hidden_size, use_ds=False, zero_freq_edges=None, vocab=None): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.activation = nn.GELU() |
|
|
self.positions = nn.Parameter(torch.ones(1, 512, 1)) |
|
|
self.device = None |
|
|
|
|
|
|
|
|
self._tied_weights_keys = [] |
|
|
|
|
|
self.use_ds = use_ds |
|
|
self.zero_freq_edges = zero_freq_edges |
|
|
self.vocab = vocab |
|
|
|
|
|
def prepare_network(self, vocab): |
|
|
|
|
|
self.weight_indices = {} |
|
|
self.shared_param_idx = 0 |
|
|
|
|
|
|
|
|
current_idx = 1 |
|
|
|
|
|
|
|
|
for s_idx, s in enumerate(vocab.edge_dict): |
|
|
for t_idx, t in enumerate(vocab.edge_dict[s]): |
|
|
if self.zero_freq_edges is not None and t in self.zero_freq_edges[s]: |
|
|
|
|
|
self.weight_indices[(s_idx, t_idx)] = self.shared_param_idx |
|
|
else: |
|
|
self.weight_indices[(s_idx, t_idx)] = current_idx |
|
|
current_idx += 1 |
|
|
|
|
|
|
|
|
self.weights = nn.Parameter(torch.randn(current_idx, self.hidden_size, self.hidden_size).uniform_(-0.5, 0.5)) |
|
|
self.biases = nn.Parameter(torch.randn(current_idx, 1, self.hidden_size).uniform_(-0.5, 0.5)) |
|
|
|
|
|
self.node_bias = nn.Parameter(torch.randn(len(vocab.edge_dict), 1, self.hidden_size).uniform_(-0.5, 0.5)) |
|
|
|
|
|
def to_device(self, device): |
|
|
self.weights.data = self.weights.data.to(device) |
|
|
self.biases.data = self.biases.data.to(device) |
|
|
self.node_bias.data = self.node_bias.data.to(device) |
|
|
self.positions.data = self.positions.data.to(device) |
|
|
self.device = device |
|
|
|
|
|
@staticmethod |
|
|
def _reshape12(x): |
|
|
return x.reshape(-1, x.size(-2), x.size(-1)) |
|
|
|
|
|
def get_positional_encoding(self, seq_len, d_model): |
|
|
position = torch.arange(0, seq_len).reshape(-1, 1) |
|
|
div_term = 10000.0 ** (torch.arange(0, d_model, 2) / d_model) |
|
|
position_encoding = torch.zeros(seq_len, d_model) |
|
|
position_encoding[:, 0::2] = torch.sin(position * div_term) |
|
|
position_encoding[:, 1::2] = torch.cos(position * div_term) |
|
|
return position_encoding.unsqueeze(0).to(self.device) |
|
|
|
|
|
def get_initial_tensor(self, batch_size, d, pe): |
|
|
|
|
|
energy_tensor = torch.ones(batch_size, 1, self.hidden_size) / self.hidden_size |
|
|
energy_tensor = energy_tensor.to(self.device) |
|
|
|
|
|
|
|
|
d = d.to(self.device) |
|
|
node_bias = self.node_bias[d[:, 0, 0]] |
|
|
energy_tensor = self.activation(energy_tensor + node_bias + pe[:,0]) |
|
|
return energy_tensor |
|
|
|
|
|
|
|
|
def forward(self, neighbor_ids): |
|
|
|
|
|
batch_size = neighbor_ids.size(0) |
|
|
loss = 0 |
|
|
|
|
|
pe = self.get_positional_encoding(512, self.hidden_size) |
|
|
|
|
|
for i in range(neighbor_ids.size(1)): |
|
|
d = neighbor_ids[:, i] |
|
|
|
|
|
if i == 0: |
|
|
|
|
|
energy_tensor = self.get_initial_tensor(batch_size, d, pe) |
|
|
else: |
|
|
energy_tensor = (energy_cache * self.positions[:, :i, :].softmax(1)).sum(1, keepdim=True) |
|
|
|
|
|
|
|
|
src_idx = d[..., 0] |
|
|
tgt_idx = d[..., 1] |
|
|
param_indices = torch.tensor([self.weight_indices.get((s.item(), t.item()), self.shared_param_idx) |
|
|
for s, t in zip(src_idx.reshape(-1), tgt_idx.reshape(-1))], |
|
|
device=self.device).reshape(batch_size, -1) |
|
|
|
|
|
|
|
|
w = self.weights[param_indices] |
|
|
b = self.biases[param_indices] |
|
|
|
|
|
expand_energy_tensor = self._reshape12(energy_tensor.unsqueeze(1).repeat(1, w.size(1), 1, 1)) |
|
|
|
|
|
if self.use_ds: |
|
|
expand_energy_tensor = expand_energy_tensor.half() |
|
|
nxt_energy_tensor = self.activation(expand_energy_tensor.bmm(self._reshape12(w))+self._reshape12(b)+Variable(pe[:,i+1], requires_grad=False)) |
|
|
output_tensor = nxt_energy_tensor.reshape(batch_size, -1, nxt_energy_tensor.size(-2), nxt_energy_tensor.size(-1)) |
|
|
|
|
|
if i == 0: |
|
|
energy_cache = output_tensor[:,0] |
|
|
else: |
|
|
energy_cache = torch.cat([energy_cache, output_tensor[:,0]], dim=1) |
|
|
|
|
|
if 1: |
|
|
energy = output_tensor.norm(2, (-2, -1)) |
|
|
label = torch.LongTensor([0 for _ in range(batch_size)]).to(self.device) |
|
|
loss += nn.CrossEntropyLoss()(energy, label) |
|
|
|
|
|
return loss / neighbor_ids.size(1) |
|
|
|
|
|
def decode(self, start, vocab, max_new_tokens=16, do_sample=False, temperature=1): |
|
|
ret = [] |
|
|
pe = self.get_positional_encoding(512, self.hidden_size) |
|
|
|
|
|
for i, pair in enumerate(start): |
|
|
if i == 0: |
|
|
energy_tensor = self.get_initial_tensor(batch_size=1, d=torch.tensor([[pair]], device=self.device), pe=pe).squeeze(0) |
|
|
else: |
|
|
energy_tensor = (energy_cache * self.positions[:, :i, :].softmax(1)).sum(1, keepdim=True).squeeze(0) |
|
|
|
|
|
|
|
|
param_idx = self.weight_indices.get((pair[0], pair[1]), self.shared_param_idx) |
|
|
|
|
|
|
|
|
w = self.weights[param_idx].to(self.device) |
|
|
b = self.biases[param_idx].to(self.device) |
|
|
|
|
|
energy_tensor = self.activation(energy_tensor.mm(w) + b + pe.squeeze(0)[i]) |
|
|
if i == 0: |
|
|
energy_cache = energy_tensor.unsqueeze(0) |
|
|
else: |
|
|
energy_cache = torch.cat([energy_cache, energy_tensor.unsqueeze(0)], dim=1) |
|
|
ret += [pair] |
|
|
|
|
|
x = pair[1] |
|
|
prev_i = len(start) |
|
|
|
|
|
for i in range(max_new_tokens): |
|
|
candidates = vocab(vocab.get_neighbor_of_node(x, -1)) |
|
|
|
|
|
|
|
|
param_indices = torch.tensor([self.weight_indices.get((x, t[1]), self.shared_param_idx) |
|
|
for t in candidates], device=self.device) |
|
|
|
|
|
|
|
|
all_w = self.weights[param_indices].to(self.device) |
|
|
all_b = self.biases[param_indices].to(self.device) |
|
|
|
|
|
curr_i = prev_i + i |
|
|
energy_tensor = (energy_cache * self.positions[:, :curr_i, :].softmax(1)).sum(1, keepdim=True) |
|
|
expand_energy_tensor = energy_tensor.unsqueeze(1).repeat(1, all_w.size(0), 1, 1) |
|
|
expand_energy_tensor = self._reshape12(expand_energy_tensor) |
|
|
|
|
|
nxt_energy_tensor = self.activation(expand_energy_tensor.bmm(self._reshape12(all_w)) + self._reshape12(all_b) + pe[:,curr_i].unsqueeze(0)) |
|
|
output_tensor = nxt_energy_tensor.reshape(1, -1, nxt_energy_tensor.size(-2), nxt_energy_tensor.size(-1)) |
|
|
|
|
|
energy = output_tensor.norm(2, (-2,-1)).squeeze() |
|
|
|
|
|
probs = torch.softmax(energy, dim=-1) |
|
|
if temperature > 0: |
|
|
probs = probs / temperature |
|
|
if do_sample: |
|
|
index = torch.multinomial(probs, 1).item() |
|
|
else: |
|
|
index = probs.argmax(-1).item() |
|
|
|
|
|
y = candidates[index][-1] |
|
|
ret += [(x, y)] |
|
|
|
|
|
energy_tensor = output_tensor[0, index] |
|
|
x = y |
|
|
|
|
|
energy_cache = torch.cat([energy_cache, energy_tensor.unsqueeze(0)], dim=1) |
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
class Vocab: |
|
|
def __init__(self, node_dict, nodeindex_dict, edge_dict, edge_decode_dict): |
|
|
self.node_dict = node_dict |
|
|
self.nodeindex_dict = nodeindex_dict |
|
|
self.edge_dict = edge_dict |
|
|
self.edge_decode_dict = edge_decode_dict |
|
|
|
|
|
def __call__(self, x): |
|
|
if isinstance(x, list): |
|
|
return [self.__call__(_) for _ in x] |
|
|
else: |
|
|
return self.fetch(x) |
|
|
|
|
|
def fetch(self, x): |
|
|
s, t = x.split("->") |
|
|
return self.edge_dict[s][t] if s in self.edge_dict and t in self.edge_dict[s] else self.edge_dict[""][""] |
|
|
|
|
|
@classmethod |
|
|
def from_node_dict(cls, dictname): |
|
|
node_dict = dict() |
|
|
nodeindex_dict = dict() |
|
|
edge_dict = dict() |
|
|
edge_decode_dict = dict() |
|
|
for s in dictname: |
|
|
node_dict[s] = dictname[s] |
|
|
nodeindex_dict[dictname[s]] = s |
|
|
edge_dict[s] = {} |
|
|
for t in dictname: |
|
|
edge_dict[s][t] = (dictname[s], dictname[t]) |
|
|
edge_decode_dict[(dictname[s], dictname[t])] = "->".join([s, t]) |
|
|
return cls(node_dict, nodeindex_dict, edge_dict, edge_decode_dict) |
|
|
|
|
|
@classmethod |
|
|
def from_edge(cls, filename): |
|
|
edge_dict = dict() |
|
|
edge_dict[""] = {} |
|
|
edge_dict[""][""] = (0, 0) |
|
|
edge_decode_dict = dict() |
|
|
with open(filename) as f: |
|
|
for line in f: |
|
|
|
|
|
s, t = line.strip().split("->") |
|
|
if s not in edge_dict: |
|
|
i = len(edge_dict) |
|
|
j = 0 |
|
|
edge_dict[s] = dict() |
|
|
else: |
|
|
i = edge_dict[s][list(edge_dict[s].keys())[0]][0] |
|
|
j = len(edge_dict[s]) |
|
|
edge_dict[s][t] = (i, j) |
|
|
edge_decode_dict[(i, j)] = "->".join([s, t]) |
|
|
return cls(None, edge_dict, edge_decode_dict) |
|
|
|
|
|
def get_neighbor_of_edge(self, key, k, frequency_dict=None): |
|
|
s, t = key.split("->") |
|
|
_s = s if s in self.edge_dict else "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if frequency_dict: |
|
|
frequency_lst = list(frequency_dict[_s].keys()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_lst = [x for i, x in enumerate(frequency_lst[:k+1]) if x != t][:k] |
|
|
ret = ["->".join([_s, _t]) for _t in t_lst] |
|
|
random.shuffle(ret) |
|
|
return ret |
|
|
|
|
|
else: |
|
|
ret = ["->".join([_s, _t]) for _t in self.edge_dict[_s].keys() if _t != t] |
|
|
random.shuffle(ret) |
|
|
return ret[:k] if k != -1 else ret |
|
|
|
|
|
def get_neighbor_of_node(self, key, k): |
|
|
|
|
|
s = self.nodeindex_dict[key] |
|
|
|
|
|
ret = ["->".join([s, _t]) for _t in self.edge_dict[s].keys() if _t != s] |
|
|
|
|
|
|
|
|
random.shuffle(ret) |
|
|
return ret[:k] if k != -1 else ret |
|
|
|
|
|
def get_neighbor_of_edge_broadcast(self, key, edges, k=100): |
|
|
s, t = key.split("->") |
|
|
_ret = [_t for _t in self.edge_dict[s].keys() if _t != t] |
|
|
random.shuffle(_ret) |
|
|
ret = [] |
|
|
for edge in edges: |
|
|
s, t = edge.split("->") |
|
|
ret += [["->".join([s, _t]) for _t in _ret[:k]]] |
|
|
return ret |
|
|
|
|
|
@staticmethod |
|
|
def to_path(tokens): |
|
|
path = [] |
|
|
for left, right in zip(tokens[:-1], tokens[1:]): |
|
|
path.append("->".join([left, right])) |
|
|
return path |
|
|
|
|
|
def get_edge_of_node(self, key): |
|
|
return list(self.edge_dict[key].values()) |
|
|
|
|
|
def decode(self, x): |
|
|
return self.edge_decode_dict[x] |
|
|
|
|
|
|
|
|
|