Agents_DeepSearch / src /streamlit_app.py
rgp230's picture
fix(downgrade_transformers): Downgrade transformers version to bypass the errors
e036ad8
import configparser
import altair as alt
import streamlit as st
from typing import List, Optional
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from langchain_core.messages import AnyMessage, AIMessage,SystemMessage, HumanMessage,AIMessageChunk
from streamlitui.constants import unsdg_countries
from llm.llm_setup import ModelSelection
import pandas as pd
from state.state import StateVector
from graph.state_vector_nodes import question_model,research_model
from graph.graph_builder import BuildGraphOptions
import re
import os
import torch
device=torch.get_default_device()
torch.classes.__path__ = []
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class StreamlitConfigUI:
"""
A Streamlit UI class that uses ConfigParser to load settings from an INI file.
"""
def __init__(self, config_file: str = "src/streamlitui/uiconfigfile.ini"):
"""
Initialize the UI class with configuration file.
Args:
config_file (str): Path to the INI configuration file
"""
self.config_file = config_file
self.config = configparser.ConfigParser()
self.config.read(config_file)
def get_llm_options(self):
return self.config["DEFAULT"].get("LLM_OPTIONS").split(",")
def get_page_title(self):
return self.config["DEFAULT"].get("PAGE_TITLE")
def get_usecase_options(self):
return self.config["DEFAULT"].get("USE_CASE_OPTIONS").split(",")
class LoadStreamlitUI:
def __init__(self):
self.config=StreamlitConfigUI()
self.user_controls={}
self.unsdg_countries=unsdg_countries
def filter_countries(self, query: str) -> List[str]:
"""
Filter countries based on the query string.
Args:
query (str): The search query
Returns:
List[str]: Filtered list of countries
"""
if not query:
return self.unsdg_countries
query_lower = query.lower()
return [
country for country in self.unsdg_countries
if query_lower in country.lower()
]
def load_streamlit_ui(self):
st.set_page_config(page_title= "🇺🇳 " + self.config.get_page_title(), layout="wide")
st.header("🇺🇳 " + self.config.get_page_title())
with st.sidebar:
# Get options from config
llm_options = self.config.get_llm_options()
usecase_options = self.config.get_usecase_options()
# LLM selection
self.user_controls["selected_llm"] = st.selectbox("Select LLM", llm_options)
self.user_controls["GENAI_API_KEY"] = st.session_state["GENAI_API_KEY"]=st.text_input("API Key",type="password")
if not self.user_controls["GENAI_API_KEY"]:
st.warning("⚠️ Please enter a Gemini or Open AI ChatGPT API key to proceed. Don't have? refer : https://platform.openai.com/api-keys or https://aistudio.google.com/welcome?gclsrc=aw.ds&gad_source=1&gad_campaignid=21521909442 ")
self.user_controls["selected_usecase"]=st.selectbox("Select Usecases",usecase_options)
self.user_controls['UN SDG Country']= st.selectbox("Choose a country (start typing to search):",options=[""] + self.unsdg_countries)
if self.user_controls['selected_usecase']=='DeepRishSearch':
self.user_controls["TAVILY_API_KEY"] = st.session_state["TAVILY_API_KEY"]=st.text_input("Tavily API Key",type="password")
if not self.user_controls["TAVILY_API_KEY"]:
st.warning("⚠️ Please enter a Tavily API key to proceed. Don't have? refer : https://www.tavily.com/")
return self.user_controls
def display_result_on_ui( state_graph,mode="Question Refining Mode:"):
if mode =="Question Refining Mode:":
for event in graph.stream({'messages':("user",user_message)}):
print(event.values())
for value in event.values():
print(value['messages'])
with st.chat_message("user"):
st.write(user_message)
with st.chat_message("SDG AI Assistant"):
st.write(value["messages"].content)
if __name__=='__main__':
ui=LoadStreamlitUI()
user_input=ui.load_streamlit_ui()
LLM_Selection=ModelSelection(user_input)
if user_input["GENAI_API_KEY"]:llm=LLM_Selection.setup_llm_model()
loaded_tokenizer = AutoTokenizer.from_pretrained('src/train_bert/topic_classifier_model')
loaded_model = AutoModelForSequenceClassification.from_pretrained('src/train_bert/topic_classifier_model',device_map=device)#.to_empty(device=device)
df_keys=pd.read_csv('src/train_bert/training_data/Keyword_Patterns.csv')
if not user_input:
st.error("Error: Failed to load user input from the UI.")
#Input prompt
user_message = st.chat_input("Ask me a question or give me keywords to kick off a query about a UN SDG Goal:")
if user_message and user_input['UN SDG Country']:
state=StateVector(country=user_input['UN SDG Country'], seed_question=user_message, messages=[])
if user_input['selected_usecase']=='AskSmart SDG Assistant':
SmartQuestions=question_model(#StateVector=state,
loaded_tokenizer=loaded_tokenizer,
loaded_model=loaded_model, llm=llm, df_keys=df_keys)
builder=BuildGraphOptions(SmartQuestions)
graph=builder.build_question_graph()
with st.chat_message("assistant"):
intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\
You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\
I can also help you create a question related to the SDGs based on your input.\
Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #,
st.write(intro)
initial_input = {'country': user_input['UN SDG Country'], 'seed_question': user_message}
with st.chat_message("user"):st.write(user_message)
with st.chat_message("assistant"):
message_placeholder = st.empty()
accumulated_content = ""
for chunk in graph.stream(initial_input, stream_mode='messages'):
message, meta=chunk
if isinstance(message, AIMessage):
accumulated_content += message.content
message_placeholder.write(accumulated_content)
elif user_input['selected_usecase']=='DeepRishSearch':
print("Deep Rishsearch")
SmartQuestions=question_model(#StateVector=state,
loaded_tokenizer=loaded_tokenizer,
loaded_model=loaded_model, llm=llm, df_keys=df_keys)
builder=BuildGraphOptions(SmartQuestions)
ResearchModel=research_model(llm=SmartQuestions.genai_model, tavily_api_key=user_input['TAVILY_API_KEY'])
graph=builder.build_research_graph(ResearchModel)
with st.chat_message("assistant"):
intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\
You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\
I can also help you create a question related to the SDGs based on your input.\
Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #,
st.write(intro)
with st.chat_message("user"):st.write(user_message)
initial_input = StateVector({'country': user_input['UN SDG Country'], 'seed_question': user_message})
with st.chat_message("assistant"):
message_placeholder = st.empty()
accumulated_content = ""
for chunk in graph.stream(initial_input, stream_mode='messages'):
message, meta=chunk
if isinstance(message, AIMessage):
accumulated_content += message.content
message_placeholder.write(accumulated_content)