| | import os |
| | import glob |
| | import warnings |
| | import gradio as gr |
| |
|
| | from langchain_community.vectorstores import Chroma |
| | from langchain_community.document_loaders.csv_loader import CSVLoader |
| | from langchain_community.document_loaders import Docx2txtLoader, TextLoader, PyPDFLoader |
| |
|
| | from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter |
| | from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings |
| | from huggingface_hub import snapshot_download, upload_folder |
| |
|
| | from langchain.tools import tool |
| | from langchain.agents import create_agent |
| | from langchain.agents.middleware import dynamic_prompt, ModelRequest |
| |
|
| | snapshot_download(repo_id="CGIAR/weai-refs", |
| | repo_type="dataset", |
| | token=os.getenv('HF_TOKEN'), |
| | local_dir='./refs' |
| | ) |
| |
|
| | snapshot_download(repo_id="CGIAR/weai-docs", |
| | repo_type="dataset", |
| | token=os.getenv('HF_TOKEN'), |
| | local_dir='./docs' |
| | ) |
| | warnings.filterwarnings('ignore') |
| | os.environ["WANDB_DISABLED"] = "true" |
| |
|
| | repo_id = "meta-llama/Llama-3.3-70B-Instruct" |
| |
|
| | model = HuggingFaceEndpoint( |
| | task='conversational', |
| | repo_id = repo_id, |
| | temperature = 0.5, |
| | huggingfacehub_api_token=os.getenv('HF_TOKEN'), |
| | max_new_tokens = 1500, |
| | server_kwargs={"bill_to":"cgiar"} |
| | ) |
| |
|
| | chat_llm = ChatHuggingFace(llm=model, verbose=True) |
| |
|
| | model_name = "sentence-transformers/all-mpnet-base-v2" |
| | model_kwargs = {"device": "cuda"} |
| |
|
| | embeddings = HuggingFaceEmbeddings(model_name=model_name) |
| |
|
| | def docs_return(directory_path, flag): |
| | docx_file_pattern = '*.docx' |
| | pdf_file_pattern = '*.pdf' |
| | txt_file_pattern = '*.txt' |
| |
|
| | docx_file_paths = glob.glob(directory_path + docx_file_pattern) |
| | pdf_file_paths = glob.glob(directory_path + pdf_file_pattern) |
| | txt_file_paths = glob.glob(directory_path + txt_file_pattern) |
| |
|
| | all_doc, all_doc2 = [], [] |
| |
|
| | for x in docx_file_paths: |
| | loader = Docx2txtLoader(x) |
| | documents = loader.load() |
| | all_doc.extend(documents) |
| | all_doc2.append(str(documents[0].page_content)) |
| |
|
| | for x in pdf_file_paths: |
| | loader = PyPDFLoader(x, extract_images=True) |
| | docs_lazy = loader.lazy_load() |
| | documents = [] |
| | for doc in docs_lazy: |
| | documents.append(doc) |
| | all_doc.extend(documents) |
| | all_doc2.append(str(documents[0].page_content)) |
| |
|
| | for x in txt_file_paths: |
| | loader = TextLoader(x) |
| | documents = loader.load() |
| | all_doc.extend(documents) |
| | all_doc2.append(str(documents[0].page_content)) |
| |
|
| | docs = '\n\n'.join(all_doc2) |
| |
|
| | return all_doc if flag == 0 else docs |
| |
|
| | def get_text_splitter(splitter_type='character', |
| | chunk_size=500, |
| | chunk_overlap=30, |
| | separator="\n", |
| | max_tokens=1000): |
| | if splitter_type == 'character': |
| | return CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator) |
| | elif splitter_type == 'recursive': |
| | return RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
| | elif splitter_type == 'token': |
| | return TokenTextSplitter(chunk_size=max_tokens, chunk_overlap=chunk_overlap) |
| | else: |
| | raise ValueError("Unsupported splitter type. Choose from 'character', 'recursive', or 'token'.") |
| |
|
| | splitter_type='character' |
| | chunk_size=1500 |
| | chunk_overlap=30 |
| | separator="\n" |
| | max_tokens=1000 |
| | docs_path = "./docs/" |
| |
|
| | all_doc = docs_return(docs_path, 0) |
| |
|
| | |
| | text_splitter = get_text_splitter(splitter_type=splitter_type, |
| | chunk_size=chunk_size, |
| | chunk_overlap=chunk_overlap, |
| | separator=separator, |
| | max_tokens=max_tokens) |
| |
|
| | |
| | docs = text_splitter.split_documents(documents=all_doc) |
| |
|
| | |
| | docs_vector_db = Chroma.from_documents(docs, embeddings, persist_directory="chroma_data") |
| |
|
| | REFS_CSV_PATH = "./refs/WEAI reference list - Sheet1.csv" |
| | REFS_CHROMA_PATH = "./chroma_data" |
| |
|
| | loader = CSVLoader(file_path=REFS_CSV_PATH, |
| | source_column="Description (what it contains and what it's useful for)") |
| | refs = loader.load() |
| |
|
| | refs_vector_db = Chroma.from_documents( |
| | refs, embeddings, persist_directory=REFS_CHROMA_PATH |
| | ) |
| |
|
| | @dynamic_prompt |
| | def ref_context(request: ModelRequest) -> str: |
| | """Inject context into state messages.""" |
| | last_query = request.state["messages"][-1].text |
| | ref_content = refs_vector_db.as_retriever(k=10) |
| |
|
| | system_message = ( |
| | """Your job is to use relevant links and email addresses to |
| | direct users to in order to reach and contact the WEAI team. Do not use links |
| | or contacts not provided in the context.If you don't know |
| | an answer, say you don't know. Do not state that you are referring to the |
| | provided context and respond as if you were in charge of the WEAI helpdesk.""" |
| | f"\n\n{ref_content}" |
| | ) |
| |
|
| | return system_message |
| |
|
| | contact_agent = (create_agent(chat_llm, tools=[], middleware=[ref_context])) |
| |
|
| | @tool("contact", description="refer users to WEAI team using links and contact details") |
| | def call_contact_agent(query: str): |
| | result = contact_agent.invoke({"messages": [{"role": "user", "content": query}]}) |
| | return result["messages"][-1].content |
| | |
| |
|
| | @dynamic_prompt |
| | def doc_context(request: ModelRequest) -> str: |
| | """Inject context into state messages.""" |
| | last_query = request.state["messages"][-1].text |
| | doc_content = docs_vector_db.as_retriever(k=10) |
| |
|
| | system_message = ( |
| | """You are a user support agent helping with queries related to the Women's Empowerment in Agriculture Index (WEAI). |
| | Use the following context to answer questions. |
| | Be as detailed as possible, but don't make up any information that's not from the context and where possible reference related studies and resources |
| | from the context you have. |
| | Use complete paper or article details such as authors, title, publication date, and webpage link. |
| | Do not use publication information not provided in the context and do not combine publication information to make up details. |
| | Use complete information as referenced in the context. Do not overexplain concepts that already have a resource or reference, |
| | first try to point users to existing resources, tools, or references and only add a brief explanation if necessary. |
| | Focus first on WEAI resources before recommending resources from the general IFPRI website. |
| | If you don't know an answer, say you don't know. Be concise but thorough in your response and try not to exceed the output token limit. |
| | Do not state that you are referring to the provided context and respond |
| | as if you were in charge of the WEAI helpdesk. """ |
| | f"\n\n{doc_content}" |
| | ) |
| |
|
| | return system_message |
| |
|
| | support_agent = (create_agent(chat_llm, tools=[], middleware=[doc_context])) |
| |
|
| |
|
| | @tool("support", description="respond to user queries using context in WEAI docs") |
| | def call_support_agent(query: str): |
| | result = support_agent.invoke({"messages": [{"role": "user", "content": query}]}) |
| | return result["messages"][-1].content |
| | |
| | support_instructions = """ |
| | You are in charge of the WEAI helpdesk. |
| | Your job is to answer user queries using provided context and references |
| | and refer users to WEAI personnel as well as relevant resource links where necessary. |
| | |
| | Steps: |
| | 1. Use the support tool to answer queries to the best of your knowledge. |
| | 2. If no contact information or links are provided in the response, use the |
| | contact tool to add all relevant contact and resource information to the response. |
| | 3. Return only a complete response with included contact and resource information. |
| | """ |
| |
|
| | response_agent = create_agent(model=chat_llm, |
| | tools=[call_contact_agent, call_support_agent], |
| | system_prompt=support_instructions, |
| | ) |
| | """ |
| | For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
| | """ |
| | with gr.Blocks() as demo: |
| | with gr.Sidebar(): |
| | gr.LoginButton() |
| | gr.Markdown("# WEAI-bot") |
| | chatbot = gr.Chatbot(type='messages', |
| | allow_tags=True) |
| | msg = gr.Textbox() |
| | clear = gr.ClearButton([msg, chatbot]) |
| |
|
| | def handle_undo(history, undo_data: gr.UndoData): |
| | return history[:undo_data.index], history[undo_data.index]['content'][0]["text"] |
| |
|
| | def handle_retry(history, retry_data: gr.RetryData): |
| | new_history = history[:retry_data.index] |
| | previous_prompt = history[retry_data.index]['content'][0]["text"] |
| | yield from support_agent_fn(previous_prompt, new_history) |
| |
|
| | def support_agent_fn(message, history): |
| | result = support_agent.invoke({"messages": [{"role": "user", "content": message}]}) |
| |
|
| | response = result['messages'][-1].content |
| | history.append({"role": "user", "content": message}) |
| | history.append({"role": "assistant", "content": response}) |
| |
|
| | return "", history |
| | |
| | def handle_like(data: gr.LikeData): |
| | if data.liked: |
| | print("You upvoted this response: ", data.value) |
| | else: |
| | print("You downvoted this response: ", data.value) |
| |
|
| | def handle_edit(history, edit_data: gr.EditData): |
| | new_history = history[:edit_data.index] |
| | new_history[-1]['content'] = [{"text": edit_data.value, "type": "text"}] |
| | return new_history |
| |
|
| | msg.submit(support_agent_fn, [msg, chatbot], [msg, chatbot]) |
| |
|
| | chatbot.undo(handle_undo, chatbot, [chatbot, msg]) |
| | chatbot.retry(handle_retry, chatbot, chatbot) |
| | chatbot.like(handle_like, None, None) |
| | chatbot.edit(handle_edit, chatbot, chatbot) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|