PierreHanna commited on
Commit
7bfd63b
Β·
1 Parent(s): b687ee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -19
app.py CHANGED
@@ -9,6 +9,7 @@ import numpy as np
9
  import faiss
10
  import csv
11
  import datetime
 
12
 
13
  from huggingface_hub import hf_hub_download
14
  encoder_text_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename=os.environ['ENCODER_TEXT'],
@@ -30,7 +31,31 @@ sys.path.append(os.environ['PRIVATE_DIR'])
30
  from models import *
31
 
32
  preprocess_model, model = get_models()
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def process(prompt, lang):
35
 
36
  now = datetime.datetime.now()
@@ -46,30 +71,13 @@ def process(prompt, lang):
46
  print(" text representation computed.")
47
 
48
  # Embed text
49
- #from models import *
50
- encoder_text = tf.keras.models.load_model(encoder_text_path)
51
  embed_query = encoder_text.predict(embed_prompt["pooled_output"])
52
  faiss.normalize_L2(embed_query)
53
  print(" text embed computed.")
54
 
55
- # load embed audio catalog
56
- index = faiss.read_index("BMG_221022.index")
57
-
58
  # distance computing
59
  D, I = index.search(embed_query, TOP)
60
 
61
- # names index
62
- import joblib
63
- audio_names = joblib.load(open('BMG_221022_names.index', 'rb'))
64
-
65
- #url
66
- url_dict={}
67
- with open("bmg_clean.csv") as csv_file:
68
- csv_reader = csv.reader(csv_file, delimiter=';')
69
- for row in csv_reader:
70
- f = row[2].split('/')[-1]
71
- url_dict[f.split('/')[-1][:-4]] = row[2]
72
-
73
  # output : top N audio file names
74
  print(I)
75
  print(D)
@@ -78,7 +86,6 @@ def process(prompt, lang):
78
  print(audio_names[I[0][i]], " with distance ", D[0][i])
79
  print(" url : ", url_dict[audio_names[I[0][i]]])
80
 
81
-
82
  return [url_dict[audio_names[I[0][0]]], url_dict[audio_names[I[0][1]]], url_dict[audio_names[I[0][2]]], url_dict[audio_names[I[0][3]]], url_dict[audio_names[I[0][4]]]]
83
 
84
  inputs = [gr.Textbox(label="Input", value="type your description", max_lines=2),
 
9
  import faiss
10
  import csv
11
  import datetime
12
+ import joblib
13
 
14
  from huggingface_hub import hf_hub_download
15
  encoder_text_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename=os.environ['ENCODER_TEXT'],
 
31
  from models import *
32
 
33
  preprocess_model, model = get_models()
34
+
35
+ index_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename=os.environ['INDEX'],
36
+ use_auth_token=os.environ['TOKEN'])
37
+ indexnames_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename=os.environ['INDEX_NAMES'],
38
+ use_auth_token=os.environ['TOKEN'])
39
+ catalog_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename=os.environ['CATALOG'],
40
+ use_auth_token=os.environ['TOKEN'])
41
+
42
+ #url
43
+ url_dict={}
44
+ with open(catalog_path) as csv_file:
45
+ csv_reader = csv.reader(csv_file, delimiter=';')
46
+ for row in csv_reader:
47
+ f = row[2].split('/')[-1]
48
+ url_dict[f.split('/')[-1][:-4]] = row[2]
49
+
50
+ # names index
51
+ audio_names = joblib.load(open(indexnames_path, 'rb'))
52
+
53
+ # load embed audio catalog
54
+ index = faiss.read_index(index_path)
55
+
56
+ encoder_text = tf.keras.models.load_model(encoder_text_path)
57
+
58
+
59
  def process(prompt, lang):
60
 
61
  now = datetime.datetime.now()
 
71
  print(" text representation computed.")
72
 
73
  # Embed text
 
 
74
  embed_query = encoder_text.predict(embed_prompt["pooled_output"])
75
  faiss.normalize_L2(embed_query)
76
  print(" text embed computed.")
77
 
 
 
 
78
  # distance computing
79
  D, I = index.search(embed_query, TOP)
80
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # output : top N audio file names
82
  print(I)
83
  print(D)
 
86
  print(audio_names[I[0][i]], " with distance ", D[0][i])
87
  print(" url : ", url_dict[audio_names[I[0][i]]])
88
 
 
89
  return [url_dict[audio_names[I[0][0]]], url_dict[audio_names[I[0][1]]], url_dict[audio_names[I[0][2]]], url_dict[audio_names[I[0][3]]], url_dict[audio_names[I[0][4]]]]
90
 
91
  inputs = [gr.Textbox(label="Input", value="type your description", max_lines=2),