Update src/prediction_compile.py

#1
Files changed (1) hide show
  1. src/prediction_compile.py +14 -1
src/prediction_compile.py CHANGED
@@ -5,6 +5,8 @@ import pickle
5
  import joblib
6
  import nltk
7
  import os
 
 
8
  import numpy as np
9
  import pandas as pd
10
  from tensorflow.keras.preprocessing.sequence import pad_sequences
@@ -33,8 +35,19 @@ st.markdown(
33
  def patch_transformers_pickle_compat() -> None:
34
  try:
35
  from transformers.models.bert import modeling_bert
 
 
 
36
  if not hasattr(modeling_bert, "BertSdpaSelfAttention"):
37
  modeling_bert.BertSdpaSelfAttention = modeling_bert.BertSelfAttention
 
 
 
 
 
 
 
 
38
  except Exception:
39
  pass
40
 
@@ -56,7 +69,7 @@ def load_tokenizer_params():
56
 
57
  @st.cache_resource
58
  def load_topic_models():
59
- patch_transformers_pickle_compat() # must run before joblib.load
60
  neg_path = "./src/fastopic_negative_model_10.pkl"
61
  pos_path = "./src/fastopic_positive_model_10.pkl"
62
  neg_model = joblib.load(neg_path)
 
5
  import joblib
6
  import nltk
7
  import os
8
+ import sys
9
+ import types
10
  import numpy as np
11
  import pandas as pd
12
  from tensorflow.keras.preprocessing.sequence import pad_sequences
 
35
  def patch_transformers_pickle_compat() -> None:
36
  try:
37
  from transformers.models.bert import modeling_bert
38
+ from transformers import BertTokenizerFast
39
+
40
+ # Fix old attention symbol
41
  if not hasattr(modeling_bert, "BertSdpaSelfAttention"):
42
  modeling_bert.BertSdpaSelfAttention = modeling_bert.BertSelfAttention
43
+
44
+ # Fix old tokenizer module path
45
+ module_name = "transformers.models.bert.tokenization_bert_fast"
46
+ if module_name not in sys.modules:
47
+ shim = types.ModuleType(module_name)
48
+ shim.BertTokenizerFast = BertTokenizerFast
49
+ sys.modules[module_name] = shim
50
+
51
  except Exception:
52
  pass
53
 
 
69
 
70
  @st.cache_resource
71
  def load_topic_models():
72
+ patch_transformers_pickle_compat()
73
  neg_path = "./src/fastopic_negative_model_10.pkl"
74
  pos_path = "./src/fastopic_positive_model_10.pkl"
75
  neg_model = joblib.load(neg_path)