| | import copy |
| | import os |
| | from pathlib import Path |
| | from typing import Union, Any |
| |
|
| | from grobid_client.grobid_client import GrobidClient |
| | from langchain.chains import create_extraction_chain |
| | from langchain.chains.question_answering import load_qa_chain |
| | from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate |
| | from langchain.retrievers import MultiQueryRetriever |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from langchain.vectorstores import Chroma |
| | from tqdm import tqdm |
| |
|
| | from grobid_processors import GrobidProcessor |
| |
|
| |
|
| | class DocumentQAEngine: |
| | llm = None |
| | qa_chain_type = None |
| | embedding_function = None |
| | embeddings_dict = {} |
| | embeddings_map_from_md5 = {} |
| | embeddings_map_to_md5 = {} |
| |
|
| | def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None): |
| | self.embedding_function = embedding_function |
| | self.llm = llm |
| | self.chain = load_qa_chain(llm, chain_type=qa_chain_type) |
| |
|
| | if embeddings_root_path is not None: |
| | self.embeddings_root_path = embeddings_root_path |
| | if not os.path.exists(embeddings_root_path): |
| | os.makedirs(embeddings_root_path) |
| | else: |
| | self.load_embeddings(self.embeddings_root_path) |
| |
|
| | if grobid_url: |
| | self.grobid_url = grobid_url |
| | grobid_client = GrobidClient( |
| | grobid_server=self.grobid_url, |
| | batch_size=1000, |
| | coordinates=["p"], |
| | sleep_time=5, |
| | timeout=60, |
| | check_server=True |
| | ) |
| | self.grobid_processor = GrobidProcessor(grobid_client) |
| |
|
| | def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None: |
| | """ |
| | Load the embeddings assuming they are all persisted and stored in a single directory. |
| | The root path of the embeddings containing one data store for each document in each subdirectory |
| | """ |
| |
|
| | embeddings_directories = [f for f in os.scandir(embeddings_root_path) if f.is_dir()] |
| |
|
| | if len(embeddings_directories) == 0: |
| | print("No available embeddings") |
| | return |
| |
|
| | for embedding_document_dir in embeddings_directories: |
| | self.embeddings_dict[embedding_document_dir.name] = Chroma(persist_directory=embedding_document_dir.path, |
| | embedding_function=self.embedding_function) |
| |
|
| | filename_list = list(Path(embedding_document_dir).glob('*.storage_filename')) |
| | if filename_list: |
| | filenam = filename_list[0].name.replace(".storage_filename", "") |
| | self.embeddings_map_from_md5[embedding_document_dir.name] = filenam |
| | self.embeddings_map_to_md5[filenam] = embedding_document_dir.name |
| |
|
| | print("Embedding loaded: ", len(self.embeddings_dict.keys())) |
| |
|
| | def get_loaded_embeddings_ids(self): |
| | return list(self.embeddings_dict.keys()) |
| |
|
| | def get_md5_from_filename(self, filename): |
| | return self.embeddings_map_to_md5[filename] |
| |
|
| | def get_filename_from_md5(self, md5): |
| | return self.embeddings_map_from_md5[md5] |
| |
|
| | def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None, |
| | verbose=False) -> ( |
| | Any, str): |
| | |
| |
|
| | if verbose: |
| | print(query) |
| |
|
| | response = self._run_query(doc_id, query, context_size=context_size) |
| | response = response['output_text'] if 'output_text' in response else response |
| |
|
| | if verbose: |
| | print(doc_id, "->", response) |
| |
|
| | if output_parser: |
| | try: |
| | return self._parse_json(response, output_parser), response |
| | except Exception as oe: |
| | print("Failing to parse the response", oe) |
| | return None, response |
| | elif extraction_schema: |
| | try: |
| | chain = create_extraction_chain(extraction_schema, self.llm) |
| | parsed = chain.run(response) |
| | return parsed, response |
| | except Exception as oe: |
| | print("Failing to parse the response", oe) |
| | return None, response |
| | else: |
| | return None, response |
| |
|
| | def query_storage(self, query: str, doc_id, context_size=4): |
| | documents = self._get_context(doc_id, query, context_size) |
| |
|
| | context_as_text = [doc.page_content for doc in documents] |
| | return context_as_text |
| |
|
| | def _parse_json(self, response, output_parser): |
| | system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \ |
| | "that can process text and transform it to JSON." |
| | human_message = """Transform the text between three double quotes in JSON.\n\n\n\n |
| | {format_instructions}\n\nText: \"\"\"{text}\"\"\"""" |
| |
|
| | system_message_prompt = SystemMessagePromptTemplate.from_template(system_message) |
| | human_message_prompt = HumanMessagePromptTemplate.from_template(human_message) |
| |
|
| | prompt_template = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) |
| |
|
| | results = self.llm( |
| | prompt_template.format_prompt( |
| | text=response, |
| | format_instructions=output_parser.get_format_instructions() |
| | ).to_messages() |
| | ) |
| | parsed_output = output_parser.parse(results.content) |
| |
|
| | return parsed_output |
| |
|
| | def _run_query(self, doc_id, query, context_size=4): |
| | relevant_documents = self._get_context(doc_id, query, context_size) |
| | return self.chain.run(input_documents=relevant_documents, question=query) |
| | |
| |
|
| | def _get_context(self, doc_id, query, context_size=4): |
| | db = self.embeddings_dict[doc_id] |
| | retriever = db.as_retriever(search_kwargs={"k": context_size}) |
| | relevant_documents = retriever.get_relevant_documents(query) |
| | return relevant_documents |
| |
|
| | def get_all_context_by_document(self, doc_id): |
| | db = self.embeddings_dict[doc_id] |
| | docs = db.get() |
| | return docs['documents'] |
| |
|
| | def _get_context_multiquery(self, doc_id, query, context_size=4): |
| | db = self.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size}) |
| | multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm) |
| | relevant_documents = multi_query_retriever.get_relevant_documents(query) |
| | return relevant_documents |
| |
|
| | def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False): |
| | if verbose: |
| | print("File", pdf_file_path) |
| | filename = Path(pdf_file_path).stem |
| | structure = self.grobid_processor.process_structure(pdf_file_path) |
| |
|
| | biblio = structure['biblio'] |
| | biblio['filename'] = filename.replace(" ", "_") |
| |
|
| | if verbose: |
| | print("Generating embeddings for:", hash, ", filename: ", filename) |
| |
|
| | texts = [] |
| | metadatas = [] |
| | ids = [] |
| | if chunk_size < 0: |
| | for passage in structure['passages']: |
| | biblio_copy = copy.copy(biblio) |
| | if len(str.strip(passage['text'])) > 0: |
| | texts.append(passage['text']) |
| |
|
| | biblio_copy['type'] = passage['type'] |
| | biblio_copy['section'] = passage['section'] |
| | biblio_copy['subSection'] = passage['subSection'] |
| | metadatas.append(biblio_copy) |
| |
|
| | ids.append(passage['passage_id']) |
| | else: |
| | document_text = " ".join([passage['text'] for passage in structure['passages']]) |
| | |
| | text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( |
| | chunk_size=chunk_size, |
| | chunk_overlap=chunk_size * perc_overlap |
| | ) |
| | texts = text_splitter.split_text(document_text) |
| | metadatas = [biblio for _ in range(len(texts))] |
| | ids = [id for id, t in enumerate(texts)] |
| |
|
| | return texts, metadatas, ids |
| |
|
| | def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1): |
| | texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=chunk_size, perc_overlap=perc_overlap) |
| | if doc_id: |
| | hash = doc_id |
| | else: |
| | hash = metadata[0]['hash'] |
| |
|
| | self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata) |
| | self.embeddings_root_path = None |
| |
|
| | return hash |
| |
|
| | def create_embeddings(self, pdfs_dir_path: Path): |
| | input_files = [] |
| | for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False): |
| | for file_ in files: |
| | if not (file_.lower().endswith(".pdf")): |
| | continue |
| | input_files.append(os.path.join(root, file_)) |
| |
|
| | for input_file in tqdm(input_files, total=len(input_files), unit='document', |
| | desc="Grobid + embeddings processing"): |
| |
|
| | md5 = self.calculate_md5(input_file) |
| | data_path = os.path.join(self.embeddings_root_path, md5) |
| |
|
| | if os.path.exists(data_path): |
| | print(data_path, "exists. Skipping it ") |
| | continue |
| |
|
| | texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=500, perc_overlap=0.1) |
| | filename = metadata[0]['filename'] |
| |
|
| | vector_db_document = Chroma.from_texts(texts, |
| | metadatas=metadata, |
| | embedding=self.embedding_function, |
| | persist_directory=data_path) |
| | vector_db_document.persist() |
| |
|
| | with open(os.path.join(data_path, filename + ".storage_filename"), 'w') as fo: |
| | fo.write("") |
| |
|
| | @staticmethod |
| | def calculate_md5(input_file: Union[Path, str]): |
| | import hashlib |
| | md5_hash = hashlib.md5() |
| | with open(input_file, 'rb') as fi: |
| | md5_hash.update(fi.read()) |
| | return md5_hash.hexdigest().upper() |
| |
|