|
|
import os |
|
|
import shutil |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_community.document_loaders import DirectoryLoader |
|
|
from langchain_openai import OpenAIEmbeddings |
|
|
from langchain.vectorstores.chroma import Chroma |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain.prompts import ChatPromptTemplate |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
DATA_PATH = os.path.join(script_directory, "pdfs") |
|
|
CHROMA_PATH = "chroma" |
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
PROMPT_TEMPLATE = """ |
|
|
Answer the question based only on the following context: |
|
|
{context} |
|
|
--- |
|
|
Answer the question based on the above context: {question} |
|
|
""" |
|
|
|
|
|
def load_documents(): |
|
|
loader = DirectoryLoader(DATA_PATH, glob="*.pdf") |
|
|
documents = loader.load() |
|
|
return documents |
|
|
|
|
|
def split_text(documents): |
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=300, |
|
|
chunk_overlap=100, |
|
|
length_function=len, |
|
|
add_start_index=True, |
|
|
) |
|
|
chunks = text_splitter.split_documents(documents) |
|
|
print(f"Split {len(documents)} documents into {len(chunks)} chunks.") |
|
|
return chunks |
|
|
|
|
|
def save_to_chroma(chunks): |
|
|
|
|
|
if os.path.exists(CHROMA_PATH): |
|
|
shutil.rmtree(CHROMA_PATH) |
|
|
|
|
|
embeddings = OpenAIEmbeddings() |
|
|
|
|
|
db = Chroma.from_documents( |
|
|
chunks, embeddings, persist_directory=CHROMA_PATH |
|
|
) |
|
|
db.persist() |
|
|
print(f"Saved {len(chunks)} chunks to {CHROMA_PATH}.") |
|
|
|
|
|
|
|
|
def get_response(query_text): |
|
|
|
|
|
|
|
|
embedding_function = OpenAIEmbeddings() |
|
|
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function) |
|
|
|
|
|
|
|
|
results = db.similarity_search_with_relevance_scores(query_text, k=4) |
|
|
if len(results) == 0 or results[0][1] < 0.7: |
|
|
print(f"Unable to find matching results.") |
|
|
return |
|
|
|
|
|
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) |
|
|
|
|
|
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) |
|
|
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) |
|
|
prompt = prompt_template.format(context=context_text, question=query_text) |
|
|
|
|
|
model = ChatOpenAI() |
|
|
response_text = model.predict(prompt) |
|
|
|
|
|
sources = [doc.metadata.get("source", None) for doc, _score in results] |
|
|
sources = list(dict.fromkeys(sources)) |
|
|
formatted_response = f"Response: {response_text}\nSources: {sources}" |
|
|
return formatted_response |
|
|
|
|
|
def prepare(): |
|
|
documents = load_documents() |
|
|
chunks = split_text(documents) |
|
|
save_to_chroma(chunks) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface(fn=get_response, |
|
|
inputs=gr.components.Textbox(lines=7, label="Enter your text"), |
|
|
outputs="text", |
|
|
title="UK Insurance Law AI Tool") |
|
|
|
|
|
|
|
|
|
|
|
prepare() |
|
|
iface.launch() |
|
|
|