Damiyan_AI / app.py
Keisuke Yamanaka - CNC
DBใƒ‡ใƒชใƒผใƒˆใฎๅ‡ฆ็†ใ‚’่ฆ‹็›ดใ—ใฆใฟใŸใ€‚ๆœ€ๅˆใ‹ใ‚‰DBใฎใ‚คใƒณใƒ‡ใƒƒใ‚ฏใ‚นใชใใฆใ‚‚ๅ‡ฆ็†ใŒ้€šใ‚‹ใ‚ˆใ†ใซใ—ใŸ
f597d52
import gradio as gr
from huggingface_hub import InferenceClient
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
#from langchain.document_loaders import UnstructuredExcelLoader
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import UnstructuredURLLoader
from langchain_text_splitters import CharacterTextSplitter
import glob
import base64
import os
from os.path import split
import time
from langchain_core.messages import HumanMessage
from unstructured.partition.pdf import partition_pdf
import uuid
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
#from langchain_chroma import Chroma
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone, ServerlessSpec
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
import io
import re
import glob
#from IPython.display import HTML, display
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from PIL import Image
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
class Damiyan_AI:
def __init__(self):
print("Initialing CLASS:Damiyan_AI")
os.environ['PYTHINTRACEMALLOC'] = '1'
#os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
# initialize connection to pinecone (get API key at app.pinecone.io)
#os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
self.PINECONE_INDEX ="damiyan-ai"
self.PINECONE_ENV = "gcp-starter"
self.add_files = False
self.bot=self.load_QAAI()
def load_QAAI(self):
# File path
# The vectorstore to use to index the summaries
# Initialize empty summaries
text_summaries = []
texts = []
table_summaries = []
tables = []
# Store base64 encoded images
img_base64_list = []
# Store image summaries
image_summaries = []
#vectorstore = Chroma(
# collection_name="mm_rag_cj_blog", embedding_function=OpenAIEmbeddings()
#)
vectorstore = self.initialize_vectorstore(index_name=self.PINECONE_INDEX)
if self.add_files == True:
print("Start to load documents")
#fullpathes = ['./Doc/Regulations1_25R-01.pdf']
fullpathes=glob.glob(f'./Doc/*')
for i,fullpath in enumerate(fullpathes):
print(f'{i+1}/{len(fullpathes)}:{fullpath}')
text_summarie,text,table_summarie,table,image_summarie,img_base64 = self.load_documents(fullpath)
text_summaries += text_summarie
texts += text
table_summaries += table_summarie
tables += table
img_base64_list += image_summarie
image_summaries += img_base64
# Create retriever
self.retriever_multi_vector_img = self.create_multi_vector_retriever(
vectorstore,
text_summaries,
texts,
table_summaries,
tables,
image_summaries,
img_base64_list,
)
chain_multimodal_rag = self.multi_modal_rag_chain(self.retriever_multi_vector_img)
return chain_multimodal_rag
def load_documents(self,fullpath):
fpath, fname = split(fullpath)
fpath += '/'
# Get elements
print('Get elements')
raw_pdf_elements = self.extract_pdf_elements(fpath, fname)
# Get text, tables
print('Get text, tables')
texts, tables = self.categorize_elements(raw_pdf_elements)
# Optional: Enforce a specific token size for texts
print('Optional: Enforce a specific token size for texts')
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=4000, chunk_overlap=0
)
joined_texts = " ".join(texts)
texts_4k_token = text_splitter.split_text(joined_texts)
# Get text, table summaries
print('Get text, table summaries')
text_summaries, table_summaries = self.generate_text_summaries(
texts_4k_token, tables, summarize_texts=True
)
print('Image summaries')
img_base64_list, image_summaries = self.generate_img_summaries(fpath)
return text_summaries,texts,table_summaries,tables,image_summaries,img_base64_list
# Extract elements from PDF
def extract_pdf_elements(self,path, fname):
"""
Extract images, tables, and chunk text from a PDF file.
path: File path, which is used to dump images (.jpg)
fname: File name
"""
return partition_pdf(
filename=path + fname,
#filename=r'/content/drive/My Drive/huggingface_transformers_demo/transformers/Doc/ResconReg.pdf',
extract_images_in_pdf=True,
infer_table_structure=True,
chunking_strategy="by_title",
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
image_output_dir_path=path,
)
# Categorize elements by type
def categorize_elements(self,raw_pdf_elements):
"""
Categorize extracted elements from a PDF into tables and texts.
raw_pdf_elements: List of unstructured.documents.elements
"""
tables = []
texts = []
for element in raw_pdf_elements:
if "unstructured.documents.elements.Table" in str(type(element)):
tables.append(str(element))
elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
texts.append(str(element))
return texts, tables
# Generate summaries of text elements
def generate_text_summaries(self,texts, tables, summarize_texts=False):
"""
Summarize text elements
texts: List of str
tables: List of str
summarize_texts: Bool to summarize texts
"""
# Prompt
#prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
#These summaries will be embedded and used to retrieve the raw text or table elements. \
#Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} """
prompt_text = """ใ‚ใชใŸใฏๆ–‡ๆ›ธใ‚„่กจใ‚’่ฆ็ด„ใ™ใ‚‹ใ‚ฟใ‚นใ‚ฏใŒไธŽใˆใ‚‰ใ‚ŒใŸใ‚ขใ‚ทใ‚นใ‚ฟใƒณใƒˆใงใ™ \
่ฆ็ด„ใฏๅŸ‹ใ‚่พผใพใ‚Œใ€AIใŒๅ›ž็ญ”ใ™ใ‚‹้š›ใฎๅ‚่€ƒ่ณ‡ๆ–™ใจใ—ใฆไฝฟใ‚ใ‚Œใพใ™ใ€‚ \
AIใฎๅ‚่€ƒ่ณ‡ๆ–™ใจใ—ใฆๆœ€้ฉใชๅฝขใง่ฆ็ด„ใ—ใฆใใ ใ•ใ„ใ€‚่กจใ‹ๆ–‡็ซ : {element} """
prompt = ChatPromptTemplate.from_template(prompt_text)
# Text summary chain
model = ChatOpenAI(temperature=0, model="gpt-4o-mini")
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
# Initialize empty summaries
text_summaries = []
table_summaries = []
# Apply to text if texts are provided and summarization is requested
if texts and summarize_texts:
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
elif texts:
text_summaries = texts
# Apply to tables if tables are provided
if tables:
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
for text_summarie in text_summaries:
print('text_summaries')
print(text_summarie)
for table_summarie in table_summaries:
print('table_summaries')
print(table_summarie)
return text_summaries, table_summaries
def encode_image(self,image_path):
"""Getting the base64 string"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_summarize(self,img_base64, prompt):
"""Make image summary"""
chat = ChatOpenAI(self,model="gpt-4o-mini", max_tokens=1024)
msg = chat.invoke(
[
HumanMessage(
content=[
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
},
]
)
]
)
return msg.content
def generate_img_summaries(self,path):
"""
Generate summaries and base64 encoded strings for images
path: Path to list of .jpg files extracted by Unstructured
"""
# Store base64 encoded images
img_base64_list = []
# Store image summaries
image_summaries = []
# Prompt
#prompt = """You are an assistant tasked with summarizing images for retrieval. \
#These summaries will be embedded and used to retrieve the raw image. \
#Give a concise summary of the image that is well optimized for retrieval."""
prompt = """ใ‚ใชใŸใฏ็”ปๅƒใ‚’่ฆ็ด„ใ™ใ‚‹ใ‚ฟใ‚นใ‚ฏใŒไธŽใˆใ‚‰ใ‚ŒใŸใ‚ขใ‚ทใ‚นใ‚ฟใƒณใƒˆใงใ™ใ€‚ \
่ฆ็ด„ใฏๅŸ‹ใ‚่พผใพใ‚Œใ€AIใฎๅ›ž็ญ”ใฎๅ‚่€ƒๆƒ…ๅ ฑใจใ—ใฆไฝฟใ‚ใ‚Œใพใ™ใ€‚. \
ๅ‚่€ƒ่ณ‡ๆ–™ใจใ—ใฆๆœ€้ฉใช่ฆ็ด„ใ‚’ไฝœใฃใฆใใ ใ•ใ„."""
# Apply to images
for img_file in sorted(os.listdir(path)):
if img_file.endswith(".jpg"):
img_path = os.path.join(path, img_file)
base64_image = self.encode_image(img_path)
img_base64_list.append(base64_image)
image_summaries.append(self.image_summarize(base64_image, prompt))
for image_summarie in image_summaries:
print('image_summarie')
print(image_summarie)
return img_base64_list, image_summaries
def create_multi_vector_retriever(
self,vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
):
"""
Create retriever that indexes summaries, but returns raw images or texts
"""
# Initialize the storage layer
store = InMemoryStore()
id_key = "doc_id"
# Create the multi-vector retriever
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
# Helper function to add documents to the vectorstore and docstore
def add_documents(retriever, doc_summaries, doc_contents):
print('add_documentts---->>>')
print(doc_summaries)
doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
summary_docs = [
Document(page_content=s, metadata={id_key: doc_ids[i]})
for i, s in enumerate(doc_summaries)
]
retriever.vectorstore.add_documents(summary_docs)
retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
# Add texts, tables, and images
# Check that text_summaries is not empty before adding
if self.add_files == True:
if text_summaries:
add_documents(retriever, text_summaries, texts)
# Check that table_summaries is not empty before adding
if table_summaries:
add_documents(retriever, table_summaries, tables)
# Check that image_summaries is not empty before adding
if image_summaries:
add_documents(retriever, image_summaries, images)
return retriever
# def plt_img_base64(self,img_base64):
# """Disply base64 encoded string as image"""
# # Create an HTML img tag with the base64 string as the source
# image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
# # Display the image by rendering the HTML
# display(HTML(image_html))
def looks_like_base64(self,sb):
"""Check if the string looks like base64"""
return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
def is_image_data(self,b64data):
"""
Check if the base64 data is an image by looking at the start of the data
"""
image_signatures = {
b"\xff\xd8\xff": "jpg",
b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
b"\x47\x49\x46\x38": "gif",
b"\x52\x49\x46\x46": "webp",
}
try:
header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
for sig, format in image_signatures.items():
if header.startswith(sig):
return True
return False
except Exception:
return False
def resize_base64_image(self,base64_string, size=(128, 128)):
"""
Resize an image encoded as a Base64 string
"""
# Decode the Base64 string
img_data = base64.b64decode(base64_string)
img = Image.open(io.BytesIO(img_data))
# Resize the image
resized_img = img.resize(size, Image.LANCZOS)
# Save the resized image to a bytes buffer
buffered = io.BytesIO()
resized_img.save(buffered, format=img.format)
# Encode the resized image to Base64
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def split_image_text_types(self,docs):
"""
Split base64-encoded images and texts
"""
b64_images = []
texts = []
for doc in docs:
# Check if the document is of type Document and extract page_content if so
if isinstance(doc, Document):
doc = doc.page_content
if self.looks_like_base64(doc) and self.is_image_data(doc):
doc = self.resize_base64_image(doc, size=(1300, 600))
b64_images.append(doc)
else:
texts.append(doc)
return {"images": b64_images, "texts": texts}
def img_prompt_func(self,data_dict):
"""
Join the context into a single string
"""
formatted_texts = "\n".join(data_dict["context"]["texts"])
messages = []
# Adding image(s) to the messages if present
if data_dict["context"]["images"]:
for image in data_dict["context"]["images"]:
image_message = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
}
messages.append(image_message)
# Adding the text for analysis
text_message = {
"type": "text",
"text": (
"ใ‚ใชใŸใฏใƒญใƒœใƒƒใƒˆใ‚ณใƒณใƒ†ใ‚นใƒˆใฎไธปๅ‚ฌ่€…ใงใ™ใ€‚ใƒญใƒœใƒƒใƒˆๆŠ€่ก“ใ‚„ใ‚ณใƒณใƒ†ใ‚นใƒˆใซ้–ขใ™ใ‚‹่ณช้–€ใซ็ญ”ใˆใฆใใ ใ•ใ„ใ€‚\n"
"ใ‚ใชใŸใซใฏใ€ใƒ†ใ‚ญใ‚นใƒˆใ‚„่กจใ€็”ปๅƒใŒไธŽใˆใ‚‰ใ‚Œใพใ™ใ€‚\n"
"ไธŽใˆใ‚‰ใ‚ŒใŸๆƒ…ๅ ฑใ‚’ไฝฟใฃใฆใ€ใƒฆใƒผใ‚ถใ‹ใ‚‰ใฎ่ณช้–€ใซ็ญ”ใˆใฆใใ ใ•ใ„ใ€‚\n"
"่ณช้–€ใจใƒ†ใ‚ญใ‚นใƒˆใ€่กจใซ่กจ่จ˜ใฎๆบใ‚ŒใŒใ‚ใ‚‹ๅ ดๅˆใฏใ€่กจ่จ˜ใฎๆบใ‚ŒใŒใ‚ใ‚‹ใ“ใจใ‚’ๆณจ่จ˜ใ—ใฆใใ ใ•ใ„ใ€‚\n"
f"ใƒฆใƒผใ‚ถใ‹ใ‚‰ใฎ่ณช้–€: {data_dict['question']}\n\n"
"ใƒ†ใ‚ญใ‚นใƒˆใ‚„่กจ:\n"
f"{formatted_texts}"
),
}
messages.append(text_message)
return [HumanMessage(content=messages)]
def multi_modal_rag_chain(self,retriever):
"""
Multi-modal RAG chain
"""
# Multi-modal LLM
model = ChatOpenAI(temperature=0, model="gpt-4o-mini", max_tokens=1024)
# RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(self.split_image_text_types),
"question": RunnablePassthrough(),
}
| RunnableLambda(self.img_prompt_func)
| model
| StrOutputParser()
)
return chain
def initialize_vectorstore(self,index_name):
model_name = "text-embedding-3-small"
embeddings = self.load_embedding_model(model_name=model_name)
print(f'loading vectorstore:{index_name}')
#print(f'KEY:{os.environ.get("PINECONE_API_KEY")}')
self.pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))
self.delete_documents()
existing_indexes = [index_info["name"] for index_info in self.pc.list_indexes()]
if index_name not in existing_indexes:
print(f'Index:{self.PINECONE_INDEX} is not found....')
print(f'Creating new Index:{self.PINECONE_INDEX}')
self.add_files=True
self.pc.create_index(
name=self.PINECONE_INDEX,
dimension=1536,
metric="cosine",
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
)
while not self.pc.describe_index(self.PINECONE_INDEX).status["ready"]:
time.sleep(1)
print(f'Created new Index:{self.PINECONE_INDEX}')
self.show_index()
index = self.pc.Index(self.PINECONE_INDEX)
vectorstore = PineconeVectorStore(index=index, embedding=embeddings)
return vectorstore
def delete_documents(self):
existing_indexes = [index_info["name"] for index_info in self.pc.list_indexes()]
if self.PINECONE_INDEX in existing_indexes:
print(f'delete documents.....')
self.pc.delete_index(self.PINECONE_INDEX)
self.show_index
def load_embedding_model(self,model_name):
print(f'loading embedding model:{model_name}')
embeddings = OpenAIEmbeddings(
model=model_name,
)
return embeddings
def show_index(self):
print(f'detail of Index:{self.PINECONE_INDEX}')
index = self.pc.Index(self.PINECONE_INDEX)
while not self.pc.describe_index(self.PINECONE_INDEX).status["ready"]:
time.sleep(1)
print(index.describe_index_stats())
def echo(self,message,history):
if message == "Who are you?":
ans = "็งใฏใƒ€ใƒŸใƒคใƒณAIใงใ™ใ€‚ใƒฌใ‚นใ‚ณใƒณใซ้–ขใ™ใ‚‹่ณช้–€ใซ็ญ”ใˆใพใ™"
elif message == 'ใƒใƒซใ‚น':
self.delete_documents()
ans = 'ใƒฌใ‚นใ‚ณใƒณใซ้–ขใ™ใ‚‹ใ“ใจใ‚’ๅฟ˜ใ‚Œใพใ—ใŸ'
else:
ans = self.bot.invoke(message)
#print(self.retriever_multi_vector_img.invoke(message))
#if len(self.retriever_multi_vector_img.invoke(message)) > 0:
# ans += 'โ– ๅ‚่€ƒ่ณ‡ๆ–™:'
# ans += self.retriever_multi_vector_img.invoke(message)[0].page_content
return ans
if __name__ == "__main__":
print("start")
damiyan = Damiyan_AI()
demo = gr.ChatInterface(fn=damiyan.echo, examples=["Who are you?"], title="MELDAS AI")
demo.launch()