rgp230 commited on
Commit
833435e
·
1 Parent(s): 8c89aac

fix(downgrade_transformers): Downgrade transformers version to bypass the errors

Browse files
src/graph/__pycache__/state_vector_nodes.cpython-312.pyc CHANGED
Binary files a/src/graph/__pycache__/state_vector_nodes.cpython-312.pyc and b/src/graph/__pycache__/state_vector_nodes.cpython-312.pyc differ
 
src/graph/state_vector_nodes.py CHANGED
@@ -48,8 +48,9 @@ class question_model:
48
  #print(state.get('seed_question').lower())
49
  predict_input = self.tokenizer(
50
  text=state.get('seed_question').lower(),
 
51
  truncation=True,
52
- padding=True,
53
  return_tensors="pt")
54
  #print(predict_input)
55
  with torch.no_grad():
 
48
  #print(state.get('seed_question').lower())
49
  predict_input = self.tokenizer(
50
  text=state.get('seed_question').lower(),
51
+ max_length=512,
52
  truncation=True,
53
+ padding='max_length',
54
  return_tensors="pt")
55
  #print(predict_input)
56
  with torch.no_grad():
src/streamlit_app.py CHANGED
@@ -15,6 +15,8 @@ import re
15
  import os
16
  import torch
17
  device=torch.get_default_device()
 
 
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
  class StreamlitConfigUI:
20
 
@@ -82,6 +84,9 @@ class LoadStreamlitUI:
82
  self.user_controls['UN SDG Country']= st.selectbox("Choose a country (start typing to search):",options=[""] + self.unsdg_countries)
83
  if self.user_controls['selected_usecase']=='DeepRishSearch':
84
  self.user_controls["TAVILY_API_KEY"] = st.session_state["TAVILY_API_KEY"]=st.text_input("Tavily API Key",type="password")
 
 
 
85
  return self.user_controls
86
  def display_result_on_ui( state_graph,mode="Question Refining Mode:"):
87
  if mode =="Question Refining Mode:":
@@ -99,8 +104,8 @@ if __name__=='__main__':
99
  user_input=ui.load_streamlit_ui()
100
  LLM_Selection=ModelSelection(user_input)
101
  if user_input["GENAI_API_KEY"]:llm=LLM_Selection.setup_llm_model()
102
- loaded_tokenizer = AutoTokenizer.from_pretrained('src/train_bert/topic_classifier_model')
103
- loaded_model = AutoModelForSequenceClassification.from_pretrained('src/train_bert/topic_classifier_model',device_map=device)
104
  df_keys=pd.read_csv('src/train_bert/training_data/Keyword_Patterns.csv')
105
 
106
  if not user_input:
 
15
  import os
16
  import torch
17
  device=torch.get_default_device()
18
+ torch.classes.__path__ = []
19
+
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
  class StreamlitConfigUI:
22
 
 
84
  self.user_controls['UN SDG Country']= st.selectbox("Choose a country (start typing to search):",options=[""] + self.unsdg_countries)
85
  if self.user_controls['selected_usecase']=='DeepRishSearch':
86
  self.user_controls["TAVILY_API_KEY"] = st.session_state["TAVILY_API_KEY"]=st.text_input("Tavily API Key",type="password")
87
+ if not self.user_controls["TAVILY_API_KEY"]:
88
+ st.warning("⚠️ Please enter a Tavily API key to proceed. Don't have? refer : https://www.tavily.com/")
89
+
90
  return self.user_controls
91
  def display_result_on_ui( state_graph,mode="Question Refining Mode:"):
92
  if mode =="Question Refining Mode:":
 
104
  user_input=ui.load_streamlit_ui()
105
  LLM_Selection=ModelSelection(user_input)
106
  if user_input["GENAI_API_KEY"]:llm=LLM_Selection.setup_llm_model()
107
+ loaded_tokenizer = AutoTokenizer.from_pretrained('src/train_bert/topic_classifier_model_test')
108
+ loaded_model = AutoModelForSequenceClassification.from_pretrained('src/train_bert/topic_classifier_model_test',device_map=device)#.to_empty(device=device)
109
  df_keys=pd.read_csv('src/train_bert/training_data/Keyword_Patterns.csv')
110
 
111
  if not user_input: