TextRetrieval / app.py
PierreHanna's picture
Update app.py
f65a9bb
raw
history blame
5 kB
import tempfile
import gradio as gr
import os
import tensorflow_hub as hub
import tensorflow as tf
import tensorflow_text as text
import sys
import numpy as np
import faiss
import csv
import datetime
# NO GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def make_bert_preprocess_model(sentence_features, tfhub_handle_preprocess, seq_length=128):
"""Returns Model mapping string features to BERT inputs.
"""
input_segments = [
tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
for ft in sentence_features]
bert_preprocess = hub.load(tfhub_handle_preprocess)
tokenizer = hub.KerasLayer(bert_preprocess.tokenize, name='tokenizer')
segments = [tokenizer(s) for s in input_segments]
truncated_segments = segments
packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs,
arguments=dict(seq_length=seq_length),
name='packer')
model_inputs = packer(truncated_segments)
return tf.keras.Model(input_segments, model_inputs)
from models import *
def process(prompt, lang):
# Getting prompt user
#prompt = input("Audio Search - enter text : ")
#print(prompt)
# prompt embedding
bert_model_name = 'small_bert/bert_en_uncased_L-4_H-512_A-8'
tfhub_handle_encoder = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1'
tfhub_handle_preprocess = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
MAX_LENGTH = 130 # MAX de 512 !!! TENSORFLOW !!!
TOP = 10
bert_preprocess_model = make_bert_preprocess_model(['my_input'], tfhub_handle_preprocess, seq_length = MAX_LENGTH)
bert_model = hub.KerasLayer(tfhub_handle_encoder)
now = datetime.datetime.now()
print()
print('*************')
print("Current Time: ", str(now))
print("Text input : ", prompt)
print('*************')
print()
prompt=[prompt]
text_preprocessed = bert_preprocess_model([np.array(prompt)])
embed_prompt = bert_model(text_preprocessed)
print(" text representation computed.")
# Embed text
#from models import *
encoder_text = tf.keras.models.load_model('encoder_text_retrievaltext_bmg_221022_54.h5')
embed_query = encoder_text.predict(embed_prompt["pooled_output"])
faiss.normalize_L2(embed_query)
print(" text embed computed.")
# load embed audio catalog
index = faiss.read_index("BMG_221022.index")
# distance computing
D, I = index.search(embed_query, TOP)
# names index
import joblib
audio_names = joblib.load(open('BMG_221022_names.index', 'rb'))
#url
url_dict={}
with open("bmg_clean.csv") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=';')
for row in csv_reader:
f = row[2].split('/')[-1]
url_dict[f.split('/')[-1][:-4]] = row[2]
# output : top N audio file names
print(I)
print(D)
print("----")
for i in range(len(I[0])):
print(audio_names[I[0][i]], " with distance ", D[0][i])
print(" url : ", url_dict[audio_names[I[0][i]]])
return [url_dict[audio_names[I[0][0]]], url_dict[audio_names[I[0][1]]], url_dict[audio_names[I[0][2]]], url_dict[audio_names[I[0][3]]], url_dict[audio_names[I[0][4]]]]
inputs = [gr.Textbox(label="Input", value="type your description", max_lines=2),
gr.Radio(label="Language", choices=["en"], value="en")]
poc_examples = [#[["I love learning machine learning"],["autre"]]
["Mysterious filmscore with Arabic influenced instruments","en"],
["Let's go on a magical adventure with wizzards, dragons and castles","en"],
["Creepy piano opening evolves and speeds up into a cinematic orchestral piece","en"],
["Halloween rock with creepy organ","en"],
["Rhythmic electro dance track for sport, motivation and sweating","en"],
["soundtrack for an action movie from the eighties in a retro synth wave style","en"],
["Choral female singing is rhythmically accompanied in a church with medieval instruments","en"],
["Christmas","en"],
["love romantic with piano, strings and vocals","en"],
["Electronic soundscapes for chilling and relaxing","en"],
["Minimal, emotional, melancholic piano","en"],
# ["Minimal, happy, joyful piano","en"],
["A calm and romantic acoustic guitar melody","en"],
["horror suspense piano","en"],
["Big Band","en"],
["90 eurodance beat","en"],
]
outputs = [gr.Audio(label="Track 1"), gr.Audio(label="Track 2"), gr.Audio(label="Track 3"), gr.Audio(label="Track 4"), gr.Audio(label="Track 5")]
demo1 = gr.Interface(fn=process, inputs=inputs, outputs=outputs, examples=poc_examples, cache_examples=False)
demo1.launch(debug=True)
#demo1.launch(debug=True, enable_queue = False, auth=(os.environ['DEMO_LOGIN'], os.environ['DEMO_PWD']),auth_message = "Contact Simbals to get login/pwd")#, share=True)