import configparser import altair as alt import streamlit as st from typing import List, Optional from transformers import AutoTokenizer, AutoModelForSequenceClassification from langchain_core.messages import AnyMessage, AIMessage,SystemMessage, HumanMessage,AIMessageChunk from streamlitui.constants import unsdg_countries from llm.llm_setup import ModelSelection import pandas as pd from state.state import StateVector from graph.state_vector_nodes import question_model,research_model from graph.graph_builder import BuildGraphOptions import re import os import torch device=torch.get_default_device() torch.classes.__path__ = [] os.environ["TOKENIZERS_PARALLELISM"] = "false" class StreamlitConfigUI: """ A Streamlit UI class that uses ConfigParser to load settings from an INI file. """ def __init__(self, config_file: str = "src/streamlitui/uiconfigfile.ini"): """ Initialize the UI class with configuration file. Args: config_file (str): Path to the INI configuration file """ self.config_file = config_file self.config = configparser.ConfigParser() self.config.read(config_file) def get_llm_options(self): return self.config["DEFAULT"].get("LLM_OPTIONS").split(",") def get_page_title(self): return self.config["DEFAULT"].get("PAGE_TITLE") def get_usecase_options(self): return self.config["DEFAULT"].get("USE_CASE_OPTIONS").split(",") class LoadStreamlitUI: def __init__(self): self.config=StreamlitConfigUI() self.user_controls={} self.unsdg_countries=unsdg_countries def filter_countries(self, query: str) -> List[str]: """ Filter countries based on the query string. Args: query (str): The search query Returns: List[str]: Filtered list of countries """ if not query: return self.unsdg_countries query_lower = query.lower() return [ country for country in self.unsdg_countries if query_lower in country.lower() ] def load_streamlit_ui(self): st.set_page_config(page_title= "🇺🇳 " + self.config.get_page_title(), layout="wide") st.header("🇺🇳 " + self.config.get_page_title()) with st.sidebar: # Get options from config llm_options = self.config.get_llm_options() usecase_options = self.config.get_usecase_options() # LLM selection self.user_controls["selected_llm"] = st.selectbox("Select LLM", llm_options) self.user_controls["GENAI_API_KEY"] = st.session_state["GENAI_API_KEY"]=st.text_input("API Key",type="password") if not self.user_controls["GENAI_API_KEY"]: st.warning("⚠️ Please enter a Gemini or Open AI ChatGPT API key to proceed. Don't have? refer : https://platform.openai.com/api-keys or https://aistudio.google.com/welcome?gclsrc=aw.ds&gad_source=1&gad_campaignid=21521909442 ") self.user_controls["selected_usecase"]=st.selectbox("Select Usecases",usecase_options) self.user_controls['UN SDG Country']= st.selectbox("Choose a country (start typing to search):",options=[""] + self.unsdg_countries) if self.user_controls['selected_usecase']=='DeepRishSearch': self.user_controls["TAVILY_API_KEY"] = st.session_state["TAVILY_API_KEY"]=st.text_input("Tavily API Key",type="password") if not self.user_controls["TAVILY_API_KEY"]: st.warning("⚠️ Please enter a Tavily API key to proceed. Don't have? refer : https://www.tavily.com/") return self.user_controls def display_result_on_ui( state_graph,mode="Question Refining Mode:"): if mode =="Question Refining Mode:": for event in graph.stream({'messages':("user",user_message)}): print(event.values()) for value in event.values(): print(value['messages']) with st.chat_message("user"): st.write(user_message) with st.chat_message("SDG AI Assistant"): st.write(value["messages"].content) if __name__=='__main__': ui=LoadStreamlitUI() user_input=ui.load_streamlit_ui() LLM_Selection=ModelSelection(user_input) if user_input["GENAI_API_KEY"]:llm=LLM_Selection.setup_llm_model() loaded_tokenizer = AutoTokenizer.from_pretrained('src/train_bert/topic_classifier_model') loaded_model = AutoModelForSequenceClassification.from_pretrained('src/train_bert/topic_classifier_model',device_map=device)#.to_empty(device=device) df_keys=pd.read_csv('src/train_bert/training_data/Keyword_Patterns.csv') if not user_input: st.error("Error: Failed to load user input from the UI.") #Input prompt user_message = st.chat_input("Ask me a question or give me keywords to kick off a query about a UN SDG Goal:") if user_message and user_input['UN SDG Country']: state=StateVector(country=user_input['UN SDG Country'], seed_question=user_message, messages=[]) if user_input['selected_usecase']=='AskSmart SDG Assistant': SmartQuestions=question_model(#StateVector=state, loaded_tokenizer=loaded_tokenizer, loaded_model=loaded_model, llm=llm, df_keys=df_keys) builder=BuildGraphOptions(SmartQuestions) graph=builder.build_question_graph() with st.chat_message("assistant"): intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\ You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\ I can also help you create a question related to the SDGs based on your input.\ Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #, st.write(intro) initial_input = {'country': user_input['UN SDG Country'], 'seed_question': user_message} with st.chat_message("user"):st.write(user_message) with st.chat_message("assistant"): message_placeholder = st.empty() accumulated_content = "" for chunk in graph.stream(initial_input, stream_mode='messages'): message, meta=chunk if isinstance(message, AIMessage): accumulated_content += message.content message_placeholder.write(accumulated_content) elif user_input['selected_usecase']=='DeepRishSearch': print("Deep Rishsearch") SmartQuestions=question_model(#StateVector=state, loaded_tokenizer=loaded_tokenizer, loaded_model=loaded_model, llm=llm, df_keys=df_keys) builder=BuildGraphOptions(SmartQuestions) ResearchModel=research_model(llm=SmartQuestions.genai_model, tavily_api_key=user_input['TAVILY_API_KEY']) graph=builder.build_research_graph(ResearchModel) with st.chat_message("assistant"): intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\ You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\ I can also help you create a question related to the SDGs based on your input.\ Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #, st.write(intro) with st.chat_message("user"):st.write(user_message) initial_input = StateVector({'country': user_input['UN SDG Country'], 'seed_question': user_message}) with st.chat_message("assistant"): message_placeholder = st.empty() accumulated_content = "" for chunk in graph.stream(initial_input, stream_mode='messages'): message, meta=chunk if isinstance(message, AIMessage): accumulated_content += message.content message_placeholder.write(accumulated_content)