TextRetrieval / app.py
PierreHanna's picture
Update app.py
8a1dc24
raw
history blame
4.49 kB
import tempfile
import gradio as gr
import os
import tensorflow_hub as hub
import tensorflow as tf
import tensorflow_text as text
import sys
import numpy as np
import faiss
import csv
import datetime
import joblib
from huggingface_hub import hf_hub_download
encoder_text_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['ENCODER_TEXT'],
use_auth_token=os.environ['TOKEN'])
print("DEBUG ", encoder_text_path)
# NO GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# Cacher le nom du repo
python_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['MODEL_FILE'],
use_auth_token=os.environ['TOKEN'])
print(python_path)
os.system('ls -la')
sys.path.append(os.environ['PRIVATE_DIR'])
from models import *
preprocess_model, model = get_models()
index_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['INDEX'],
use_auth_token=os.environ['TOKEN'])
indexnames_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['INDEX_NAMES'],
use_auth_token=os.environ['TOKEN']) #########
#catalog_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['CATALOG'],
# use_auth_token=os.environ['TOKEN']) ###############
catalog_path = get_catalog()
url_dict=get_durl(catalog_path) ############
audio_names = joblib.load(open(indexnames_path, 'rb')) ############
index = faiss.read_index(index_path)
encoder_text = tf.keras.models.load_model(encoder_text_path)
def process(prompt, lang):
now = datetime.datetime.now()
print()
print('*************')
print("Current Time: ", str(now))
print("Text input : ", prompt)
print('*************')
print()
# Embed text
embed_query = get_predict(encoder_text, prompt, preprocess_model, model)
faiss.normalize_L2(embed_query)
print(" text embed computed.")
# distance computing
D, I = index.search(embed_query, TOP)
# output : top N audio file names
print(I)
print(D)
print("----")
for i in range(len(I[0])):
print(audio_names[I[0][i]], " with distance ", D[0][i])
print(" url : ", get_url(I[0][i], audio_names, url_dict))
return [get_url(I[0][0], audio_names, url_dict), get_url(I[0][1], audio_names, url_dict), get_url(I[0][2], audio_names, url_dict), get_url(I[0][3], audio_names, url_dict), get_url(I[0][4], audio_names, url_dict)] #######
inputs = [gr.Textbox(label="Input", value="type your description", max_lines=2),
gr.Radio(label="Language", choices=["en"], value="en")]
poc_examples = [#[["I love learning machine learning"],["autre"]]
["Mysterious filmscore with Arabic influenced instruments","en"],
["Let's go on a magical adventure with wizzards, dragons and castles","en"],
["Creepy piano opening evolves and speeds up into a cinematic orchestral piece","en"],
["Halloween rock with creepy organ","en"],
["Rhythmic electro dance track for sport, motivation and sweating","en"],
["soundtrack for an action movie from the eighties in a retro synth wave style","en"],
["Choral female singing is rhythmically accompanied in a church with medieval instruments","en"],
["Christmas","en"],
["love romantic with piano, strings and vocals","en"],
["Electronic soundscapes for chilling and relaxing","en"],
["Minimal, emotional, melancholic piano","en"],
# ["Minimal, happy, joyful piano","en"],
["A calm and romantic acoustic guitar melody","en"],
["horror suspense piano","en"],
["Big Band","en"],
["90 eurodance beat","en"],
]
# cacher ces textes aussi pour pas que le user puisse afficher des choses....
outputs = [gr.Audio(label="Track 1"), gr.Audio(label="Track 2"), gr.Audio(label="Track 3"), gr.Audio(label="Track 4"), gr.Audio(label="Track 5")]
demo1 = gr.Interface(fn=process, inputs=inputs, outputs=outputs, examples=poc_examples, cache_examples=False)
demo1.launch(debug=True)
#demo1.launch(debug=True, enable_queue = False, auth=(os.environ['DEMO_LOGIN'], os.environ['DEMO_PWD']),auth_message = "Contact Simbals to get login/pwd")#, share=True)