PierreHanna commited on
Commit
e628964
Β·
1 Parent(s): ab53756

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +84 -0
models.py CHANGED
@@ -296,3 +296,87 @@ def create_encoder_model_mlp(input_shape, size1, final_activ=None):
296
 
297
  return model
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  return model
298
 
299
+ def make_bert_preprocess_model(sentence_features, tfhub_handle_preprocess, seq_length=128):
300
+ """Returns Model mapping string features to BERT inputs.
301
+ """
302
+
303
+ input_segments = [
304
+ tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
305
+ for ft in sentence_features]
306
+
307
+ bert_preprocess = hub.load(tfhub_handle_preprocess)
308
+ tokenizer = hub.KerasLayer(bert_preprocess.tokenize, name='tokenizer')
309
+ segments = [tokenizer(s) for s in input_segments]
310
+
311
+ truncated_segments = segments
312
+ packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs,
313
+ arguments=dict(seq_length=seq_length),
314
+ name='packer')
315
+ model_inputs = packer(truncated_segments)
316
+ return tf.keras.Model(input_segments, model_inputs)
317
+
318
+ def process(prompt, lang):
319
+
320
+ # Getting prompt user
321
+ #prompt = input("Audio Search - enter text : ")
322
+ #print(prompt)
323
+
324
+ # prompt embedding
325
+ bert_model_name = 'small_bert/bert_en_uncased_L-4_H-512_A-8'
326
+ tfhub_handle_encoder = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1'
327
+ tfhub_handle_preprocess = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
328
+
329
+ MAX_LENGTH = 130 # MAX de 512 !!! TENSORFLOW !!!
330
+ TOP = 10
331
+
332
+
333
+ bert_preprocess_model = make_bert_preprocess_model(['my_input'], tfhub_handle_preprocess, seq_length = MAX_LENGTH)
334
+ bert_model = hub.KerasLayer(tfhub_handle_encoder)
335
+
336
+ now = datetime.datetime.now()
337
+ print()
338
+ print('*************')
339
+ print("Current Time: ", str(now))
340
+ print("Text input : ", prompt)
341
+ print('*************')
342
+ print()
343
+ prompt=[prompt]
344
+ text_preprocessed = bert_preprocess_model([np.array(prompt)])
345
+ embed_prompt = bert_model(text_preprocessed)
346
+ print(" text representation computed.")
347
+
348
+ # Embed text
349
+ #from models import *
350
+ encoder_text = tf.keras.models.load_model(encoder_text_path)
351
+ embed_query = encoder_text.predict(embed_prompt["pooled_output"])
352
+ faiss.normalize_L2(embed_query)
353
+ print(" text embed computed.")
354
+
355
+ # load embed audio catalog
356
+ index = faiss.read_index("BMG_221022.index")
357
+
358
+ # distance computing
359
+ D, I = index.search(embed_query, TOP)
360
+
361
+ # names index
362
+ import joblib
363
+ audio_names = joblib.load(open('BMG_221022_names.index', 'rb'))
364
+
365
+ #url
366
+ url_dict={}
367
+ with open("bmg_clean.csv") as csv_file:
368
+ csv_reader = csv.reader(csv_file, delimiter=';')
369
+ for row in csv_reader:
370
+ f = row[2].split('/')[-1]
371
+ url_dict[f.split('/')[-1][:-4]] = row[2]
372
+
373
+ # output : top N audio file names
374
+ print(I)
375
+ print(D)
376
+ print("----")
377
+ for i in range(len(I[0])):
378
+ print(audio_names[I[0][i]], " with distance ", D[0][i])
379
+ print(" url : ", url_dict[audio_names[I[0][i]]])
380
+
381
+
382
+ return [url_dict[audio_names[I[0][0]]], url_dict[audio_names[I[0][1]]], url_dict[audio_names[I[0][2]]], url_dict[audio_names[I[0][3]]], url_dict[audio_names[I[0][4]]]]