Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| from tensorflow.keras import Input | |
| from tensorflow.keras.layers import ( | |
| GRU, | |
| Activation, | |
| ActivityRegularization, | |
| Add, | |
| BatchNormalization, | |
| Bidirectional, | |
| Concatenate, | |
| Conv1D, | |
| Dense, | |
| Dropout, | |
| Flatten, | |
| Lambda, | |
| Layer, | |
| ) | |
| from tensorflow.keras.models import Model, load_model | |
| import numpy as np | |
| def selector_init(shape, dtype=None): | |
| c = np.zeros(shape) | |
| c[0] += 1 | |
| return tf.constant(c, dtype=dtype) | |
| class Selector(Layer): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def build(self, input_shape): | |
| for elem in input_shape[1:]: | |
| assert ( | |
| elem.as_list() == input_shape[0].as_list() | |
| ), "All inputs must be the same shape." | |
| self.selectors = self.add_weight( | |
| name="selectors", | |
| shape=(len(input_shape),), | |
| initializer=selector_init, | |
| trainable=False, | |
| ) | |
| super().build(input_shape) | |
| def call(self, x): | |
| return sum([self.selectors[i] * x[i] for i in range(self.selectors.shape[0])]) | |
| def compute_output_shape(self, input_shape): | |
| return input_shape[0] | |
| def get_config(self): | |
| config = super().get_config().copy() | |
| config.update({}) | |
| return config | |
| class ResidualTuner(Layer): | |
| def __init__(self, hidden_units=100, **kwargs): | |
| super().__init__(**kwargs) | |
| self.hidden_units = hidden_units | |
| self.dense3 = Dense(1) | |
| self.batchnorm2 = BatchNormalization() | |
| self.dense2 = Dense(self.hidden_units, activation="relu") | |
| self.batchnorm1 = BatchNormalization() | |
| self.dense1 = Dense(self.hidden_units, activation="relu") | |
| def build(self, input_shape): | |
| super().build(input_shape) # Be sure to call this somewhere! | |
| def call(self, inp): | |
| x = self.dense1(inp) | |
| x = self.batchnorm1(x) | |
| x = self.dense2(x) | |
| x = self.batchnorm2(x) | |
| x = self.dense3(x) | |
| return x + inp | |
| def compute_output_shape(self, input_shape): | |
| return input_shape | |
| def get_config(self): | |
| config = super().get_config().copy() | |
| config.update({"hidden_units": self.hidden_units}) | |
| return config | |
| class SumDiff(Layer): | |
| def __init__(self, freeze=False, **kwargs): | |
| super().__init__(**kwargs) | |
| self.freeze = freeze | |
| def build(self, input_shape): | |
| self.b = self.add_weight( | |
| name="b", | |
| shape=(1,), | |
| initializer=tf.keras.initializers.Zeros(), | |
| trainable=not self.freeze, | |
| ) | |
| self.w = self.add_weight( | |
| name="w", | |
| shape=(1,), | |
| initializer=tf.keras.initializers.Ones(), | |
| trainable=not self.freeze, | |
| ) | |
| super().build(input_shape) # Be sure to call this somewhere! | |
| def call(self, x): | |
| out = tf.reduce_sum(x[0], axis=(1, 2)) - tf.reduce_sum(x[1], axis=(1, 2)) | |
| return self.b + self.w * tf.reshape(out, shape=(-1, 1)) | |
| def compute_output_shape( | |
| self, input_shape | |
| ): # MUST INCLUDE THIS FUNCTION IF CHANGING SHAPE! | |
| assert input_shape and len(input_shape) >= 2 | |
| assert input_shape[-1] | |
| return tuple((None, 1)) | |
| def freeze(self, unfreeze=False): | |
| self.freeze = not unfreeze | |
| self.w.trainable = unfreeze | |
| self.b.trainable = unfreeze | |
| def get_config(self): | |
| config = super().get_config().copy() | |
| config.update({"freeze": self.freeze}) | |
| return config | |
| def binary_KL(y_true, y_pred): | |
| # return K.mean(K.binary_crossentropy(y_pred, y_true)-K.binary_crossentropy(y_true, y_true), axis=-1) # this is for the Ubuntu machine in Courant | |
| return tf.keras.backend.mean( | |
| tf.keras.backend.binary_crossentropy(y_true, y_pred) | |
| - tf.keras.backend.binary_crossentropy(y_true, y_true), | |
| axis=-1, | |
| ) # this is for Anaconda or Ubuntu on my PC | |
| def pos_reg(x, adjacency_left_trim=0, adjacency_right_trim=0): | |
| l = x.shape[0] | |
| return tf.reduce_sum(tf.square(x[adjacency_left_trim : l - adjacency_right_trim])) | |
| def adj_reg_fo(x, adjacency_left_trim=0, adjacency_right_trim=0): | |
| l = x.shape[0] | |
| x_trimmed = x[adjacency_left_trim : l - adjacency_right_trim] | |
| x_norm = x_trimmed - tf.reduce_mean(x_trimmed, axis=0) | |
| A = tf.reduce_sum((x_norm[:-1] - x_norm[1:]) ** 2, axis=0) | |
| B = tf.reduce_sum(x_norm ** 2, axis=0) | |
| return tf.reduce_mean(A / B) | |
| def adj_reg_so(x, adjacency_left_trim=0, adjacency_right_trim=0): | |
| l = x.shape[0] | |
| x_trimmed = x[adjacency_left_trim : l - adjacency_right_trim] | |
| x_norm = x_trimmed - tf.reduce_mean(x_trimmed, axis=0) | |
| diff_1 = x_norm[:-1] - x_norm[1:] | |
| diff_2 = diff_1[:-1] - diff_1[1:] | |
| A = tf.reduce_sum(diff_2 ** 2, axis=0) | |
| B = tf.reduce_sum(x_norm ** 2, axis=0) | |
| return tf.reduce_mean(A / B) | |
| class MultiRegularizer(tf.keras.regularizers.Regularizer): | |
| def __init__( | |
| self, | |
| position_regularization, | |
| adjacency_regularization_fo, | |
| adjacency_regularization_so, | |
| adjacency_left_trim=0, | |
| adjacency_right_trim=0, | |
| ): | |
| self.position_regularization = position_regularization | |
| self.adjacency_regularization_fo = adjacency_regularization_fo | |
| self.adjacency_regularization_so = adjacency_regularization_so | |
| self.adjacency_left_trim = adjacency_left_trim | |
| self.adjacency_right_trim = adjacency_right_trim | |
| def __call__(self, x): | |
| return ( | |
| self.position_regularization | |
| * pos_reg(x, self.adjacency_left_trim, self.adjacency_right_trim) | |
| + self.adjacency_regularization_fo | |
| * adj_reg_fo(x, self.adjacency_left_trim, self.adjacency_right_trim) | |
| + self.adjacency_regularization_so | |
| * adj_reg_so(x, self.adjacency_left_trim, self.adjacency_right_trim) | |
| ) | |
| def get_config(self): | |
| config = super().get_config().copy() | |
| config.update( | |
| { | |
| "position_regularization": self.position_regularization, | |
| "adjacency_regularization_fo": self.adjacency_regularization_fo, | |
| "adjacency_regularization_so": self.adjacency_regularization_so, | |
| "adjacency_left_trim": self.adjacency_left_trim, | |
| "adjacency_right_trim": self.adjacency_right_trim, | |
| } | |
| ) | |
| return config | |
| class RegularizedBiasLayer(Layer): | |
| def __init__( | |
| self, | |
| position_regularization, | |
| adjacency_regularization_fo, | |
| adjacency_regularization_so, | |
| adjacency_left_trim=0, | |
| adjacency_right_trim=0, | |
| **kwargs | |
| ): | |
| # self.output_dim = output_dim | |
| super().__init__(**kwargs) | |
| self.position_regularization = position_regularization | |
| self.adjacency_regularization_fo = adjacency_regularization_fo | |
| self.adjacency_regularization_so = adjacency_regularization_so | |
| self.adjacency_left_trim = adjacency_left_trim | |
| self.adjacency_right_trim = adjacency_right_trim | |
| def build(self, input_shape): | |
| regularizer = MultiRegularizer( | |
| self.position_regularization, | |
| self.adjacency_regularization_fo, | |
| self.adjacency_regularization_so, | |
| self.adjacency_left_trim, | |
| self.adjacency_right_trim, | |
| ) | |
| self.kernel = self.add_weight( | |
| name="kernel", | |
| shape=(input_shape[1], input_shape[2]), | |
| initializer="random_normal", | |
| regularizer=regularizer, | |
| trainable=True, | |
| ) | |
| super().build(input_shape) | |
| def call(self, x): | |
| return self.kernel + x | |
| def compute_output_shape(self, input_shape): | |
| return input_shape | |
| def get_config(self): | |
| config = super().get_config().copy() | |
| config.update( | |
| { | |
| "position_regularization": self.position_regularization, | |
| "adjacency_regularization_fo": self.adjacency_regularization_fo, | |
| "adjacency_regularization_so": self.adjacency_regularization_so, | |
| "adjacency_right_trim": self.adjacency_right_trim, | |
| "adjacency_left_trim": self.adjacency_left_trim, | |
| } | |
| ) | |
| return config | |
| def regularized_act(x, act_reg, activation="exponential"): | |
| if isinstance(activation, str): | |
| return ActivityRegularization(l1=act_reg)(Activation(activation)(x)) | |
| return ActivityRegularization(l1=act_reg)(activation(x)) | |
| def train_model( | |
| model, | |
| input_data, | |
| target_data, | |
| filename, | |
| validation_split=0.25, | |
| epochs=256, | |
| batch_size=128, | |
| custom_callbacks=[], | |
| verbose=1, | |
| ): | |
| model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( | |
| filepath=filename, | |
| verbose=0, | |
| save_weights_only=False, | |
| monitor="val_binary_KL", | |
| mode="min", | |
| save_best_only=True, | |
| ) | |
| early_stopping_callback = tf.keras.callbacks.EarlyStopping( | |
| monitor="val_binary_KL", | |
| min_delta=0, | |
| patience=10, | |
| verbose=1, | |
| mode="min", | |
| restore_best_weights=True, | |
| ) | |
| history = model.fit( | |
| input_data, | |
| target_data, | |
| batch_size=batch_size, | |
| epochs=epochs, | |
| verbose=verbose, | |
| validation_split=validation_split, | |
| callbacks=[model_checkpoint_callback, early_stopping_callback] | |
| + custom_callbacks, | |
| ) | |
| return history | |
| def get_model( | |
| input_length=90, | |
| randomized_region=(10, 80), | |
| num_filters=20, | |
| num_structure_filters=8, | |
| filter_width=6, | |
| structure_filter_width=30, | |
| dropout_rate=0.01, | |
| activity_regularization=0.0, | |
| tune_energy=True, | |
| position_regularization=2.5e-5, | |
| adjacency_regularization=0.0, | |
| adjacency_regularization_so=0.0, | |
| position_regularization_structure=2.5e-5, | |
| adjacency_regularization_structure=0.0, | |
| adjacency_regularization_so_structure=0.0, | |
| energy_activation="softplus", | |
| ): | |
| ################### | |
| ## Define layers ## | |
| ################### | |
| # Sequence layers | |
| qc_skip = Conv1D(filters=num_filters, kernel_size=filter_width, name="qc_skip") | |
| qc_incl = Conv1D(filters=num_filters, kernel_size=filter_width, name="qc_incl") | |
| # bias layers | |
| seq_left_trim = randomized_region[0] | |
| seq_right_trim = input_length - randomized_region[1] | |
| position_bias_skip = RegularizedBiasLayer( | |
| position_regularization, | |
| adjacency_regularization, | |
| adjacency_regularization_so, | |
| adjacency_left_trim=seq_left_trim, | |
| adjacency_right_trim=seq_right_trim, | |
| name="position_bias_skip", | |
| ) | |
| position_bias_incl = RegularizedBiasLayer( | |
| position_regularization, | |
| adjacency_regularization, | |
| adjacency_regularization_so, | |
| adjacency_left_trim=seq_left_trim, | |
| adjacency_right_trim=seq_right_trim, | |
| name="position_bias_incl", | |
| ) | |
| dropout_skip_seq = Dropout(dropout_rate, name="dropout_skip_seq") | |
| dropout_incl_seq = Dropout(dropout_rate, name="dropout_incl_seq") | |
| # Structure layers | |
| c_skip_struct = Conv1D( | |
| num_structure_filters, | |
| structure_filter_width, | |
| padding="same", | |
| name="c_skip_struct", | |
| ) | |
| c_incl_struct = Conv1D( | |
| num_structure_filters, | |
| structure_filter_width, | |
| padding="same", | |
| name="c_incl_struct", | |
| ) | |
| position_bias_skip_struct = RegularizedBiasLayer( | |
| position_regularization_structure, | |
| adjacency_regularization_structure, | |
| adjacency_regularization_so_structure, | |
| name="position_bias_skip_struct", | |
| ) | |
| position_bias_incl_struct = RegularizedBiasLayer( | |
| position_regularization_structure, | |
| adjacency_regularization_structure, | |
| adjacency_regularization_so_structure, | |
| name="position_bias_incl_struct", | |
| ) | |
| dropout_skip_struct = Dropout(dropout_rate, name="dropout_skip_struct") | |
| dropout_incl_struct = Dropout(dropout_rate, name="dropout_incl_struct") | |
| # Energy layers | |
| energy_seq = SumDiff(name="energy_seq", freeze=not tune_energy) | |
| energy_seq_struct = SumDiff(name="energy_seq_struct", freeze=not tune_energy) | |
| # Generalized function layer | |
| gen_func = ResidualTuner(name="gen_func", hidden_units=4) | |
| # Final activation | |
| output_activation = Activation("sigmoid", name="output_activation") | |
| # Additional selectors | |
| output_selector = Selector(name="output_selector") | |
| ######################## | |
| ## Define model logic ## | |
| ######################## | |
| # Inputs | |
| seq_input = Input(shape=(input_length, 4), name="seq_input") | |
| struct_input = Input(shape=(input_length, 3), name="struct_input") | |
| wobble_input = Input(shape=(input_length, 1), name="wobble_input") | |
| # Sequence processing | |
| out_simple_skip = qc_skip(seq_input) | |
| out_simple_incl = qc_incl(seq_input) | |
| dropout_bias_skip = dropout_skip_seq(position_bias_skip(out_simple_skip)) | |
| dropout_bias_incl = dropout_incl_seq(position_bias_incl(out_simple_incl)) | |
| # Structure processing | |
| structure_out_skip = dropout_skip_struct( | |
| ( | |
| position_bias_skip_struct( | |
| c_skip_struct(Concatenate()([seq_input, struct_input, wobble_input])) | |
| ) | |
| ) | |
| )[:, 2:-3, :] | |
| structure_out_incl = dropout_incl_struct( | |
| ( | |
| position_bias_incl_struct( | |
| c_incl_struct(Concatenate()([seq_input, struct_input, wobble_input])) | |
| ) | |
| ) | |
| )[:, 2:-3, :] | |
| # Concatenate sequence (selector between sort vs no sort) and structure | |
| seq_struct_concat_skip = Concatenate()([dropout_bias_skip, structure_out_skip]) | |
| seq_struct_concat_incl = Concatenate()([dropout_bias_incl, structure_out_incl]) | |
| # Energy layers | |
| energy_seq_out = energy_seq( | |
| [ | |
| regularized_act( | |
| dropout_bias_incl, | |
| activity_regularization, | |
| activation=energy_activation, | |
| ), | |
| regularized_act( | |
| dropout_bias_skip, | |
| activity_regularization, | |
| activation=energy_activation, | |
| ), | |
| ] | |
| ) | |
| energy_seq_struct_out = energy_seq_struct( | |
| [ | |
| regularized_act( | |
| seq_struct_concat_incl, | |
| activity_regularization, | |
| activation=energy_activation, | |
| ), | |
| regularized_act( | |
| seq_struct_concat_skip, | |
| activity_regularization, | |
| activation=energy_activation, | |
| ), | |
| ] | |
| ) | |
| # Generalized function | |
| gen_func_out = gen_func(energy_seq_struct_out) | |
| # Model output | |
| out = output_activation( | |
| output_selector([energy_seq_out, energy_seq_struct_out, gen_func_out]) | |
| ) | |
| # create model | |
| model = Model(inputs=[seq_input, struct_input, wobble_input], outputs=out) | |
| model.compile(optimizer="adam", loss=binary_KL, metrics=[binary_KL]) | |
| return model | |