PierreHanna commited on
Commit
3b76b78
Β·
1 Parent(s): 34e71ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -11,7 +11,7 @@ import csv
11
  import datetime
12
 
13
  from huggingface_hub import hf_hub_download
14
- file_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename="encoder_text_retrievaltext_bmg_221022_54.h5",
15
  use_auth_token=os.environ['TOKEN'])
16
 
17
  #use_auth_token="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")
@@ -42,6 +42,8 @@ def make_bert_preprocess_model(sentence_features, tfhub_handle_preprocess, seq_l
42
  model_inputs = packer(truncated_segments)
43
  return tf.keras.Model(input_segments, model_inputs)
44
 
 
 
45
  from models import *
46
 
47
  def process(prompt, lang):
@@ -76,8 +78,7 @@ def process(prompt, lang):
76
 
77
  # Embed text
78
  #from models import *
79
- #encoder_text = tf.keras.models.load_model('encoder_text_retrievaltext_bmg_221022_54.h5')
80
- encoder_text = tf.keras.models.load_model(file_path)
81
  embed_query = encoder_text.predict(embed_prompt["pooled_output"])
82
  faiss.normalize_L2(embed_query)
83
  print(" text embed computed.")
 
11
  import datetime
12
 
13
  from huggingface_hub import hf_hub_download
14
+ encoder_text_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename="encoder_text_retrievaltext_bmg_221022_54.h5",
15
  use_auth_token=os.environ['TOKEN'])
16
 
17
  #use_auth_token="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")
 
42
  model_inputs = packer(truncated_segments)
43
  return tf.keras.Model(input_segments, model_inputs)
44
 
45
+ python_path = hf_hub_download(repo_id="PierreHanna/TextRetrieval", repo_type="space", filename="models.py",
46
+ use_auth_token=os.environ['TOKEN'])
47
  from models import *
48
 
49
  def process(prompt, lang):
 
78
 
79
  # Embed text
80
  #from models import *
81
+ encoder_text = tf.keras.models.load_model(encoder_text_path)
 
82
  embed_query = encoder_text.predict(embed_prompt["pooled_output"])
83
  faiss.normalize_L2(embed_query)
84
  print(" text embed computed.")