Spaces:
Sleeping
Sleeping
File size: 8,607 Bytes
713a9e8 3a6d542 c061ced 713a9e8 f232eef a4d8e6f c061ced 713a9e8 a4d8e6f 52be2ce a4d8e6f ba78ba8 8c89aac 833435e ba78ba8 713a9e8 c061ced 713a9e8 2398527 713a9e8 c061ced 713a9e8 c061ced 713a9e8 c061ced 713a9e8 9905f36 833435e 713a9e8 a4d8e6f c061ced 713a9e8 a4d8e6f e036ad8 232133b a4d8e6f 713a9e8 a4d8e6f 52be2ce 9905f36 52be2ce a4d8e6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import configparser
import altair as alt
import streamlit as st
from typing import List, Optional
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from langchain_core.messages import AnyMessage, AIMessage,SystemMessage, HumanMessage,AIMessageChunk
from streamlitui.constants import unsdg_countries
from llm.llm_setup import ModelSelection
import pandas as pd
from state.state import StateVector
from graph.state_vector_nodes import question_model,research_model
from graph.graph_builder import BuildGraphOptions
import re
import os
import torch
device=torch.get_default_device()
torch.classes.__path__ = []
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class StreamlitConfigUI:
"""
A Streamlit UI class that uses ConfigParser to load settings from an INI file.
"""
def __init__(self, config_file: str = "src/streamlitui/uiconfigfile.ini"):
"""
Initialize the UI class with configuration file.
Args:
config_file (str): Path to the INI configuration file
"""
self.config_file = config_file
self.config = configparser.ConfigParser()
self.config.read(config_file)
def get_llm_options(self):
return self.config["DEFAULT"].get("LLM_OPTIONS").split(",")
def get_page_title(self):
return self.config["DEFAULT"].get("PAGE_TITLE")
def get_usecase_options(self):
return self.config["DEFAULT"].get("USE_CASE_OPTIONS").split(",")
class LoadStreamlitUI:
def __init__(self):
self.config=StreamlitConfigUI()
self.user_controls={}
self.unsdg_countries=unsdg_countries
def filter_countries(self, query: str) -> List[str]:
"""
Filter countries based on the query string.
Args:
query (str): The search query
Returns:
List[str]: Filtered list of countries
"""
if not query:
return self.unsdg_countries
query_lower = query.lower()
return [
country for country in self.unsdg_countries
if query_lower in country.lower()
]
def load_streamlit_ui(self):
st.set_page_config(page_title= "🇺🇳 " + self.config.get_page_title(), layout="wide")
st.header("🇺🇳 " + self.config.get_page_title())
with st.sidebar:
# Get options from config
llm_options = self.config.get_llm_options()
usecase_options = self.config.get_usecase_options()
# LLM selection
self.user_controls["selected_llm"] = st.selectbox("Select LLM", llm_options)
self.user_controls["GENAI_API_KEY"] = st.session_state["GENAI_API_KEY"]=st.text_input("API Key",type="password")
if not self.user_controls["GENAI_API_KEY"]:
st.warning("⚠️ Please enter a Gemini or Open AI ChatGPT API key to proceed. Don't have? refer : https://platform.openai.com/api-keys or https://aistudio.google.com/welcome?gclsrc=aw.ds&gad_source=1&gad_campaignid=21521909442 ")
self.user_controls["selected_usecase"]=st.selectbox("Select Usecases",usecase_options)
self.user_controls['UN SDG Country']= st.selectbox("Choose a country (start typing to search):",options=[""] + self.unsdg_countries)
if self.user_controls['selected_usecase']=='DeepRishSearch':
self.user_controls["TAVILY_API_KEY"] = st.session_state["TAVILY_API_KEY"]=st.text_input("Tavily API Key",type="password")
if not self.user_controls["TAVILY_API_KEY"]:
st.warning("⚠️ Please enter a Tavily API key to proceed. Don't have? refer : https://www.tavily.com/")
return self.user_controls
def display_result_on_ui( state_graph,mode="Question Refining Mode:"):
if mode =="Question Refining Mode:":
for event in graph.stream({'messages':("user",user_message)}):
print(event.values())
for value in event.values():
print(value['messages'])
with st.chat_message("user"):
st.write(user_message)
with st.chat_message("SDG AI Assistant"):
st.write(value["messages"].content)
if __name__=='__main__':
ui=LoadStreamlitUI()
user_input=ui.load_streamlit_ui()
LLM_Selection=ModelSelection(user_input)
if user_input["GENAI_API_KEY"]:llm=LLM_Selection.setup_llm_model()
loaded_tokenizer = AutoTokenizer.from_pretrained('src/train_bert/topic_classifier_model')
loaded_model = AutoModelForSequenceClassification.from_pretrained('src/train_bert/topic_classifier_model',device_map=device)#.to_empty(device=device)
df_keys=pd.read_csv('src/train_bert/training_data/Keyword_Patterns.csv')
if not user_input:
st.error("Error: Failed to load user input from the UI.")
#Input prompt
user_message = st.chat_input("Ask me a question or give me keywords to kick off a query about a UN SDG Goal:")
if user_message and user_input['UN SDG Country']:
state=StateVector(country=user_input['UN SDG Country'], seed_question=user_message, messages=[])
if user_input['selected_usecase']=='AskSmart SDG Assistant':
SmartQuestions=question_model(#StateVector=state,
loaded_tokenizer=loaded_tokenizer,
loaded_model=loaded_model, llm=llm, df_keys=df_keys)
builder=BuildGraphOptions(SmartQuestions)
graph=builder.build_question_graph()
with st.chat_message("assistant"):
intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\
You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\
I can also help you create a question related to the SDGs based on your input.\
Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #,
st.write(intro)
initial_input = {'country': user_input['UN SDG Country'], 'seed_question': user_message}
with st.chat_message("user"):st.write(user_message)
with st.chat_message("assistant"):
message_placeholder = st.empty()
accumulated_content = ""
for chunk in graph.stream(initial_input, stream_mode='messages'):
message, meta=chunk
if isinstance(message, AIMessage):
accumulated_content += message.content
message_placeholder.write(accumulated_content)
elif user_input['selected_usecase']=='DeepRishSearch':
print("Deep Rishsearch")
SmartQuestions=question_model(#StateVector=state,
loaded_tokenizer=loaded_tokenizer,
loaded_model=loaded_model, llm=llm, df_keys=df_keys)
builder=BuildGraphOptions(SmartQuestions)
ResearchModel=research_model(llm=SmartQuestions.genai_model, tavily_api_key=user_input['TAVILY_API_KEY'])
graph=builder.build_research_graph(ResearchModel)
with st.chat_message("assistant"):
intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\
You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\
I can also help you create a question related to the SDGs based on your input.\
Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #,
st.write(intro)
with st.chat_message("user"):st.write(user_message)
initial_input = StateVector({'country': user_input['UN SDG Country'], 'seed_question': user_message})
with st.chat_message("assistant"):
message_placeholder = st.empty()
accumulated_content = ""
for chunk in graph.stream(initial_input, stream_mode='messages'):
message, meta=chunk
if isinstance(message, AIMessage):
accumulated_content += message.content
message_placeholder.write(accumulated_content)
|