Spaces:
Runtime error
Runtime error
Commit ·
4ff740e
1
Parent(s): ac18860
Upload 3 files
Browse files- Dockerfile +11 -0
- legalminds.py +164 -0
- requirements.txt +17 -0
Dockerfile
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11
|
| 2 |
+
RUN useradd -m -u 1000 user
|
| 3 |
+
USER user
|
| 4 |
+
ENV HOME=/home/user \
|
| 5 |
+
PATH=/home/user/.local/bin:$PATH
|
| 6 |
+
WORKDIR $HOME/app
|
| 7 |
+
COPY --chown=user . $HOME/app
|
| 8 |
+
COPY ./requirements.txt ~/app/requirements.txt
|
| 9 |
+
RUN pip install -r requirements.txt
|
| 10 |
+
COPY . .
|
| 11 |
+
CMD ["chainlit", "run", "legalminds.py", "--port", "7860"]
|
legalminds.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
load_dotenv()
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import time
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import warnings
|
| 10 |
+
from langchain.chains import RetrievalQA
|
| 11 |
+
from langchain.callbacks import StdOutCallbackHandler
|
| 12 |
+
import chainlit as cl # importing chainlit for our app
|
| 13 |
+
from chainlit.prompt import Prompt, PromptMessage
|
| 14 |
+
from chainlit.playground.providers.openai import ChatOpenAI # importing ChatOpenAI tools
|
| 15 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 16 |
+
from langchain.vectorstores import Chroma, DeepLake
|
| 17 |
+
from langchain.embeddings import OpenAIEmbeddings
|
| 18 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 19 |
+
from langchain.document_loaders.dataframe import DataFrameLoader
|
| 20 |
+
from langchain.chat_models import ChatOpenAI
|
| 21 |
+
from langchain.chains import RetrievalQA
|
| 22 |
+
from langchain.memory import ConversationBufferMemory
|
| 23 |
+
from langchain.chains import ConversationalRetrievalChain
|
| 24 |
+
from langchain.llms.openai import OpenAIChat
|
| 25 |
+
from langchain.agents.agent_toolkits import create_retriever_tool
|
| 26 |
+
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
|
| 27 |
+
from langchain.utilities import SerpAPIWrapper
|
| 28 |
+
from langchain.agents import load_tools
|
| 29 |
+
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
| 30 |
+
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
| 31 |
+
from langchain.schema.messages import SystemMessage
|
| 32 |
+
from langchain.prompts import MessagesPlaceholder
|
| 33 |
+
from langchain.agents import AgentExecutor
|
| 34 |
+
|
| 35 |
+
warnings.filterwarnings("ignore")
|
| 36 |
+
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
| 37 |
+
|
| 38 |
+
# review_df = pd.read_csv("./data/justice.csv")
|
| 39 |
+
|
| 40 |
+
# data = review_df
|
| 41 |
+
|
| 42 |
+
# text_splitter = RecursiveCharacterTextSplitter(
|
| 43 |
+
# chunk_size = 7000, # the character length of the chunk
|
| 44 |
+
# chunk_overlap = 700, # the character length of the overlap between chunks
|
| 45 |
+
# length_function = len, # the length function - in this case, character length (aka the python len() fn.)
|
| 46 |
+
# )
|
| 47 |
+
|
| 48 |
+
# loader = DataFrameLoader(review_df, page_content_column="facts")
|
| 49 |
+
# base_docs = loader.load()
|
| 50 |
+
# docs = text_splitter.split_documents(base_docs)
|
| 51 |
+
|
| 52 |
+
embedder = OpenAIEmbeddings()
|
| 53 |
+
|
| 54 |
+
# This is needed for both the memory and the prompt
|
| 55 |
+
memory_key = "history"
|
| 56 |
+
# Embed and persist db
|
| 57 |
+
persist_directory = "./data/chroma"
|
| 58 |
+
|
| 59 |
+
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedder)
|
| 60 |
+
# vectorstore = DeepLake(dataset_path="./legalminds/", embedding=embedder, overwrite=True)
|
| 61 |
+
# vectorstore.add_documents(docs)
|
| 62 |
+
|
| 63 |
+
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
| 64 |
+
primary_qa_llm = ChatOpenAI(
|
| 65 |
+
model="gpt-3.5-turbo-16k",
|
| 66 |
+
temperature=0,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
retriever = vectorstore.as_retriever()
|
| 70 |
+
CUSTOM_TOOL_N_DOCS = 3 # number of retrieved docs from deep lake to consider
|
| 71 |
+
CUSTOM_TOOL_DOCS_SEPARATOR ="\n\n" # how to join together the retrieved docs to form a single string
|
| 72 |
+
|
| 73 |
+
def retrieve_n_docs_tool(query: str) -> str:
|
| 74 |
+
""" Searches for relevant documents that may contain the answer to the query."""
|
| 75 |
+
docs = retriever.get_relevant_documents(query)[:CUSTOM_TOOL_N_DOCS]
|
| 76 |
+
texts = [doc.page_content for doc in docs]
|
| 77 |
+
texts_merged = CUSTOM_TOOL_DOCS_SEPARATOR.join(texts)
|
| 78 |
+
return texts_merged
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
serp_tool = load_tools(["serpapi"])
|
| 82 |
+
# print("Serp Tool:",serp_tool[0])
|
| 83 |
+
|
| 84 |
+
data_tool = create_retriever_tool(
|
| 85 |
+
retriever,
|
| 86 |
+
"retrieve_n_docs_tool",
|
| 87 |
+
"Searches and returns documents regarding the query asked."
|
| 88 |
+
)
|
| 89 |
+
tools = [data_tool, serp_tool[0]]
|
| 90 |
+
|
| 91 |
+
# llm = OpenAIChat(model="gpt-3.5-turbo", temperature=0)
|
| 92 |
+
llm = ChatOpenAI(temperature = 0)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm)
|
| 96 |
+
system_message = SystemMessage(
|
| 97 |
+
content=(
|
| 98 |
+
"Do your best to answer the questions. "
|
| 99 |
+
"Feel free to use any tools available to look up "
|
| 100 |
+
"relevant information, only if necessary"
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
prompt = OpenAIFunctionsAgent.create_prompt(
|
| 105 |
+
system_message=system_message,
|
| 106 |
+
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)]
|
| 107 |
+
)
|
| 108 |
+
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
| 109 |
+
|
| 110 |
+
handler = StdOutCallbackHandler()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@cl.on_chat_start # marks a function that will be executed at the start of a user session
|
| 114 |
+
async def start_chat():
|
| 115 |
+
|
| 116 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=True,
|
| 117 |
+
return_intermediate_steps=True)
|
| 118 |
+
# agent_executor = create_conversational_retrieval_agent(llm, tools, verbose=True)
|
| 119 |
+
# qa_with_sources_chain = RetrievalQA.from_chain_type(
|
| 120 |
+
# llm=llm,
|
| 121 |
+
# retriever=retriever,
|
| 122 |
+
# callbacks=[handler],
|
| 123 |
+
# return_source_documents=True
|
| 124 |
+
# )
|
| 125 |
+
|
| 126 |
+
cl.user_session.set("agent", agent_executor)
|
| 127 |
+
|
| 128 |
+
@cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
|
| 129 |
+
async def main(message: str):
|
| 130 |
+
agent_executor = cl.user_session.get("agent")
|
| 131 |
+
|
| 132 |
+
# prompt = Prompt(
|
| 133 |
+
# provider=ChatOpenAI.id,
|
| 134 |
+
# messages=[
|
| 135 |
+
# PromptMessage(
|
| 136 |
+
# role="system",
|
| 137 |
+
# # template=RAQA_PROMPT_TEMPLATE,
|
| 138 |
+
# # formatted=RAQA_PROMPT_TEMPLATE,
|
| 139 |
+
# ),
|
| 140 |
+
# PromptMessage(
|
| 141 |
+
# role="user",
|
| 142 |
+
# # template=user_template,
|
| 143 |
+
# # formatted=user_template.format(input=message),
|
| 144 |
+
# ),
|
| 145 |
+
# ],
|
| 146 |
+
# inputs={"input": message},
|
| 147 |
+
# # settings=settings,
|
| 148 |
+
# )
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# result = await qa_with_sources_chain.acall({"query" : message}) #, callbacks=[cl.AsyncLangchainCallbackHandler()])
|
| 153 |
+
result = agent_executor({"input": message})
|
| 154 |
+
# print("result Dict:",result)
|
| 155 |
+
|
| 156 |
+
msg = cl.Message(content=result["output"])
|
| 157 |
+
print("message:",msg)
|
| 158 |
+
print("output message:",msg.content)
|
| 159 |
+
# Update the prompt object with the completion
|
| 160 |
+
# msg.content = result["output"]
|
| 161 |
+
# prompt.completion = msg.content
|
| 162 |
+
# msg.prompt = prompt
|
| 163 |
+
# print("message_content: ",msg.content)
|
| 164 |
+
await msg.send()
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
chainlit==0.7.0
|
| 2 |
+
numpy==1.25.2
|
| 3 |
+
openai==0.27.8
|
| 4 |
+
python-dotenv==1.0.0
|
| 5 |
+
wandb==0.15.11
|
| 6 |
+
chromadb
|
| 7 |
+
langchain
|
| 8 |
+
tiktoken
|
| 9 |
+
pandas
|
| 10 |
+
scipy
|
| 11 |
+
scikit-learn
|
| 12 |
+
ipykernel
|
| 13 |
+
matplotlib
|
| 14 |
+
plotly
|
| 15 |
+
deeplake
|
| 16 |
+
google-search-results
|
| 17 |
+
|