rgp230 commited on
Commit
f232eef
·
1 Parent(s): 34f7bc7

fix(remove_tf): Unblock container build by removing tf dependency

Browse files
requirements.txt CHANGED
@@ -9,8 +9,6 @@ langchain_huggingface
9
  langgraph-prebuilt
10
  streamlit
11
  transformers[torch]
12
- tensorflow
13
- tf-keras
14
  langchain_openai
15
  langchain_google_genai
16
  torch
 
9
  langgraph-prebuilt
10
  streamlit
11
  transformers[torch]
 
 
12
  langchain_openai
13
  langchain_google_genai
14
  torch
src/graph/__pycache__/state_vector_nodes.cpython-312.pyc CHANGED
Binary files a/src/graph/__pycache__/state_vector_nodes.cpython-312.pyc and b/src/graph/__pycache__/state_vector_nodes.cpython-312.pyc differ
 
src/graph/state_vector_nodes.py CHANGED
@@ -16,7 +16,7 @@ from langchain_community.tools.tavily_search import TavilySearchResults
16
  import pandas as pd
17
  import torch.nn.functional as F
18
 
19
-
20
  class question_model:
21
  def __init__(self,loaded_tokenizer,loaded_model, llm, df_keys):
22
  #self.state=StateVector
@@ -45,43 +45,49 @@ class question_model:
45
  #print(state)
46
  if not state.get('seed_question') or len(state.get('seed_question').strip())<3:
47
  raise ValueError("Seed question is not set in the state vector.")
48
- predict_input = self.tokenizer.encode(
 
49
  text=state.get('seed_question').lower(),
50
  truncation=True,
51
  padding=True,
52
  return_tensors="pt")
53
- output = self.distilbert_model(predict_input.numpy())[0]
54
- numpy_output=output.numpy()
55
- torch_output=torch.from_numpy(numpy_output)
56
- prediction_value = torch.argmax(torch_output, dim=1).numpy() # All answers
57
- prob_value=F.softmax(torch_output).numpy()[0]
 
 
 
 
 
58
 
59
- #prob_value = F.softmax(output, dim=1).cpu().numpy()[0]
60
- #prediction_value = tf.argmax(output, axis=1).numpy()#All answers
61
- #prob_value=tf.nn.softmax(output).numpy()[0]#Probability of TF output
62
- Topic_Bool=prob_value>0.4
63
- Topics=[]
64
- Keywords={}
65
- for index, key in enumerate(sdg_goals):
66
- if not Topic_Bool[index]:continue
67
- #print(sdg_goals[key])
68
- Topics.append((index+1,sdg_goals[key]))
69
- #print(Topics)
70
- for i,t in Topics:
71
- kw_patterns=self.df_keys[self.df_keys['topic_num']==i]['keywords'].values[0].split(',')
72
- Keywords[t] = re.findall(r'%s' %("|".join(kw_patterns)),state['seed_question'])
73
- if not Keywords[t]:
74
- Keywords[t] = kw_patterns
75
- state['messages'].append(AIMessage(content="Will add keywords for the topic: %s \n" % t ))
76
- state['topic'] = Topics
77
- state['topic_kw'] = Keywords
78
- if not state.get('country'):
79
- state['messages'].append(AIMessage(content="Country is not set. Please provide a country. \n"))
80
- return state
81
- elif not state.get('topic'):
82
- state['messages'].append(AIMessage(content="Missing topic please ask a question about the 17 Sustainable Development Goals. Graph will terminate. \n"))
83
- state['messages'].append(AIMessage(content="Topics are: %s and keywords found: %s.\n Proceeding to prompt creation. \n" \
84
- %(", ".join(Keywords.keys()), ", ".join([kw for kws in Keywords.values() for kw in kws]))))
85
  return state
86
 
87
  def should_continue(self, state:StateVector) -> str:
 
16
  import pandas as pd
17
  import torch.nn.functional as F
18
 
19
+ torch.classes.__path__ = []
20
  class question_model:
21
  def __init__(self,loaded_tokenizer,loaded_model, llm, df_keys):
22
  #self.state=StateVector
 
45
  #print(state)
46
  if not state.get('seed_question') or len(state.get('seed_question').strip())<3:
47
  raise ValueError("Seed question is not set in the state vector.")
48
+ #print(state.get('seed_question').lower())
49
+ predict_input = self.tokenizer(
50
  text=state.get('seed_question').lower(),
51
  truncation=True,
52
  padding=True,
53
  return_tensors="pt")
54
+ #print(predict_input)
55
+ with torch.no_grad():
56
+ logits = self.distilbert_model(**predict_input).logits
57
+ #print(logits)
58
+ #output = self.distilbert_model(predict_input.numpy())[0]
59
+ #print(output)
60
+ #numpy_output=output.numpy()
61
+ #torch_output=torch.from_numpy(numpy_output)
62
+ #prediction_value = torch.argmax(torch_output, dim=1).numpy() # All answers
63
+ prob_value=F.softmax(logits, dim=1).cpu().numpy()[0]
64
 
65
+ #prob_value = F.softmax(output, dim=1).cpu().numpy()[0]
66
+ #prediction_value = tf.argmax(output, axis=1).numpy()#All answers
67
+ #prob_value=tf.nn.softmax(output).numpy()[0]#Probability of TF output
68
+ Topic_Bool=prob_value>0.4
69
+ Topics=[]
70
+ Keywords={}
71
+ for index, key in enumerate(sdg_goals):
72
+ if not Topic_Bool[index]:continue
73
+ #print(sdg_goals[key])
74
+ Topics.append((index+1,sdg_goals[key]))
75
+ #print(Topics)
76
+ for i,t in Topics:
77
+ kw_patterns=self.df_keys[self.df_keys['topic_num']==i]['keywords'].values[0].split(',')
78
+ Keywords[t] = re.findall(r'%s' %("|".join(kw_patterns)),state['seed_question'])
79
+ if not Keywords[t]:
80
+ Keywords[t] = kw_patterns
81
+ state['messages'].append(AIMessage(content="Will add keywords for the topic: %s \n" % t ))
82
+ state['topic'] = Topics
83
+ state['topic_kw'] = Keywords
84
+ if not state.get('country'):
85
+ state['messages'].append(AIMessage(content="Country is not set. Please provide a country. \n"))
86
+ return state
87
+ elif not state.get('topic'):
88
+ state['messages'].append(AIMessage(content="Missing topic please ask a question about the 17 Sustainable Development Goals. Graph will terminate. \n"))
89
+ state['messages'].append(AIMessage(content="Topics are: %s and keywords found: %s.\n Proceeding to prompt creation. \n" \
90
+ %(", ".join(Keywords.keys()), ", ".join([kw for kws in Keywords.values() for kw in kws]))))
91
  return state
92
 
93
  def should_continue(self, state:StateVector) -> str:
src/state/__pycache__/state.cpython-312.pyc CHANGED
Binary files a/src/state/__pycache__/state.cpython-312.pyc and b/src/state/__pycache__/state.cpython-312.pyc differ
 
src/streamlit_app.py CHANGED
@@ -2,10 +2,9 @@ import configparser
2
  import altair as alt
3
  import streamlit as st
4
  from typing import List, Optional
5
- from transformers import DistilBertTokenizerFast, TFDistilBertForSequenceClassification
6
  from langchain_core.messages import AnyMessage, AIMessage,SystemMessage, HumanMessage,AIMessageChunk
7
 
8
-
9
  from streamlitui.constants import unsdg_countries
10
  from llm.llm_setup import ModelSelection
11
  import pandas as pd
@@ -94,8 +93,8 @@ if __name__=='__main__':
94
  user_input=ui.load_streamlit_ui()
95
  LLM_Selection=ModelSelection(user_input)
96
  if user_input["GENAI_API_KEY"]:llm=LLM_Selection.setup_llm_model()
97
- loaded_tokenizer = DistilBertTokenizerFast.from_pretrained('src/train_bert/topic_classifier_model')
98
- loaded_model = TFDistilBertForSequenceClassification.from_pretrained('src/train_bert/topic_classifier_model')
99
  df_keys=pd.read_csv('src/train_bert/training_data/Keyword_Patterns.csv')
100
 
101
  if not user_input:
 
2
  import altair as alt
3
  import streamlit as st
4
  from typing import List, Optional
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from langchain_core.messages import AnyMessage, AIMessage,SystemMessage, HumanMessage,AIMessageChunk
7
 
 
8
  from streamlitui.constants import unsdg_countries
9
  from llm.llm_setup import ModelSelection
10
  import pandas as pd
 
93
  user_input=ui.load_streamlit_ui()
94
  LLM_Selection=ModelSelection(user_input)
95
  if user_input["GENAI_API_KEY"]:llm=LLM_Selection.setup_llm_model()
96
+ loaded_tokenizer = AutoTokenizer.from_pretrained('src/train_bert/topic_classifier_model')
97
+ loaded_model = AutoModelForSequenceClassification.from_pretrained('src/train_bert/topic_classifier_model')
98
  df_keys=pd.read_csv('src/train_bert/training_data/Keyword_Patterns.csv')
99
 
100
  if not user_input: