Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
8f1745b
1
Parent(s):
df051eb
GlobalAttentionPoolingHead layer
Browse files
.ipynb_checkpoints/Model visualisation-checkpoint.ipynb
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [],
|
| 3 |
-
"metadata": {},
|
| 4 |
-
"nbformat": 4,
|
| 5 |
-
"nbformat_minor": 5
|
| 6 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qarac/models/layers/GlobalAttentionPoolingHead.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Created on Tue Sep 5 07:32:55 2023
|
| 5 |
+
|
| 6 |
+
@author: peter
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import keras
|
| 10 |
+
import tensorflow
|
| 11 |
+
|
| 12 |
+
class GlobalAttentionPoolingHead(keras.layers.Layer):
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super(GlobalAttentionPoolingHead,self).__init__()
|
| 16 |
+
self.global_projection = None
|
| 17 |
+
self.local_projection = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def build(self,input_shape):
|
| 21 |
+
width = input_shape[-1]
|
| 22 |
+
self.global_projection = self.add_weight('global projection',shape=(width,width))
|
| 23 |
+
self.local_projection = self.add_weight('local projection',shape=(width,width))
|
| 24 |
+
self.build=True
|
| 25 |
+
|
| 26 |
+
@tensorflow.function
|
| 27 |
+
def project(self,X):
|
| 28 |
+
return tensorflow.tensordot(X,self.local_projection,axes=1)
|
| 29 |
+
|
| 30 |
+
def attention_function(self,gp):
|
| 31 |
+
@tensorflow.function
|
| 32 |
+
def inner(lp):
|
| 33 |
+
return tensorflow.tensordot(lp,gp,axes=1)
|
| 34 |
+
return inner
|
| 35 |
+
|
| 36 |
+
def call(self,X,training=None):
|
| 37 |
+
gp = tensorflow.linalg.l2_normalize(tensorflow.tensordot([tensorflow.reduce_sum(X,
|
| 38 |
+
axis=1),
|
| 39 |
+
self.global_projection],
|
| 40 |
+
axes=1),
|
| 41 |
+
axis=1)
|
| 42 |
+
lp = tensorflow.linalg.l2_normalize(tensorflow.ragged.map_flat_values(self.project,
|
| 43 |
+
X),
|
| 44 |
+
axis=2)
|
| 45 |
+
attention = tensorflow.ragged.map_flat_values(self.attention_function(gp),
|
| 46 |
+
lp)
|
| 47 |
+
return tensorflow.reduce_sum(attention *X,
|
| 48 |
+
axis=1)
|
qarac/models/layers/HyenaLayer.py
CHANGED
|
@@ -11,22 +11,33 @@ import keras_nlp
|
|
| 11 |
import tensorflow
|
| 12 |
import warnings
|
| 13 |
|
|
|
|
| 14 |
@tensorflow.function
|
| 15 |
def convolve(x,y):
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
fz = fx*fy
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
-
@tensorflow.function
|
| 23 |
-
def fft(x):
|
| 24 |
-
|
| 25 |
|
| 26 |
-
@tensorflow.function
|
| 27 |
-
def ifft(x):
|
| 28 |
-
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class HyenaLayer(keras.layers.Layer):
|
| 32 |
"""Keras implementation of Hyena layer. Unlike in the original paper,
|
|
@@ -77,24 +88,31 @@ class HyenaLayer(keras.layers.Layer):
|
|
| 77 |
trainable=True)
|
| 78 |
self.filters = self.add_weight(shape=(width,width,self.stages),
|
| 79 |
trainable=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
def call(self,X,training=None):
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
f_flat = tensorflow.tensordot(self.positional_encoding(X).flat_values,
|
| 86 |
-
self.filters,
|
| 87 |
-
axes=1)
|
| 88 |
-
x = tensorflow.RaggedTensor.from_row_lengths(x_flat,X.row_lengths())
|
| 89 |
-
f = tensorflow.RaggedTensor.from_row_lengths(f_flat,X.row_lengths())
|
| 90 |
if self.causal:
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
f = concat(f,tensorflow.zeros_like(f))
|
| 94 |
y = x[:,:,:,0]
|
| 95 |
for i in tensorflow.range(self.stages):
|
| 96 |
y = convolve(y,f[:,:,:,i])*x[:,:,:,i+1]
|
| 97 |
if self.causal:
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
| 11 |
import tensorflow
|
| 12 |
import warnings
|
| 13 |
|
| 14 |
+
|
| 15 |
@tensorflow.function
|
| 16 |
def convolve(x,y):
|
| 17 |
+
xT = tensorflow.vectorized_map(tensorflow.transpose, x)
|
| 18 |
+
yT = tensorflow.vectorized_map(tensorflow.transpose, y)
|
| 19 |
+
fx = tensorflow.vectorized_map(tensorflow.signal.rfft, xT)
|
| 20 |
+
fy = tensorflow.vectorized_map(tensorflow.signal.rfft, yT)
|
| 21 |
fz = fx*fy
|
| 22 |
+
zT = tensorflow.vectorized_map(tensorflow.signal.irfft, fz)
|
| 23 |
+
return tensorflow.vectorized_map(tensorflow.transpose,zT)
|
| 24 |
|
| 25 |
+
# @tensorflow.function
|
| 26 |
+
# def fft(x):
|
| 27 |
+
# return tensorflow.signal.rfft(tensorflow.transpose(x))
|
| 28 |
|
| 29 |
+
# @tensorflow.function
|
| 30 |
+
# def ifft(x):
|
| 31 |
+
# return tensorflow.transpose(tensorflow.signal.irfft(x))
|
| 32 |
|
| 33 |
+
@tensorflow.function
|
| 34 |
+
def pad(x):
|
| 35 |
+
return tensorflow.concat([x,tensorflow.zeros_like(x)],0)
|
| 36 |
+
|
| 37 |
+
@tensorflow.function()
|
| 38 |
+
def truncate(args):
|
| 39 |
+
(data,length)=args
|
| 40 |
+
return data[:length]
|
| 41 |
|
| 42 |
class HyenaLayer(keras.layers.Layer):
|
| 43 |
"""Keras implementation of Hyena layer. Unlike in the original paper,
|
|
|
|
| 88 |
trainable=True)
|
| 89 |
self.filters = self.add_weight(shape=(width,width,self.stages),
|
| 90 |
trainable=True)
|
| 91 |
+
self.built = True
|
| 92 |
+
|
| 93 |
+
def conpute_output_shape(self,input_shape):
|
| 94 |
+
return input_shape
|
| 95 |
+
|
| 96 |
+
@tensorflow.function
|
| 97 |
+
def project(self,x):
|
| 98 |
+
return tensorflow.tensordot(x,self.data_projection,axes=1)
|
| 99 |
+
|
| 100 |
+
@tensorflow.function
|
| 101 |
+
def generate_filters(self,t):
|
| 102 |
+
return tensorflow.tensordot(t, self.filters,axes=1)
|
| 103 |
|
| 104 |
def call(self,X,training=None):
|
| 105 |
+
|
| 106 |
+
x = tensorflow.ragged.map_flat_values(self.project, X)
|
| 107 |
+
f = tensorflow.ragged.map_flat_values(self.generate_filters,self.positional_encoding(X))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if self.causal:
|
| 109 |
+
x = tensorflow.vectorize_map(pad,x)
|
| 110 |
+
f = tensorflow.vectorize_map(pad,f)
|
|
|
|
| 111 |
y = x[:,:,:,0]
|
| 112 |
for i in tensorflow.range(self.stages):
|
| 113 |
y = convolve(y,f[:,:,:,i])*x[:,:,:,i+1]
|
| 114 |
if self.causal:
|
| 115 |
+
y = tensorflow.vectorized_map(truncate,(y,X.row_lengths()))
|
| 116 |
+
return tensorflow.raw_ops.RaggedTensorToVariant(rt_nested_splits=y.row_splits,
|
| 117 |
+
rt_dense_values=y.flat_values,
|
| 118 |
+
batched_input=True)
|
scripts.py
CHANGED
|
@@ -7,7 +7,9 @@ import qarac.corpora.BNCorpus
|
|
| 7 |
import qarac.corpora.Batcher
|
| 8 |
import qarac.models.qarac_base_model
|
| 9 |
import keras
|
|
|
|
| 10 |
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
|
|
|
|
| 7 |
import qarac.corpora.Batcher
|
| 8 |
import qarac.models.qarac_base_model
|
| 9 |
import keras
|
| 10 |
+
import tensorflow
|
| 11 |
|
| 12 |
+
#tensorflow.debugging.disable_traceback_filtering()
|
| 13 |
|
| 14 |
|
| 15 |
|