rajeshlion commited on
Commit
e9930be
·
verified ·
1 Parent(s): e929a77

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import openai
4
+ from dotenv import load_dotenv
5
+ _ = load_dotenv() # read local .env file
6
+
7
+ import gradio as gr
8
+ from langchain_chroma import Chroma
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
11
+
12
+ # Custom class to handle API routing for different models
13
+ class ChatOpenRouter(ChatOpenAI):
14
+ openai_api_base: str
15
+ openai_api_key: str
16
+ model_name: str
17
+
18
+ def __init__(self,
19
+ model_name: str,
20
+ openai_api_key: str = None,
21
+ openai_api_base: str = "https://openrouter.ai/api/v1",
22
+ **kwargs):
23
+ openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY')
24
+ super().__init__(openai_api_base=openai_api_base,
25
+ openai_api_key=openai_api_key,
26
+ model_name=model_name, **kwargs)
27
+
28
+ # Initialize embedding function here
29
+ embedding_function = OpenAIEmbeddings()
30
+
31
+ # Updated cbfs class with dynamic database and model selection
32
+ class cbfs:
33
+ def __init__(self, persist_directory, model_name):
34
+ self.chat_history = []
35
+ self.answer = ""
36
+ self.db_query = ""
37
+ self.db_response = []
38
+ self.panels = []
39
+ # Initialize Chroma and the ConversationalRetrievalChain with the chosen database and model
40
+ db = Chroma(persist_directory=persist_directory, embedding_function=embedding_function)
41
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
42
+
43
+ # Select model dynamically
44
+ if model_name == "GPT-4":
45
+ chosen_llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
46
+ elif model_name == "GPT-3.5":
47
+ chosen_llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", temperature=0)
48
+ elif model_name == "Llama-3 8B":
49
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0)
50
+ elif model_name == "Gemini-1.5 Pro":
51
+ chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0)
52
+ elif model_name == "Claude 3 Sonnet":
53
+ chosen_llm = ChatOpenRouter(model_name='anthropic/claude-3-sonnet', temperature=0)
54
+ elif model_name == "Claude 3.5 Sonnet":
55
+ chosen_llm = ChatOpenRouter(model_name='anthropic/claude-3.5-sonnet', temperature=0)
56
+ else:
57
+ # Default model
58
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0)
59
+ # chosen_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
60
+
61
+ self.qa = ConversationalRetrievalChain.from_llm(
62
+ llm=chosen_llm,
63
+ retriever=retriever,
64
+ return_source_documents=True,
65
+ return_generated_question=True,
66
+ )
67
+
68
+ def convchain(self, query):
69
+ if not query:
70
+ return [("User", ""), ("ChatBot", "")]
71
+ result = self.qa.invoke({"question": query, "chat_history": self.chat_history})
72
+ self.chat_history.append((query, result["answer"]))
73
+ self.db_query = result["generated_question"]
74
+ self.db_response = result["source_documents"]
75
+ self.answer = result['answer']
76
+ self.panels.append(["User", query]) # Ensure this is a list of two strings
77
+ self.panels.append(["ChatBot", self.answer]) # Ensure this is a list of two strings
78
+ return self.panels
79
+
80
+ def clr_history(self):
81
+ self.chat_history = []
82
+ self.panels = []
83
+ return self.panels # Clear the chatbot display
84
+
85
+ # Create Gradio interface functions
86
+ def initialize_cbfs(db_choice, model_choice):
87
+ """Initialize cbfs object based on the database and model selection and clear history."""
88
+ if db_choice == "Governance Documents":
89
+ return cbfs(persist_directory='docs/chroma_eg/', model_name=model_choice)
90
+ elif db_choice == "Faculty Handbook":
91
+ return cbfs(persist_directory='docs/chroma_hb/', model_name=model_choice)
92
+ else:
93
+ return None
94
+
95
+ def chat_history(query, db_choice, model_choice, cb):
96
+ """Handles chat submissions. Reminds the user to select a document if none is selected."""
97
+ # cb = initialize_cbfs(db_choice, model_choice) # Reinitialize cbfs
98
+ if cb is None: # If cb is not initialized, remind to select a document
99
+ return [("ChatBot", "Please select a document from the dropdown menu before submitting your query.")], ""
100
+ else:
101
+ return cb.convchain(query), "" # Clear input box by returning empty string
102
+
103
+ def clear_history(cb):
104
+ # cb = initialize_cbfs(db_choice, model_choice) # Reinitialize cbfs to clear history
105
+ if cb is None: # Check if cbfs instance is None
106
+ return [], "" # No error message, simply clear the UI components
107
+ else:
108
+ cb.clr_history()
109
+ return [], ""
110
+
111
+ # Create Gradio UI layout
112
+ with gr.Blocks() as demo:
113
+ # Full-width image at the top
114
+ with gr.Row():
115
+ gr.Image("isu_logo.jpg", elem_id="full_width_image", show_label=False)
116
+
117
+ # Full-width text below the image
118
+ with gr.Row():
119
+ gr.Markdown("<h1 style='text-align: center; font-size: 3.5em;'>Department of Economics</h1>")
120
+
121
+ gr.Markdown("# Faculty Policies & Rules ChatBot")
122
+
123
+ with gr.Row():
124
+ db_choice = gr.Dropdown(["Governance Documents", "Faculty Handbook"], label="Select Document", scale=1)
125
+ model_choice = gr.Dropdown(["GPT-3.5", "GPT-4", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Sonnet", "Claude 3.5 Sonnet"],
126
+ label="Select Model", scale=1, value = "Llama-3 70B")
127
+ button_clearhistory = gr.Button("Clear History", scale=1)
128
+
129
+ with gr.Row():
130
+ inp = gr.Textbox(placeholder="Enter text here…", scale=8)
131
+ button_submit = gr.Button("Submit", scale=1)
132
+
133
+ output = gr.Chatbot()
134
+
135
+ # Initialize cbfs instance
136
+ cbfs_instance = gr.State(initialize_cbfs(db_choice.value, model_choice.value))
137
+
138
+ # Update cbfs_instance and clear chat history when the dropdown values change
139
+ def update_cbfs_and_clear_history(db_choice, model_choice):
140
+ new_cbfs = initialize_cbfs(db_choice, model_choice)
141
+ if new_cbfs:
142
+ new_cbfs.clr_history()
143
+ return new_cbfs, [], "" # Clear the chatbot display and input box
144
+
145
+ db_choice.change(
146
+ fn=update_cbfs_and_clear_history,
147
+ inputs=[db_choice, model_choice],
148
+ outputs=[cbfs_instance, output, inp]
149
+ )
150
+
151
+ model_choice.change(
152
+ fn=update_cbfs_and_clear_history,
153
+ inputs=[db_choice, model_choice],
154
+ outputs=[cbfs_instance, output, inp]
155
+ )
156
+
157
+ # Define interactions for both submit button and Enter key
158
+ inp.submit(fn=chat_history, inputs=[inp, db_choice, model_choice, cbfs_instance], outputs=[output, inp])
159
+ button_submit.click(fn=chat_history, inputs=[inp, db_choice, model_choice, cbfs_instance], outputs=[output, inp])
160
+ button_clearhistory.click(fn=clear_history, inputs=cbfs_instance, outputs=[output, inp])
161
+
162
+
163
+
164
+ # Launch the Gradio app
165
+ demo.launch()
166
+
167
+