TextRetrieval / models.py
PierreHanna's picture
Upload 6 files
bd29f40
raw
history blame
11.1 kB
import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization, Concatenate
from tensorflow.keras.layers import Lambda, Flatten, Dense
from tensorflow.keras.initializers import glorot_uniform, RandomNormal, Zeros, HeNormal, Constant
from tensorflow.keras.layers import Input, Subtract, Dense, Lambda, Dropout,LeakyReLU, ReLU, PReLU, Attention
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Conv1D, ZeroPadding2D, Activation, Input, concatenate, ConvLSTM2D, LSTM
from tensorflow.keras.layers import AveragePooling1D, MaxPooling1D, GlobalMaxPooling1D, GlobalMaxPooling2D, TimeDistributed, GlobalAveragePooling1D
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D, UpSampling1D, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Conv1D, ZeroPadding2D, Activation, Multiply, Add, MaxPool1D, Permute
from keras import backend as K
import tensorflow_addons as tfa
import numpy as np
MARGIN = 0.5
DIM_OUT = 1024
def triplet_loss_new(y_true, y_pred):
anchor, positive, negative = y_pred[:,:DIM_OUT], y_pred[:,DIM_OUT:2*DIM_OUT], y_pred[:,2*DIM_OUT:]
positive_dist = K.sum(K.square(anchor-positive), axis=-1)
negative_dist = K.sum(K.square(anchor-negative), axis=-1)
return K.sum(K.maximum(positive_dist - negative_dist + MARGIN, 0), axis=0)
# Define the contrastive loss function, NT_Xent (Tensorflow version)
def nt_xent_loss_4(y_true, y_pred, tau=0.07):
'''call
Calculates the infonce loss described in SimCLR
https://arxiv.org/abs/2002.05709
Args:
z1 (tf.Tensor): The embeddings, view 1 (half of batch)
z2 (tf.Tensor): The embeddings, view 2 (half of batch)
Returns:
tf.Tensor: The loss
'''
z1 = y_pred[:,:DIM_OUT]
z2 = y_pred[:,DIM_OUT:2*DIM_OUT]
# Combine the two embeddings
z = tf.concat([z1, z2], axis=0)
# Normalize each row
z = tf.math.l2_normalize(z, axis=1)
batch_size = tf.shape(z)[0]
ones = tf.ones((batch_size // 2, ))
labels = tf.experimental.numpy.diagflat(ones, batch_size // 2) + \
tf.experimental.numpy.diagflat(ones, -batch_size // 2)
# Similarity matrix
sim_m = z @ tf.transpose(z)
# Setting diagonal to -1
sim_m = tf.linalg.set_diag(sim_m, -tf.ones((batch_size, )))
# Crossentropy
sim_m = sim_m / tau
entropy = tf.multiply(-labels, tf.nn.log_softmax(sim_m, axis=1))
return tf.reduce_mean(tf.reduce_sum(entropy, axis=1))
# Define the contrastive loss function, NT_Xent (Tensorflow version)
def nt_xent_loss_3(y_true, y_pred, tau=0.07):
""" Calculates the contrastive loss of the input data using NT_Xent. The
equation can be found in the paper: https://arxiv.org/pdf/2002.05709.pdf
(This is the Tensorflow implementation of the standard numpy version found
in the NT_Xent function).
Args:
zi: One half of the input data, shape = (batch_size, feature_1, feature_2, ..., feature_N)
zj: Other half of the input data, must have the same shape as zi
tau: Temperature parameter (a constant), default = 1.
Returns:
loss: The complete NT_Xent constrastive loss
"""
zi = y_pred[:,:DIM_OUT]
zj = y_pred[:,DIM_OUT:2*DIM_OUT]
z = tf.cast(tf.concat((zi, zj), 0), dtype=tf.float32)
loss = 0
for k in range(zi.shape[0]):
# Numerator (compare i,j & j,i)
i = k
j = k + zi.shape[0]
# Instantiate the cosine similarity loss function
cosine_sim = tf.keras.losses.CosineSimilarity(axis=-1, reduction=tf.keras.losses.Reduction.NONE)
sim = tf.squeeze(- cosine_sim(tf.reshape(z[i], (1, -1)), tf.reshape(z[j], (1, -1))))
numerator = tf.math.exp(sim / tau)
# Denominator (compare i & j to all samples apart from themselves)
sim_ik = - cosine_sim(tf.reshape(z[i], (1, -1)), z[tf.range(z.shape[0]) != i])
sim_jk = - cosine_sim(tf.reshape(z[j], (1, -1)), z[tf.range(z.shape[0]) != j])
denominator_ik = tf.reduce_sum(tf.math.exp(sim_ik / tau))
denominator_jk = tf.reduce_sum(tf.math.exp(sim_jk / tau))
# Calculate individual and combined losses
loss_ij = - tf.math.log(numerator / denominator_ik)
loss_ji = - tf.math.log(numerator / denominator_jk)
loss += loss_ij + loss_ji
# Divide by the total number of samples
loss /= z.shape[0]
return loss
def nt_xent_loss_2(y_true, y_pred, temperature=0.07):
# InfoNCE loss (information noise-contrastive estimation)
# NT-Xent loss (normalized temperature-scaled cross entropy)
projections_1 = y_pred[:,:DIM_OUT]
projections_2 = y_pred[:,DIM_OUT:2*DIM_OUT]
# Cosine similarity: the dot product of the l2-normalized feature vectors
projections_1 = tf.math.l2_normalize(projections_1, axis=1)
projections_2 = tf.math.l2_normalize(projections_2, axis=1)
similarities = (
tf.matmul(projections_1, projections_2, transpose_b=True) / temperature
)
# The similarity between the representations of two augmented views of the
# same image should be higher than their similarity with other views
batch_size = tf.shape(projections_1)[0]
contrastive_labels = tf.range(batch_size)
contrastive_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
contrastive_accuracy.update_state(contrastive_labels, similarities)
contrastive_accuracy.update_state(
contrastive_labels, tf.transpose(similarities)
)
# The temperature-scaled similarities are used as logits for cross-entropy
# a symmetrized version of the loss is used here
loss_1_2 = tf.keras.losses.sparse_categorical_crossentropy(
contrastive_labels, similarities, from_logits=True
)
loss_2_1 = tf.keras.losses.sparse_categorical_crossentropy(
contrastive_labels, tf.transpose(similarities), from_logits=True
)
return (loss_1_2 + loss_2_1) / 2
#def contrastive_loss(xi, xj, tau=1, normalize=False):
################# ERREUR SUR CETTE VERSION ???
def nt_xent_loss(y_true, y_pred, tau=0.07, normalize=False):
''' this loss is the modified torch implementation by M Diephuis here: https://github.com/mdiephuis/SimCLR/
the inputs:
xi, xj: image features extracted from a batch of images 2N, composed of N matching paints
tau: temperature parameter
normalize: normalize or not. seem to not be very useful, so better to try without.
'''
xi = y_pred[:,:DIM_OUT]
xj = y_pred[:,DIM_OUT:2*DIM_OUT]
#xi=tf.transpose(xi)
#xj=tf.transpose(xj)
x = tf.keras.backend.concatenate((xi, xj), axis=0)
#print(xi.shape)
#print(x.shape)
sim_mat = tf.keras.backend.dot(x, tf.keras.backend.transpose(x))
if normalize:
sim_mat_denom = tf.keras.backend.dot(tf.keras.backend.l2_normalize(x, axis=1).unsqueeze(1), tf.keras.backend.l2_normalize(x, axis=1).unsqueeze(1).T)
sim_mat = sim_mat / sim_mat_denom.clamp(min=1e-16)
sim_mat = tf.keras.backend.exp(sim_mat /tau)
if normalize:
sim_mat_denom = tf.keras.backend.l2_normalize(xi, dim=1) * tf.keras.backend.l2_normalize(xj, axis=1)
sim_match = tf.keras.backend.exp(tf.keras.backend.sum(xi * xj, axis=-1) / sim_mat_denom / tau)
else:
sim_match = tf.keras.backend.exp(tf.keras.backend.sum(xi * xj, axis=-1) / tau)
sim_match = tf.keras.backend.concatenate((sim_match, sim_match), axis=0)
#print(tf.keras.backend.shape(x)[0])
norm_sum = tf.keras.backend.exp(tf.keras.backend.ones(tf.keras.backend.shape(x)[0]) / tau)
#norm_sum = tf.keras.backend.ones(12) # NON
#norm_sum = tf.keras.backend.exp(32/ tau) #OK
#norm_sum = tf.keras.backend.shape(x)[0] #OK
#return K.sum(xi)
return tf.math.reduce_mean(-tf.keras.backend.log(sim_match / (tf.keras.backend.sum(sim_mat, axis=-1) - norm_sum)))
def create_encoder_model_audio(in_shape, dim, final_activ):
#return create_encoder_model_resnet_byte_1d(in_shape)
return create_encoder_model_mlp(in_shape, dim, final_activ=final_activ) #1024
def create_encoder_model_text(in_shape, dim, final_activ):
#return create_encoder_model_resnet_byte_1d(in_shape)
return create_encoder_model_mlp(in_shape, dim, final_activ=final_activ) #1024
######### RESNET 1D
def residual_block_byte_1d(x, filters, activation="relu"):
# Shortcut
s = Conv1D(filters, 1, padding="same")(x)
y = BatchNormalization()(s)
y = Activation(activation)(y)
y = Conv1D(filters, 3, padding="same")(y)
y = BatchNormalization()(y)
y = Conv1D(filters, 1, padding="same")(y)
y = BatchNormalization()(y)
y = Add()([y, s])
y = Activation(activation)(y)
return y
#return MaxPool1D(pool_size=2, strides=2)(x)
def create_encoder_model_resnet_byte_1d(input_shape):
inputs = Input(shape=input_shape)
x = Conv1D(32, 7, strides = 2, padding="same")(inputs)
x = MaxPooling1D(pool_size=3, strides=2)(x)
for i in range(3):
x = residual_block_byte_1d(x, 32)
for i in range(4):
x = residual_block_byte_1d(x, 64)
for i in range(6):
x = residual_block_byte_1d(x, 128)
for i in range(3):
x = residual_block_byte_1d(x, 256)
#print(x.shape)
x = AveragePooling1D(pool_size=3, strides=3)(x)
x = GlobalAveragePooling1D()(x)
#x = Flatten()(x)
x = Dense(DIM_OUT, activation="relu")(x)
model = Dense(DIM_OUT, activation='sigmoid')(x)
model = BatchNormalization()(model)
model = Lambda(lambda x: K.l2_normalize(x,axis=-1))(model)
model = Model(inputs=inputs,outputs=model)
#model.summary()
return model
# simple MLP
def create_encoder_model_mlp(input_shape, size1, final_activ=None):
inputs = Input(shape=input_shape)
x = Dense(size1, activation="relu")(inputs)
x = Dropout(0.1)(x)
#x = BatchNormalization()(x)
'''
x = Dense(1024, activation="relu")(x)
x = Dropout(0.1)(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.1)(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.1)(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.1)(x)
'''
#x = BatchNormalization()(x)
#x = Dense(512, activation="relu")(x)
#x = BatchNormalization()(x)
'''
if final_activ != None :
model = Dense(DIM_OUT)(x)#, activation='sigmoid')(x)
else :
model = Dense(DIM_OUT, activation=final_activ)(x)
'''
model = Dense(DIM_OUT, activation=final_activ)(x)
model = Dropout(0.1)(model)
#model = BatchNormalization()(model)
model = Lambda(lambda x: K.l2_normalize(x,axis=-1))(model)
model = Model(inputs=inputs,outputs=model)
model.summary()
return model