File size: 8,607 Bytes
713a9e8
3a6d542
c061ced
713a9e8
f232eef
a4d8e6f
c061ced
713a9e8
a4d8e6f
 
 
52be2ce
a4d8e6f
 
ba78ba8
8c89aac
 
833435e
 
ba78ba8
713a9e8
c061ced
713a9e8
 
 
 
2398527
713a9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c061ced
713a9e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c061ced
713a9e8
 
 
 
c061ced
713a9e8
 
 
 
 
 
 
9905f36
 
833435e
 
 
713a9e8
a4d8e6f
 
 
 
 
 
 
 
 
 
c061ced
713a9e8
 
 
a4d8e6f
 
e036ad8
 
232133b
a4d8e6f
713a9e8
 
a4d8e6f
 
 
 
 
52be2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9905f36
52be2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4d8e6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import configparser
import altair as alt
import streamlit as st
from typing import List, Optional
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from langchain_core.messages import AnyMessage, AIMessage,SystemMessage, HumanMessage,AIMessageChunk

from streamlitui.constants import unsdg_countries
from llm.llm_setup import ModelSelection
import pandas as pd
from state.state import StateVector
from graph.state_vector_nodes import question_model,research_model
from graph.graph_builder import BuildGraphOptions
import re
import os
import torch
device=torch.get_default_device()
torch.classes.__path__ = []

os.environ["TOKENIZERS_PARALLELISM"] = "false"
class StreamlitConfigUI:

    """
    A Streamlit UI class that uses ConfigParser to load settings from an INI file.
    """
    
    def __init__(self, config_file: str = "src/streamlitui/uiconfigfile.ini"):
        """
        Initialize the UI class with configuration file.
        
        Args:
            config_file (str): Path to the INI configuration file
        """
        self.config_file = config_file
        self.config = configparser.ConfigParser()
        self.config.read(config_file)
    def get_llm_options(self): 
           return self.config["DEFAULT"].get("LLM_OPTIONS").split(",")
    def get_page_title(self): 
           return self.config["DEFAULT"].get("PAGE_TITLE")
    def get_usecase_options(self):
        return self.config["DEFAULT"].get("USE_CASE_OPTIONS").split(",")

    
class LoadStreamlitUI:
    def __init__(self):
        self.config=StreamlitConfigUI()
        self.user_controls={}
        self.unsdg_countries=unsdg_countries
    def filter_countries(self, query: str) -> List[str]:
        """
        Filter countries based on the query string.
        
        Args:
            query (str): The search query
            
        Returns:
            List[str]: Filtered list of countries
        """
        if not query:
            return self.unsdg_countries
        
        query_lower = query.lower()
        return [
            country for country in self.unsdg_countries 
            if query_lower in country.lower()
        ]
    
    def load_streamlit_ui(self):
        st.set_page_config(page_title= "🇺🇳 " + self.config.get_page_title(), layout="wide")
        st.header("🇺🇳 " + self.config.get_page_title())

        with st.sidebar:
            # Get options from config
            llm_options = self.config.get_llm_options()
            usecase_options = self.config.get_usecase_options()

            # LLM selection
            self.user_controls["selected_llm"] = st.selectbox("Select LLM", llm_options)
            self.user_controls["GENAI_API_KEY"] = st.session_state["GENAI_API_KEY"]=st.text_input("API Key",type="password")
            if not self.user_controls["GENAI_API_KEY"]:
                st.warning("⚠️ Please enter a Gemini or Open AI ChatGPT API key to proceed. Don't have? refer : https://platform.openai.com/api-keys or https://aistudio.google.com/welcome?gclsrc=aw.ds&gad_source=1&gad_campaignid=21521909442 ")
            self.user_controls["selected_usecase"]=st.selectbox("Select Usecases",usecase_options)
            self.user_controls['UN SDG Country']= st.selectbox("Choose a country (start typing to search):",options=[""] + self.unsdg_countries)
            if self.user_controls['selected_usecase']=='DeepRishSearch':
                self.user_controls["TAVILY_API_KEY"] = st.session_state["TAVILY_API_KEY"]=st.text_input("Tavily API Key",type="password") 
                if not self.user_controls["TAVILY_API_KEY"]:
                    st.warning("⚠️ Please enter a Tavily API key to proceed. Don't have? refer : https://www.tavily.com/")

        return self.user_controls
def display_result_on_ui( state_graph,mode="Question Refining Mode:"):
	if mode =="Question Refining Mode:":
			for event in graph.stream({'messages':("user",user_message)}):
				print(event.values())
				for value in event.values():
					print(value['messages'])
					with st.chat_message("user"):
						st.write(user_message)
					with st.chat_message("SDG AI Assistant"):
						st.write(value["messages"].content)

if __name__=='__main__':
    ui=LoadStreamlitUI()
    user_input=ui.load_streamlit_ui()
    LLM_Selection=ModelSelection(user_input)
    if user_input["GENAI_API_KEY"]:llm=LLM_Selection.setup_llm_model()
    loaded_tokenizer = AutoTokenizer.from_pretrained('src/train_bert/topic_classifier_model')
    loaded_model = AutoModelForSequenceClassification.from_pretrained('src/train_bert/topic_classifier_model',device_map=device)#.to_empty(device=device)
    df_keys=pd.read_csv('src/train_bert/training_data/Keyword_Patterns.csv')
	
    if not user_input:
        st.error("Error: Failed to load user input from the UI.")
    #Input prompt 
    user_message = st.chat_input("Ask me a question or give me keywords to kick off a query about a UN SDG Goal:")
    
    if user_message and user_input['UN SDG Country']:
        state=StateVector(country=user_input['UN SDG Country'], seed_question=user_message, messages=[])
        if user_input['selected_usecase']=='AskSmart SDG Assistant':
            SmartQuestions=question_model(#StateVector=state,
                                    loaded_tokenizer=loaded_tokenizer,
                                    loaded_model=loaded_model, llm=llm, df_keys=df_keys)
            builder=BuildGraphOptions(SmartQuestions)
            graph=builder.build_question_graph()
            with st.chat_message("assistant"):
                intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\
                            You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\
                            I can also help you create a question related to the SDGs based on your input.\
                            Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #, 
                st.write(intro)
            initial_input = {'country': user_input['UN SDG Country'], 'seed_question': user_message}
            with st.chat_message("user"):st.write(user_message)
            with st.chat_message("assistant"):
                message_placeholder = st.empty()
                accumulated_content = ""
                for chunk in graph.stream(initial_input, stream_mode='messages'):
                    message, meta=chunk
                    if isinstance(message, AIMessage):
                            accumulated_content += message.content
                            message_placeholder.write(accumulated_content)
        elif user_input['selected_usecase']=='DeepRishSearch': 
            print("Deep Rishsearch")
            SmartQuestions=question_model(#StateVector=state,
                                    loaded_tokenizer=loaded_tokenizer,
                                    loaded_model=loaded_model, llm=llm, df_keys=df_keys)
            builder=BuildGraphOptions(SmartQuestions)
            ResearchModel=research_model(llm=SmartQuestions.genai_model, tavily_api_key=user_input['TAVILY_API_KEY'])
            graph=builder.build_research_graph(ResearchModel)
            with st.chat_message("assistant"):
                intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\
                            You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\
                            I can also help you create a question related to the SDGs based on your input.\
                            Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #, 
                st.write(intro)
            with st.chat_message("user"):st.write(user_message)
            initial_input = StateVector({'country': user_input['UN SDG Country'], 'seed_question': user_message})
            with st.chat_message("assistant"):
                message_placeholder = st.empty()
                accumulated_content = ""
                for chunk in graph.stream(initial_input, stream_mode='messages'):
                    message, meta=chunk
                    if isinstance(message, AIMessage):
                            accumulated_content += message.content
                            message_placeholder.write(accumulated_content)