rajeshlion commited on
Commit
428d7b8
·
verified ·
1 Parent(s): 83af442

Create app_gpu2.py

Browse files
Files changed (1) hide show
  1. app_gpu2.py +150 -0
app_gpu2.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ from dotenv import load_dotenv
4
+ _ = load_dotenv() # read local .env file
5
+
6
+ import gradio as gr
7
+ from langchain_chroma import Chroma
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
10
+ from spaces import GPU # Import GPU decorator
11
+
12
+ # Custom class to handle API routing for different models
13
+ class ChatOpenRouter(ChatOpenAI):
14
+ openai_api_base: str
15
+ openai_api_key: str
16
+ model_name: str
17
+
18
+ def __init__(self, model_name: str, openai_api_key: str = None, openai_api_base: str = "https://openrouter.ai/api/v1", **kwargs):
19
+ openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY')
20
+ super().__init__(openai_api_base=openai_api_base, openai_api_key=openai_api_key, model_name=model_name, **kwargs)
21
+
22
+ # Initialize embedding function here
23
+ embedding_function = OpenAIEmbeddings()
24
+
25
+ # Updated cbfs class with dynamic database and model selection
26
+ @GPU # Ensure this class runs on the GPU
27
+ class cbfs:
28
+ def __init__(self, persist_directory, model_name):
29
+ self.chat_history = []
30
+ self.answer = ""
31
+ self.db_query = ""
32
+ self.db_response = []
33
+ self.panels = []
34
+ # Initialize Chroma and the ConversationalRetrievalChain with the chosen database and model
35
+ db = Chroma(persist_directory=persist_directory, embedding_function=embedding_function)
36
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
37
+
38
+ # Select model dynamically
39
+ if model_name == "GPT-4":
40
+ chosen_llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0)
41
+ elif model_name == "GPT-3.5":
42
+ chosen_llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", temperature=0)
43
+ elif model_name == "Llama-3 8B":
44
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0)
45
+ elif model_name == "Gemini-1.5 Pro":
46
+ chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0)
47
+ elif model_name == "Claude 3 Sonnet":
48
+ chosen_llm = ChatOpenRouter(model_name='anthropic/claude-3-sonnet', temperature=0)
49
+ elif model_name == "Claude 3.5 Sonnet":
50
+ chosen_llm = ChatOpenRouter(model_name='anthropic/claude-3.5-sonnet', temperature=0)
51
+ else:
52
+ # Default model
53
+ chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0)
54
+
55
+ self.qa = ConversationalRetrievalChain.from_llm(
56
+ llm=chosen_llm,
57
+ retriever=retriever,
58
+ return_source_documents=True,
59
+ return_generated_question=True,
60
+ )
61
+
62
+ def convchain(self, query):
63
+ if not query:
64
+ return [("User", ""), ("ChatBot", "")]
65
+ result = self.qa.invoke({"question": query, "chat_history": self.chat_history})
66
+ self.chat_history.append((query, result["answer"]))
67
+ self.db_query = result["generated_question"]
68
+ self.db_response = result["source_documents"]
69
+ self.answer = result['answer']
70
+ self.panels.append(["User", query]) # Ensure this is a list of two strings
71
+ self.panels.append(["ChatBot", self.answer]) # Ensure this is a list of two strings
72
+ return self.panels
73
+
74
+ def clr_history(self):
75
+ self.chat_history = []
76
+ self.panels = []
77
+ return self.panels # Clear the chatbot display
78
+
79
+
80
+ # Create Gradio interface functions
81
+ cbfs_instances = {} # Dictionary to store instances based on db_choice and model_choice
82
+
83
+ def initialize_cbfs(db_choice, model_choice):
84
+ """Initialize cbfs object based on the database and model selection and clear history."""
85
+ key = (db_choice, model_choice)
86
+ if key not in cbfs_instances:
87
+ if db_choice == "Governance Documents":
88
+ cbfs_instances[key] = cbfs(persist_directory='docs/chroma_eg/', model_name=model_choice)
89
+ elif db_choice == "Faculty Handbook":
90
+ cbfs_instances[key] = cbfs(persist_directory='docs/chroma_hb/', model_name=model_choice)
91
+ return cbfs_instances[key]
92
+
93
+ def chat_history(query, db_choice, model_choice):
94
+ """Handles chat submissions. Reminds the user to select a document if none is selected."""
95
+ cb = initialize_cbfs(db_choice, model_choice)
96
+ if not cb: # If cb is not initialized, remind to select a document
97
+ return [("ChatBot", "Please select a document from the dropdown menu before submitting your query.")], ""
98
+ else:
99
+ return cb.convchain(query), "" # Clear input box by returning empty string
100
+
101
+ def clear_history(db_choice, model_choice):
102
+ cb = initialize_cbfs(db_choice, model_choice)
103
+ cb.clr_history()
104
+ return [], ""
105
+
106
+
107
+ # Create Gradio UI layout
108
+ with gr.Blocks() as demo:
109
+ # Full-width image at the top
110
+ with gr.Row():
111
+ gr.Image("isu_logo.jpg", elem_id="full_width_image", show_label=False)
112
+
113
+ # Full-width text below the image
114
+ with gr.Row():
115
+ gr.Markdown("<h1 style='text-align: center; font-size: 3.5em;'>Department of Economics</h1>")
116
+
117
+ gr.Markdown("# Faculty Policies & Rules ChatBot")
118
+
119
+ with gr.Row():
120
+ db_choice = gr.Dropdown(["Governance Documents", "Faculty Handbook"], label="Select Document", scale=1)
121
+ model_choice = gr.Dropdown(["GPT-3.5", "GPT-4", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Sonnet", "Claude 3.5 Sonnet"],
122
+ label="Select Model", scale=1, value="Llama-3 70B")
123
+ button_clearhistory = gr.Button("Clear History", scale=1)
124
+
125
+ with gr.Row():
126
+ inp = gr.Textbox(placeholder="Enter text here…", scale=8)
127
+ button_submit = gr.Button("Submit", scale=1)
128
+
129
+ output = gr.Chatbot()
130
+
131
+ # Update cbfs_instance and clear chat history when the dropdown values change
132
+ db_choice.change(
133
+ fn=clear_history,
134
+ inputs=[db_choice, model_choice],
135
+ outputs=[output, inp]
136
+ )
137
+
138
+ model_choice.change(
139
+ fn=clear_history,
140
+ inputs=[db_choice, model_choice],
141
+ outputs=[output, inp]
142
+ )
143
+
144
+ # Define interactions for both submit button and Enter key
145
+ inp.submit(fn=chat_history, inputs=[inp, db_choice, model_choice], outputs=[output, inp])
146
+ button_submit.click(fn=chat_history, inputs=[inp, db_choice, model_choice], outputs=[output, inp])
147
+ button_clearhistory.click(fn=clear_history, inputs=[db_choice, model_choice], outputs=[output, inp])
148
+
149
+ # Launch the Gradio app
150
+ demo.launch()