Spaces:
Sleeping
Sleeping
Keisuke Yamanaka - CNC
DBใใชใผใใฎๅฆ็ใ่ฆ็ดใใฆใฟใใๆๅใใDBใฎใคใณใใใฏในใชใใฆใๅฆ็ใ้ใใใใซใใ
f597d52 | import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_openai import ChatOpenAI | |
| #from langchain.document_loaders import UnstructuredExcelLoader | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.document_loaders import UnstructuredURLLoader | |
| from langchain_text_splitters import CharacterTextSplitter | |
| import glob | |
| import base64 | |
| import os | |
| from os.path import split | |
| import time | |
| from langchain_core.messages import HumanMessage | |
| from unstructured.partition.pdf import partition_pdf | |
| import uuid | |
| from langchain.retrievers.multi_vector import MultiVectorRetriever | |
| from langchain.storage import InMemoryStore | |
| #from langchain_chroma import Chroma | |
| from langchain_pinecone import PineconeVectorStore | |
| from pinecone import Pinecone, ServerlessSpec | |
| from langchain_core.documents import Document | |
| from langchain_openai import OpenAIEmbeddings | |
| import io | |
| import re | |
| import glob | |
| #from IPython.display import HTML, display | |
| from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
| from PIL import Image | |
| """ | |
| For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
| """ | |
| client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
| class Damiyan_AI: | |
| def __init__(self): | |
| print("Initialing CLASS:Damiyan_AI") | |
| os.environ['PYTHINTRACEMALLOC'] = '1' | |
| #os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY | |
| # initialize connection to pinecone (get API key at app.pinecone.io) | |
| #os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY | |
| self.PINECONE_INDEX ="damiyan-ai" | |
| self.PINECONE_ENV = "gcp-starter" | |
| self.add_files = False | |
| self.bot=self.load_QAAI() | |
| def load_QAAI(self): | |
| # File path | |
| # The vectorstore to use to index the summaries | |
| # Initialize empty summaries | |
| text_summaries = [] | |
| texts = [] | |
| table_summaries = [] | |
| tables = [] | |
| # Store base64 encoded images | |
| img_base64_list = [] | |
| # Store image summaries | |
| image_summaries = [] | |
| #vectorstore = Chroma( | |
| # collection_name="mm_rag_cj_blog", embedding_function=OpenAIEmbeddings() | |
| #) | |
| vectorstore = self.initialize_vectorstore(index_name=self.PINECONE_INDEX) | |
| if self.add_files == True: | |
| print("Start to load documents") | |
| #fullpathes = ['./Doc/Regulations1_25R-01.pdf'] | |
| fullpathes=glob.glob(f'./Doc/*') | |
| for i,fullpath in enumerate(fullpathes): | |
| print(f'{i+1}/{len(fullpathes)}:{fullpath}') | |
| text_summarie,text,table_summarie,table,image_summarie,img_base64 = self.load_documents(fullpath) | |
| text_summaries += text_summarie | |
| texts += text | |
| table_summaries += table_summarie | |
| tables += table | |
| img_base64_list += image_summarie | |
| image_summaries += img_base64 | |
| # Create retriever | |
| self.retriever_multi_vector_img = self.create_multi_vector_retriever( | |
| vectorstore, | |
| text_summaries, | |
| texts, | |
| table_summaries, | |
| tables, | |
| image_summaries, | |
| img_base64_list, | |
| ) | |
| chain_multimodal_rag = self.multi_modal_rag_chain(self.retriever_multi_vector_img) | |
| return chain_multimodal_rag | |
| def load_documents(self,fullpath): | |
| fpath, fname = split(fullpath) | |
| fpath += '/' | |
| # Get elements | |
| print('Get elements') | |
| raw_pdf_elements = self.extract_pdf_elements(fpath, fname) | |
| # Get text, tables | |
| print('Get text, tables') | |
| texts, tables = self.categorize_elements(raw_pdf_elements) | |
| # Optional: Enforce a specific token size for texts | |
| print('Optional: Enforce a specific token size for texts') | |
| text_splitter = CharacterTextSplitter.from_tiktoken_encoder( | |
| chunk_size=4000, chunk_overlap=0 | |
| ) | |
| joined_texts = " ".join(texts) | |
| texts_4k_token = text_splitter.split_text(joined_texts) | |
| # Get text, table summaries | |
| print('Get text, table summaries') | |
| text_summaries, table_summaries = self.generate_text_summaries( | |
| texts_4k_token, tables, summarize_texts=True | |
| ) | |
| print('Image summaries') | |
| img_base64_list, image_summaries = self.generate_img_summaries(fpath) | |
| return text_summaries,texts,table_summaries,tables,image_summaries,img_base64_list | |
| # Extract elements from PDF | |
| def extract_pdf_elements(self,path, fname): | |
| """ | |
| Extract images, tables, and chunk text from a PDF file. | |
| path: File path, which is used to dump images (.jpg) | |
| fname: File name | |
| """ | |
| return partition_pdf( | |
| filename=path + fname, | |
| #filename=r'/content/drive/My Drive/huggingface_transformers_demo/transformers/Doc/ResconReg.pdf', | |
| extract_images_in_pdf=True, | |
| infer_table_structure=True, | |
| chunking_strategy="by_title", | |
| max_characters=4000, | |
| new_after_n_chars=3800, | |
| combine_text_under_n_chars=2000, | |
| image_output_dir_path=path, | |
| ) | |
| # Categorize elements by type | |
| def categorize_elements(self,raw_pdf_elements): | |
| """ | |
| Categorize extracted elements from a PDF into tables and texts. | |
| raw_pdf_elements: List of unstructured.documents.elements | |
| """ | |
| tables = [] | |
| texts = [] | |
| for element in raw_pdf_elements: | |
| if "unstructured.documents.elements.Table" in str(type(element)): | |
| tables.append(str(element)) | |
| elif "unstructured.documents.elements.CompositeElement" in str(type(element)): | |
| texts.append(str(element)) | |
| return texts, tables | |
| # Generate summaries of text elements | |
| def generate_text_summaries(self,texts, tables, summarize_texts=False): | |
| """ | |
| Summarize text elements | |
| texts: List of str | |
| tables: List of str | |
| summarize_texts: Bool to summarize texts | |
| """ | |
| # Prompt | |
| #prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \ | |
| #These summaries will be embedded and used to retrieve the raw text or table elements. \ | |
| #Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} """ | |
| prompt_text = """ใใชใใฏๆๆธใ่กจใ่ฆ็ดใใใฟในใฏใไธใใใใใขใทในใฟใณใใงใ \ | |
| ่ฆ็ดใฏๅใ่พผใพใใAIใๅ็ญใใ้ใฎๅ่่ณๆใจใใฆไฝฟใใใพใใ \ | |
| AIใฎๅ่่ณๆใจใใฆๆ้ฉใชๅฝขใง่ฆ็ดใใฆใใ ใใใ่กจใๆ็ซ : {element} """ | |
| prompt = ChatPromptTemplate.from_template(prompt_text) | |
| # Text summary chain | |
| model = ChatOpenAI(temperature=0, model="gpt-4o-mini") | |
| summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser() | |
| # Initialize empty summaries | |
| text_summaries = [] | |
| table_summaries = [] | |
| # Apply to text if texts are provided and summarization is requested | |
| if texts and summarize_texts: | |
| text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5}) | |
| elif texts: | |
| text_summaries = texts | |
| # Apply to tables if tables are provided | |
| if tables: | |
| table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5}) | |
| for text_summarie in text_summaries: | |
| print('text_summaries') | |
| print(text_summarie) | |
| for table_summarie in table_summaries: | |
| print('table_summaries') | |
| print(table_summarie) | |
| return text_summaries, table_summaries | |
| def encode_image(self,image_path): | |
| """Getting the base64 string""" | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode("utf-8") | |
| def image_summarize(self,img_base64, prompt): | |
| """Make image summary""" | |
| chat = ChatOpenAI(self,model="gpt-4o-mini", max_tokens=1024) | |
| msg = chat.invoke( | |
| [ | |
| HumanMessage( | |
| content=[ | |
| {"type": "text", "text": prompt}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, | |
| }, | |
| ] | |
| ) | |
| ] | |
| ) | |
| return msg.content | |
| def generate_img_summaries(self,path): | |
| """ | |
| Generate summaries and base64 encoded strings for images | |
| path: Path to list of .jpg files extracted by Unstructured | |
| """ | |
| # Store base64 encoded images | |
| img_base64_list = [] | |
| # Store image summaries | |
| image_summaries = [] | |
| # Prompt | |
| #prompt = """You are an assistant tasked with summarizing images for retrieval. \ | |
| #These summaries will be embedded and used to retrieve the raw image. \ | |
| #Give a concise summary of the image that is well optimized for retrieval.""" | |
| prompt = """ใใชใใฏ็ปๅใ่ฆ็ดใใใฟในใฏใไธใใใใใขใทในใฟใณใใงใใ \ | |
| ่ฆ็ดใฏๅใ่พผใพใใAIใฎๅ็ญใฎๅ่ๆ ๅ ฑใจใใฆไฝฟใใใพใใ. \ | |
| ๅ่่ณๆใจใใฆๆ้ฉใช่ฆ็ดใไฝใฃใฆใใ ใใ.""" | |
| # Apply to images | |
| for img_file in sorted(os.listdir(path)): | |
| if img_file.endswith(".jpg"): | |
| img_path = os.path.join(path, img_file) | |
| base64_image = self.encode_image(img_path) | |
| img_base64_list.append(base64_image) | |
| image_summaries.append(self.image_summarize(base64_image, prompt)) | |
| for image_summarie in image_summaries: | |
| print('image_summarie') | |
| print(image_summarie) | |
| return img_base64_list, image_summaries | |
| def create_multi_vector_retriever( | |
| self,vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images | |
| ): | |
| """ | |
| Create retriever that indexes summaries, but returns raw images or texts | |
| """ | |
| # Initialize the storage layer | |
| store = InMemoryStore() | |
| id_key = "doc_id" | |
| # Create the multi-vector retriever | |
| retriever = MultiVectorRetriever( | |
| vectorstore=vectorstore, | |
| docstore=store, | |
| id_key=id_key, | |
| ) | |
| # Helper function to add documents to the vectorstore and docstore | |
| def add_documents(retriever, doc_summaries, doc_contents): | |
| print('add_documentts---->>>') | |
| print(doc_summaries) | |
| doc_ids = [str(uuid.uuid4()) for _ in doc_contents] | |
| summary_docs = [ | |
| Document(page_content=s, metadata={id_key: doc_ids[i]}) | |
| for i, s in enumerate(doc_summaries) | |
| ] | |
| retriever.vectorstore.add_documents(summary_docs) | |
| retriever.docstore.mset(list(zip(doc_ids, doc_contents))) | |
| # Add texts, tables, and images | |
| # Check that text_summaries is not empty before adding | |
| if self.add_files == True: | |
| if text_summaries: | |
| add_documents(retriever, text_summaries, texts) | |
| # Check that table_summaries is not empty before adding | |
| if table_summaries: | |
| add_documents(retriever, table_summaries, tables) | |
| # Check that image_summaries is not empty before adding | |
| if image_summaries: | |
| add_documents(retriever, image_summaries, images) | |
| return retriever | |
| # def plt_img_base64(self,img_base64): | |
| # """Disply base64 encoded string as image""" | |
| # # Create an HTML img tag with the base64 string as the source | |
| # image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />' | |
| # # Display the image by rendering the HTML | |
| # display(HTML(image_html)) | |
| def looks_like_base64(self,sb): | |
| """Check if the string looks like base64""" | |
| return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None | |
| def is_image_data(self,b64data): | |
| """ | |
| Check if the base64 data is an image by looking at the start of the data | |
| """ | |
| image_signatures = { | |
| b"\xff\xd8\xff": "jpg", | |
| b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png", | |
| b"\x47\x49\x46\x38": "gif", | |
| b"\x52\x49\x46\x46": "webp", | |
| } | |
| try: | |
| header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes | |
| for sig, format in image_signatures.items(): | |
| if header.startswith(sig): | |
| return True | |
| return False | |
| except Exception: | |
| return False | |
| def resize_base64_image(self,base64_string, size=(128, 128)): | |
| """ | |
| Resize an image encoded as a Base64 string | |
| """ | |
| # Decode the Base64 string | |
| img_data = base64.b64decode(base64_string) | |
| img = Image.open(io.BytesIO(img_data)) | |
| # Resize the image | |
| resized_img = img.resize(size, Image.LANCZOS) | |
| # Save the resized image to a bytes buffer | |
| buffered = io.BytesIO() | |
| resized_img.save(buffered, format=img.format) | |
| # Encode the resized image to Base64 | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def split_image_text_types(self,docs): | |
| """ | |
| Split base64-encoded images and texts | |
| """ | |
| b64_images = [] | |
| texts = [] | |
| for doc in docs: | |
| # Check if the document is of type Document and extract page_content if so | |
| if isinstance(doc, Document): | |
| doc = doc.page_content | |
| if self.looks_like_base64(doc) and self.is_image_data(doc): | |
| doc = self.resize_base64_image(doc, size=(1300, 600)) | |
| b64_images.append(doc) | |
| else: | |
| texts.append(doc) | |
| return {"images": b64_images, "texts": texts} | |
| def img_prompt_func(self,data_dict): | |
| """ | |
| Join the context into a single string | |
| """ | |
| formatted_texts = "\n".join(data_dict["context"]["texts"]) | |
| messages = [] | |
| # Adding image(s) to the messages if present | |
| if data_dict["context"]["images"]: | |
| for image in data_dict["context"]["images"]: | |
| image_message = { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{image}"}, | |
| } | |
| messages.append(image_message) | |
| # Adding the text for analysis | |
| text_message = { | |
| "type": "text", | |
| "text": ( | |
| "ใใชใใฏใญใใใใณใณใในใใฎไธปๅฌ่ ใงใใใญใใใๆ่กใใณใณใในใใซ้ขใใ่ณช้ใซ็ญใใฆใใ ใใใ\n" | |
| "ใใชใใซใฏใใใญในใใ่กจใ็ปๅใไธใใใใพใใ\n" | |
| "ไธใใใใๆ ๅ ฑใไฝฟใฃใฆใใฆใผใถใใใฎ่ณช้ใซ็ญใใฆใใ ใใใ\n" | |
| "่ณช้ใจใใญในใใ่กจใซ่กจ่จใฎๆบใใใใๅ ดๅใฏใ่กจ่จใฎๆบใใใใใใจใๆณจ่จใใฆใใ ใใใ\n" | |
| f"ใฆใผใถใใใฎ่ณช้: {data_dict['question']}\n\n" | |
| "ใใญในใใ่กจ:\n" | |
| f"{formatted_texts}" | |
| ), | |
| } | |
| messages.append(text_message) | |
| return [HumanMessage(content=messages)] | |
| def multi_modal_rag_chain(self,retriever): | |
| """ | |
| Multi-modal RAG chain | |
| """ | |
| # Multi-modal LLM | |
| model = ChatOpenAI(temperature=0, model="gpt-4o-mini", max_tokens=1024) | |
| # RAG pipeline | |
| chain = ( | |
| { | |
| "context": retriever | RunnableLambda(self.split_image_text_types), | |
| "question": RunnablePassthrough(), | |
| } | |
| | RunnableLambda(self.img_prompt_func) | |
| | model | |
| | StrOutputParser() | |
| ) | |
| return chain | |
| def initialize_vectorstore(self,index_name): | |
| model_name = "text-embedding-3-small" | |
| embeddings = self.load_embedding_model(model_name=model_name) | |
| print(f'loading vectorstore:{index_name}') | |
| #print(f'KEY:{os.environ.get("PINECONE_API_KEY")}') | |
| self.pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY")) | |
| self.delete_documents() | |
| existing_indexes = [index_info["name"] for index_info in self.pc.list_indexes()] | |
| if index_name not in existing_indexes: | |
| print(f'Index:{self.PINECONE_INDEX} is not found....') | |
| print(f'Creating new Index:{self.PINECONE_INDEX}') | |
| self.add_files=True | |
| self.pc.create_index( | |
| name=self.PINECONE_INDEX, | |
| dimension=1536, | |
| metric="cosine", | |
| spec=ServerlessSpec(cloud="aws", region="us-east-1"), | |
| ) | |
| while not self.pc.describe_index(self.PINECONE_INDEX).status["ready"]: | |
| time.sleep(1) | |
| print(f'Created new Index:{self.PINECONE_INDEX}') | |
| self.show_index() | |
| index = self.pc.Index(self.PINECONE_INDEX) | |
| vectorstore = PineconeVectorStore(index=index, embedding=embeddings) | |
| return vectorstore | |
| def delete_documents(self): | |
| existing_indexes = [index_info["name"] for index_info in self.pc.list_indexes()] | |
| if self.PINECONE_INDEX in existing_indexes: | |
| print(f'delete documents.....') | |
| self.pc.delete_index(self.PINECONE_INDEX) | |
| self.show_index | |
| def load_embedding_model(self,model_name): | |
| print(f'loading embedding model:{model_name}') | |
| embeddings = OpenAIEmbeddings( | |
| model=model_name, | |
| ) | |
| return embeddings | |
| def show_index(self): | |
| print(f'detail of Index:{self.PINECONE_INDEX}') | |
| index = self.pc.Index(self.PINECONE_INDEX) | |
| while not self.pc.describe_index(self.PINECONE_INDEX).status["ready"]: | |
| time.sleep(1) | |
| print(index.describe_index_stats()) | |
| def echo(self,message,history): | |
| if message == "Who are you?": | |
| ans = "็งใฏใใใคใณAIใงใใใฌในใณใณใซ้ขใใ่ณช้ใซ็ญใใพใ" | |
| elif message == 'ใใซใน': | |
| self.delete_documents() | |
| ans = 'ใฌในใณใณใซ้ขใใใใจใๅฟใใพใใ' | |
| else: | |
| ans = self.bot.invoke(message) | |
| #print(self.retriever_multi_vector_img.invoke(message)) | |
| #if len(self.retriever_multi_vector_img.invoke(message)) > 0: | |
| # ans += 'โ ๅ่่ณๆ:' | |
| # ans += self.retriever_multi_vector_img.invoke(message)[0].page_content | |
| return ans | |
| if __name__ == "__main__": | |
| print("start") | |
| damiyan = Damiyan_AI() | |
| demo = gr.ChatInterface(fn=damiyan.echo, examples=["Who are you?"], title="MELDAS AI") | |
| demo.launch() | |