Spaces:
Sleeping
Sleeping
| import os | |
| import openai | |
| from dotenv import load_dotenv | |
| _ = load_dotenv() # read local .env file | |
| import gradio as gr | |
| from langchain_chroma import Chroma | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| # Custom class to handle API routing for different models | |
| class ChatOpenRouter(ChatOpenAI): | |
| openai_api_base: str | |
| openai_api_key: str | |
| model_name: str | |
| def __init__(self, | |
| model_name: str, | |
| openai_api_key: str = None, | |
| openai_api_base: str = "https://openrouter.ai/api/v1", | |
| **kwargs): | |
| openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY') | |
| super().__init__(openai_api_base=openai_api_base, | |
| openai_api_key=openai_api_key, | |
| model_name=model_name, **kwargs) | |
| # Initialize embedding function here | |
| embedding_function = OpenAIEmbeddings() | |
| # Updated cbfs class with dynamic database and model selection | |
| class cbfs: | |
| def __init__(self, persist_directory, model_name): | |
| self.chat_history = [] | |
| self.answer = "" | |
| self.db_query = "" | |
| self.db_response = [] | |
| self.panels = [] | |
| # Initialize Chroma and the ConversationalRetrievalChain with the chosen database and model | |
| db = Chroma(persist_directory=persist_directory, embedding_function=embedding_function) | |
| retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| # Select model dynamically | |
| if model_name == "GPT-4": | |
| chosen_llm = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0) | |
| elif model_name == "GPT-3.5": | |
| chosen_llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", temperature=0) | |
| elif model_name == "Llama-3 8B": | |
| chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0) | |
| elif model_name == "Gemini-1.5 Pro": | |
| chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0) | |
| elif model_name == "Claude 3 Sonnet": | |
| chosen_llm = ChatOpenRouter(model_name='anthropic/claude-3-sonnet', temperature=0) | |
| elif model_name == "Claude 3.5 Sonnet": | |
| chosen_llm = ChatOpenRouter(model_name='anthropic/claude-3.5-sonnet', temperature=0) | |
| else: | |
| # Default model | |
| chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0) | |
| # chosen_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
| self.qa = ConversationalRetrievalChain.from_llm( | |
| llm=chosen_llm, | |
| retriever=retriever, | |
| return_source_documents=True, | |
| return_generated_question=True, | |
| ) | |
| def convchain(self, query): | |
| if not query: | |
| return [("User", ""), ("ChatBot", "")] | |
| result = self.qa.invoke({"question": query, "chat_history": self.chat_history}) | |
| self.chat_history.append((query, result["answer"])) | |
| self.db_query = result["generated_question"] | |
| self.db_response = result["source_documents"] | |
| self.answer = result['answer'] | |
| self.panels.append(["User", query]) # Ensure this is a list of two strings | |
| self.panels.append(["ChatBot", self.answer]) # Ensure this is a list of two strings | |
| return self.panels | |
| def clr_history(self): | |
| self.chat_history = [] | |
| self.panels = [] | |
| return self.panels # Clear the chatbot display | |
| # Create Gradio interface functions | |
| def initialize_cbfs(db_choice, model_choice): | |
| """Initialize cbfs object based on the database and model selection and clear history.""" | |
| if db_choice == "Governance Documents": | |
| return cbfs(persist_directory='docs/chroma_eg/', model_name=model_choice) | |
| elif db_choice == "Faculty Handbook": | |
| return cbfs(persist_directory='docs/chroma_hb/', model_name=model_choice) | |
| else: | |
| return None | |
| def chat_history(query, db_choice, model_choice, cb): | |
| """Handles chat submissions. Reminds the user to select a document if none is selected.""" | |
| # cb = initialize_cbfs(db_choice, model_choice) # Reinitialize cbfs | |
| if cb is None: # If cb is not initialized, remind to select a document | |
| return [("ChatBot", "Please select a document from the dropdown menu before submitting your query.")], "" | |
| else: | |
| return cb.convchain(query), "" # Clear input box by returning empty string | |
| def clear_history(cb): | |
| # cb = initialize_cbfs(db_choice, model_choice) # Reinitialize cbfs to clear history | |
| if cb is None: # Check if cbfs instance is None | |
| return [], "" # No error message, simply clear the UI components | |
| else: | |
| cb.clr_history() | |
| return [], "" | |
| # Create Gradio UI layout | |
| with gr.Blocks() as demo: | |
| # Full-width image at the top | |
| with gr.Row(): | |
| gr.Image("isu_logo.jpg", elem_id="full_width_image", show_label=False) | |
| # Full-width text below the image | |
| with gr.Row(): | |
| gr.Markdown("<h1 style='text-align: center; font-size: 3.5em;'>Department of Economics</h1>") | |
| gr.Markdown("# Faculty Policies & Rules ChatBot") | |
| with gr.Row(): | |
| db_choice = gr.Dropdown(["Governance Documents", "Faculty Handbook"], label="Select Document", scale=1) | |
| model_choice = gr.Dropdown(["GPT-3.5", "GPT-4", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Sonnet", "Claude 3.5 Sonnet"], | |
| label="Select Model", scale=1, value = "Llama-3 70B") | |
| button_clearhistory = gr.Button("Clear History", scale=1) | |
| with gr.Row(): | |
| inp = gr.Textbox(placeholder="Enter text here…", scale=8) | |
| button_submit = gr.Button("Submit", scale=1) | |
| output = gr.Chatbot() | |
| # Initialize cbfs instance | |
| cbfs_instance = gr.State(initialize_cbfs(db_choice.value, model_choice.value)) | |
| # Update cbfs_instance and clear chat history when the dropdown values change | |
| def update_cbfs_and_clear_history(db_choice, model_choice): | |
| new_cbfs = initialize_cbfs(db_choice, model_choice) | |
| if new_cbfs: | |
| new_cbfs.clr_history() | |
| return new_cbfs, [], "" # Clear the chatbot display and input box | |
| db_choice.change( | |
| fn=update_cbfs_and_clear_history, | |
| inputs=[db_choice, model_choice], | |
| outputs=[cbfs_instance, output, inp] | |
| ) | |
| model_choice.change( | |
| fn=update_cbfs_and_clear_history, | |
| inputs=[db_choice, model_choice], | |
| outputs=[cbfs_instance, output, inp] | |
| ) | |
| # Define interactions for both submit button and Enter key | |
| inp.submit(fn=chat_history, inputs=[inp, db_choice, model_choice, cbfs_instance], outputs=[output, inp]) | |
| button_submit.click(fn=chat_history, inputs=[inp, db_choice, model_choice, cbfs_instance], outputs=[output, inp]) | |
| button_clearhistory.click(fn=clear_history, inputs=cbfs_instance, outputs=[output, inp]) | |
| # Launch the Gradio app | |
| demo.launch() | |