TextRetrieval_A / app.py
PierreHanna's picture
Update app.py
eb392be verified
raw
history blame
5.23 kB
import tempfile
import gradio as gr
import os
import tensorflow as tf
import sys
import numpy as np
import csv
import datetime, time
import joblib
from huggingface_hub import hf_hub_download
# NO GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
max_results = 100
max_output = 5
# Cacher le nom du repo
python_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['MODEL_FILE'],
use_auth_token=os.environ['TOKEN'])
print(python_path)
sys.path.append(os.environ['PRIVATE_DIR'])
from models import *
preprocess_model, model = get_models()
url_dict = get_durl_myma()
#audio_names = get_audio_names()
audio_names = get_audio_names_pickle()
index = get_index()
#encoder_text = get_encoder_text()
encoder_text = tf.keras.models.load_model("encoder_text_retrievaltext_bmg_221022_54_clean")
def process(prompt, lang):
now = datetime.datetime.now()
print()
print('*************')
print("Current Time: ", str(now))
print("Text input : ", prompt)
print('*************')
print()
a=time.time()
#try :
embed_query = get_predict(encoder_text, prompt, preprocess_model, model)
print("Embed time : ", time.time() - a)
do_normalize(embed_query)
D, I = get_distance(index, embed_query, TOP)
print("Search + Embed time : ", time.time() - a)
#print(I)
#print(D)
#print("----")
#for i in range(len(I[0])):
# print(audio_names[I[0][i]], " with distance ", D[0][i])
# print(" url : ", get_url_myma(I[0][i], audio_names, url_dict))
formated = []
output_csv = f"{prompt}_results_text.csv"
with open(output_csv, "w") as w:
writer = csv.writer(w)
count = 0
for top in I[0]:
if count > max_output:
break
#formated.append(get_url_myma(top, audio_names, url_dict)
#formated.append(audio_names[top].split('.')[0])
#writer.writerow(dict_catalog[file].values())
count += 1
out = [output_csv]
for i in range(max_output):
out.append(audio_names[I[0][i]].split('.')[0])
out.append(get_url_myma(I[0][i], audio_names, url_dict))
print("Total time : ", time.time() - a)
return out
'''return [output_csv,
audio_names[I[0][0]].split('.')[0], get_url_myma(I[0][0], audio_names, url_dict),
audio_names[I[0][1]].split('.')[0], get_url_myma(I[0][1], audio_names, url_dict),
audio_names[I[0][2]].split('.')[0], get_url_myma(I[0][2], audio_names, url_dict),
audio_names[I[0][3]].split('.')[0], get_url_myma(I[0][3], audio_names, url_dict),
audio_names[I[0][4]].split('.')[0], get_url_myma(I[0][4], audio_names, url_dict)]
'''
'''
except:
return ["",
"Error input - please try again", "",
"Error input - please try again", "",
"Error input - please try again", "",
"Error input - please try again", "",
"Error input - please try again", "",]
'''
inputs = [gr.Textbox(label="Input", value="type your description", max_lines=2),
gr.Radio(label="Language", choices=["en"], value="en")]
poc_examples = [
["Mysterious filmscore with Arabic influenced instruments","en"],
["Let's go on a magical adventure with wizzards, dragons and castles","en"],
["Creepy piano opening evolves and speeds up into a cinematic orchestral piece","en"],
["Chilled electronic","en"],
#["","en"],
["Relax piano","en"],
["Halloween rock with creepy organ","en"],
["Rhythmic electro dance track for sport, motivation and sweating","en"],
["soundtrack for an action movie from the eighties in a retro synth wave style","en"],
["Choral female singing is rhythmically accompanied in a church with medieval instruments","en"],
["Christmas","en"],
["love romantic with piano, strings and vocals","en"],
["Electronic soundscapes for chilling and relaxing","en"],
["Minimal, emotional, melancholic piano","en"],
["A calm and romantic acoustic guitar melody","en"],
["horror suspense piano","en"],
["Big Band","en"],
["90 eurodance beat","en"],
]
outputs = [gr.File()]
for i in range(max_output):
outputs.append(gr.Textbox(label=f"top{i} track name", show_label=True))
outputs.append(gr.Audio(label=f"top{i}", show_label=False))
'''outputs = [gr.File(),
gr.Textbox(label="Track name 1"), gr.Audio(label="Track 1", show_label=False),
gr.Textbox(label="Track name 2"), gr.Audio(label="Track 2", show_label=False),
gr.Textbox(label="Track name 3"), gr.Audio(label="Track 3", show_label=False),
gr.Textbox(label="Track name 4"), gr.Audio(label="Track 4", show_label=False),
gr.Textbox(label="Track name 5"), gr.Audio(label="Track 5", show_label=False)]
'''
demo1 = gr.Interface(fn=process, inputs=inputs, outputs=outputs, examples=poc_examples, cache_examples=False, examples_per_page=20)
demo1.launch(debug=False)