PierreHanna commited on
Commit
fe3e9c5
Β·
1 Parent(s): 28b1346

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -13,8 +13,7 @@ import joblib
13
 
14
  from huggingface_hub import hf_hub_download
15
 
16
- # Cacher le nom du repo
17
- encoder_text_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename=os.environ['ENCODER_TEXT'],
18
  use_auth_token=os.environ['TOKEN'])
19
  print("DEBUG ", encoder_text_path)
20
  # NO GPU
@@ -22,18 +21,18 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
22
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
23
 
24
  # Cacher le nom du repo
25
- python_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename="models.py", # cacher le nom du fichier !
26
  use_auth_token=os.environ['TOKEN'])
27
  print(python_path)
28
  os.system('ls -la')
29
  sys.path.append(os.environ['PRIVATE_DIR'])
30
  from models import *
31
  preprocess_model, model = get_models()
32
- index_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename=os.environ['INDEX'],
33
  use_auth_token=os.environ['TOKEN'])
34
- indexnames_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename=os.environ['INDEX_NAMES'],
35
  use_auth_token=os.environ['TOKEN']) #########
36
- catalog_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename=os.environ['CATALOG'],
37
  use_auth_token=os.environ['TOKEN']) ###############
38
  url_dict=get_durl(catalog_path) ############
39
  audio_names = joblib.load(open(indexnames_path, 'rb')) ############
@@ -48,13 +47,11 @@ def process(prompt, lang):
48
  print("Text input : ", prompt)
49
  print('*************')
50
  print()
51
- prompt=[prompt]
52
- text_preprocessed = preprocess_model([np.array(prompt)])
53
- embed_prompt = model(text_preprocessed)
54
- print(" text representation computed.")
55
 
56
  # Embed text
57
- embed_query = encoder_text.predict(embed_prompt["pooled_output"]) #######
 
 
58
  faiss.normalize_L2(embed_query)
59
  print(" text embed computed.")
60
 
 
13
 
14
  from huggingface_hub import hf_hub_download
15
 
16
+ encoder_text_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['ENCODER_TEXT'],
 
17
  use_auth_token=os.environ['TOKEN'])
18
  print("DEBUG ", encoder_text_path)
19
  # NO GPU
 
21
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
22
 
23
  # Cacher le nom du repo
24
+ python_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['MODEL_FILE'],
25
  use_auth_token=os.environ['TOKEN'])
26
  print(python_path)
27
  os.system('ls -la')
28
  sys.path.append(os.environ['PRIVATE_DIR'])
29
  from models import *
30
  preprocess_model, model = get_models()
31
+ index_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['INDEX'],
32
  use_auth_token=os.environ['TOKEN'])
33
+ indexnames_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['INDEX_NAMES'],
34
  use_auth_token=os.environ['TOKEN']) #########
35
+ catalog_path = hf_hub_download(repo_id=os.environ['REPO_ID'], repo_type="space", filename=os.environ['CATALOG'],
36
  use_auth_token=os.environ['TOKEN']) ###############
37
  url_dict=get_durl(catalog_path) ############
38
  audio_names = joblib.load(open(indexnames_path, 'rb')) ############
 
47
  print("Text input : ", prompt)
48
  print('*************')
49
  print()
 
 
 
 
50
 
51
  # Embed text
52
+ #embed_query = encoder_text.predict(embed_prompt["pooled_output"]) #######
53
+ embed_query = get_predict(encoder_text, prompt, preprocess_model, model)
54
+
55
  faiss.normalize_L2(embed_query)
56
  print(" text embed computed.")
57