sunny333 commited on
Commit
568cd7b
·
1 Parent(s): 731be0b

initial commit

Browse files
RAG_MLM/__init__.py ADDED
File without changes
RAG_MLM/app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import random
4
+ import embedder as eb
5
+ import utility as ut
6
+ import ragMLM as rag
7
+ from PIL import Image
8
+ import base64
9
+ from io import BytesIO
10
+ #----image utility-----
11
+ def plt_img_base64(img_base64):
12
+ """Disply base64 encoded string as image"""
13
+ # Decode the base64 string
14
+ img_data = base64.b64decode(img_base64)
15
+ # Create a BytesIO object
16
+ img_buffer = BytesIO(img_data)
17
+ # Open the image using PIL
18
+ img = Image.open(img_buffer)
19
+ return img
20
+
21
+ # Dummy text generation function
22
+ def generate_text(input_text):
23
+ return f"Echo: {input_text}"
24
+
25
+ # Dummy multiple images generation
26
+ def generate_images(n,imgList):
27
+ images = []
28
+ for _ in range(n):
29
+ img = Image.new('RGB', (200, 200), color=(random.randint(0,255), random.randint(0,255), random.randint(0,255)))
30
+ images.append(img)
31
+ for item in imgList:
32
+ img = plt_img_base64(item)
33
+ images.append(img)
34
+ return images
35
+
36
+
37
+ # The function Gradio will call
38
+ def process_input(user_input):
39
+ #------calling llm------
40
+ #docs = eb.retriever_multi_vector.invoke(user_input, limit=5)
41
+ #r = ut.split_image_text_types(docs)
42
+ response = rag.multimodal_rag_w_sources.invoke({'input': user_input})
43
+ text_sources = response['context']['texts']
44
+ text_sources = ut.beautify_output(text_sources)
45
+ text_answer = response['answer']
46
+ #text_answer = ut.beautify_output(text_answer)
47
+ img_sources = response['context']['images']
48
+ #---------end-----------
49
+
50
+ #text_response = generate_text(user_input)
51
+ image_responses = generate_images(1,img_sources) # generate 3 random images
52
+ return text_answer,text_sources, image_responses
53
+ # Define Gradio interface
54
+ iface = gr.Interface(
55
+ fn=process_input,
56
+ inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
57
+ outputs=[
58
+ gr.Textbox(label="Response Text"),
59
+ gr.Textbox(label="Context"),
60
+ gr.Gallery(label="Response Images", columns=[3], height="auto")
61
+ ],
62
+ title="Text to Text + Multiple Images Demo",
63
+ description="Enter a query and get text plus multiple images!"
64
+ )
65
+
66
+ iface.launch()
RAG_MLM/differentiator.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import htmltabletomd
3
+ import os
4
+ data = ""
5
+ with open('data.pkl', 'rb') as f:
6
+ data = pickle.load(f)
7
+
8
+
9
+ def differentiate_table_text():
10
+ docs = []
11
+ tables = []
12
+ for doc in data:
13
+ if doc.metadata['category'] == 'Table':
14
+ tables.append(doc)
15
+ elif doc.metadata['category'] == 'CompositeElement':
16
+ docs.append(doc)
17
+ for table in tables:
18
+ table.page_content = htmltabletomd.convert_table(table.metadata['text_as_html'])
19
+ print(f"length of docs {len(docs)}, length of tables {len(tables)}")
20
+
21
+ with open('RAG_MLM/docs.pkl', 'wb') as f:
22
+ pickle.dump(docs, f)
23
+
24
+ with open('RAG_MLM/table.pkl', 'wb') as f:
25
+ pickle.dump(tables, f)
26
+
27
+ # call this for differentiator
28
+ file_path="RAG_MLM/docs.pkl"
29
+ if os.path.exists(file_path):
30
+ print(f"✅ File '{file_path}' found")
31
+ else:
32
+ print(">>>>>>> generating: differentiating tables and text")
33
+ differentiate_table_text()
34
+
RAG_MLM/embedder.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from langchain.retrievers.multi_vector import MultiVectorRetriever
3
+ from langchain_community.storage import RedisStore
4
+ from langchain_community.utilities.redis import get_client
5
+ from langchain_chroma import Chroma
6
+ from langchain_core.documents import Document
7
+ from langchain_openai import OpenAIEmbeddings
8
+ import pickle
9
+ import redis
10
+ import os
11
+
12
+ from dotenv import load_dotenv
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+
17
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
18
+ print(OPENAI_API_KEY)
19
+ openai_embed_model = OpenAIEmbeddings(model='text-embedding-3-small')
20
+
21
+ #-----remote redis------
22
+ r = redis.Redis(
23
+ host='prompt-firefly-14099.upstash.io',
24
+ port=6379,
25
+ password='ATcTAAIjcDFhMjQ2OGIwMzA4ODU0MzEyYTZlNGI2MjUwZmMzZTRhM3AxMA',
26
+ ssl=True
27
+ )
28
+ redis_url = "rediss://:ATcTAAIjcDFhMjQ2OGIwMzA4ODU0MzEyYTZlNGI2MjUwZmMzZTRhM3AxMA@prompt-firefly-14099.upstash.io:6379"
29
+ r = redis.from_url(redis_url)
30
+ redis_store = RedisStore(client=r)
31
+ #------
32
+
33
+ #-----local redis------
34
+ #client = get_client('redis://localhost:6379')
35
+ #redis_store = RedisStore(client=client)
36
+ #-----------------------------------
37
+
38
+
39
+ #-------pickle loading-----------
40
+ text_summaries, text_docs = "",""
41
+ table_summaries, table_docs = "",""
42
+ image_summaries, imgs_base64 ="",""
43
+
44
+ with open('RAG_MLM/text_summaries.pkl', 'rb') as f:
45
+ text_summaries = pickle.load(f)
46
+
47
+ with open('RAG_MLM/docs.pkl', 'rb') as f:
48
+ text_docs = pickle.load(f)
49
+
50
+ with open('RAG_MLM/table_summaries.pkl', 'rb') as f:
51
+ table_summaries = pickle.load(f)
52
+ with open('RAG_MLM/table.pkl', 'rb') as f:
53
+ table_docs = pickle.load(f)
54
+
55
+ with open('RAG_MLM/image_summaries.pkl', 'rb') as f:
56
+ image_summaries = pickle.load(f)
57
+ with open('RAG_MLM/img_base64_list.pkl', 'rb') as f:
58
+ imgs_base64 = pickle.load(f)
59
+
60
+ #--------------------------
61
+
62
+ def create_multi_vector_retriever(
63
+ docstore, vectorstore, text_summaries, texts, table_summaries, tables,
64
+ image_summaries, images
65
+ ):
66
+ """
67
+ Create retriever that indexes summaries, but returns raw images or texts
68
+ """
69
+ id_key = "doc_id"
70
+
71
+ # Create the multi-vector retriever
72
+ retriever = MultiVectorRetriever(
73
+ vectorstore=vectorstore,
74
+ docstore=docstore,
75
+ id_key=id_key,
76
+ )
77
+
78
+ # Helper function to add documents to the vectorstore and docstore
79
+ def add_documents(retriever, doc_summaries, doc_contents):
80
+ doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
81
+ summary_docs = [
82
+ Document(page_content=s, metadata={id_key: doc_ids[i]})
83
+ for i, s in enumerate(doc_summaries)
84
+ ]
85
+ retriever.vectorstore.add_documents(summary_docs)
86
+ raw_contents = [doc.page_content if isinstance(doc, Document) else doc for doc in doc_contents]
87
+
88
+ retriever.docstore.mset(list(zip(doc_ids, raw_contents)))
89
+
90
+ # Add texts, tables, and images
91
+ # Check that text_summaries is not empty before adding
92
+ if text_summaries:
93
+ add_documents(retriever, text_summaries, texts)
94
+
95
+ # Check that table_summaries is not empty before adding
96
+ if table_summaries:
97
+ add_documents(retriever, table_summaries, tables)
98
+
99
+ # Check that image_summaries is not empty before adding
100
+ if image_summaries:
101
+ add_documents(retriever, image_summaries, images)
102
+ return retriever
103
+
104
+ chroma_db = Chroma(
105
+ collection_name="mm_rag",
106
+ embedding_function=openai_embed_model,
107
+ collection_metadata={"hnsw:space": "cosine"},
108
+ )
109
+
110
+
111
+
112
+
113
+ # Create retriever
114
+ retriever_multi_vector = create_multi_vector_retriever(
115
+ redis_store, chroma_db,
116
+ text_summaries, text_docs,
117
+ table_summaries, table_docs,
118
+ image_summaries, imgs_base64,
119
+ )
120
+
121
+ #------utility------
122
+
123
+ from PIL import Image
124
+ import base64
125
+ from io import BytesIO
126
+
127
+ def plt_img_base64(img_base64):
128
+ """Disply base64 encoded string as image"""
129
+ # Decode the base64 string
130
+ img_data = base64.b64decode(img_base64)
131
+ # Create a BytesIO object
132
+ img_buffer = BytesIO(img_data)
133
+ # Open the image using PIL
134
+ img = Image.open(img_buffer)
135
+ #display(img)
136
+
137
+ # Check retrieval-----uncomment to check diretly ----
138
+ #query = "tell me about free body diagram"
139
+ #docs = retriever_multi_vector.invoke(query, limit=5)
140
+ # We get 3 relevant docs
141
+ #print(">>>>>>>>",len(docs))
142
+ #print(docs[0])
RAG_MLM/extractor.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import UnstructuredPDFLoader
2
+ import os
3
+ import pickle
4
+
5
+ doc = 'data/filteredData.pdf'
6
+
7
+ def extractor_text_image_table():
8
+ loader = UnstructuredPDFLoader(file_path=doc,
9
+ strategy='hi_res',
10
+ extract_images_in_pdf=True,
11
+ infer_table_structure=True,
12
+ # section-based chunking
13
+ chunking_strategy="by_title",
14
+ max_characters=4000, # max size of chunks
15
+ new_after_n_chars=4000, # preferred size of chunks
16
+ # smaller chunks < 2000 chars will be combined into a larger chunk
17
+ combine_text_under_n_chars=2000,
18
+ mode='elements',
19
+ image_output_dir_path='./figures')
20
+ data = loader.load()
21
+ print_retrived_data(data)
22
+ with open('data.pkl', 'wb') as f:
23
+ pickle.dump(data, f)
24
+
25
+
26
+ def print_retrived_data(data):
27
+ print(">>>>>>>>>>>>>>data retrived>>>>>>>>")
28
+ print([doc.metadata['category'] for doc in data])
29
+ print(">>>>>>>>>>>>>>end -- data retrived>>>>>>>>")
30
+
31
+
32
+ # call this to extract images
33
+ file_path="data.pkl"
34
+ if os.path.exists(file_path):
35
+ print(f"✅ File '{file_path}' found")
36
+ else:
37
+ print(">>>>>>>> generating: extracting text images tables >>>>>")
38
+ extractor_text_image_table()
RAG_MLM/main.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ extractor
2
+ |
3
+ differentiator
4
+ |
5
+ summary
6
+ |
7
+ embedder
8
+ |
9
+ ragMLM
10
+ |
11
+ app
RAG_MLM/ragMLM.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import itemgetter
2
+ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
3
+ from langchain_core.messages import HumanMessage
4
+ from . import utility as ut
5
+ from . import embedder as ed
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_openai import ChatOpenAI
8
+ import os
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
15
+ print(OPENAI_API_KEY)
16
+
17
+ chatgpt = ChatOpenAI(model_name='gpt-4o', temperature=0)
18
+ def multimodal_prompt_function(data_dict):
19
+ """
20
+ Create a multimodal prompt with both text and image context.
21
+ This function formats the provided context from `data_dict`, which contains
22
+ text, tables, and base64-encoded images. It joins the text (with table) portions
23
+ and prepares the image(s) in a base64-encoded format to be included in a
24
+ message.
25
+ The formatted text and images (context) along with the user question are used to
26
+ construct a prompt for GPT-4o
27
+ """
28
+ formatted_texts = "\n".join(data_dict["context"]["texts"])
29
+ messages = []
30
+
31
+ # Adding image(s) to the messages if present
32
+ if data_dict["context"]["images"]:
33
+ for image in data_dict["context"]["images"]:
34
+ image_message = {
35
+ "type": "image_url",
36
+ "image_url": {"url": f"data:image/jpeg;base64,{image}"},
37
+ }
38
+ messages.append(image_message)
39
+
40
+ # Adding the text for analysis
41
+ text_message = {
42
+ "type": "text",
43
+ "text": (
44
+ f"""You are an analyst tasked with understanding detailed information
45
+ and trends from text documents,
46
+ data tables, and charts and graphs in images.
47
+ You will be given context information below which will be a mix of
48
+ text, tables, and images usually of charts or graphs.
49
+ Use this information to provide answers related to the user
50
+ question.
51
+ Do not make up answers, use the provided context documents below and
52
+ answer the question to the best of your ability.
53
+
54
+ User question:
55
+ {data_dict['question']}
56
+
57
+ Context documents:
58
+ {formatted_texts}
59
+
60
+ Answer:
61
+ """
62
+ ),
63
+ }
64
+ messages.append(text_message)
65
+ return [HumanMessage(content=messages)]
66
+
67
+ # Create RAG chain
68
+ multimodal_rag = (
69
+ {
70
+ "context": itemgetter('context'),
71
+ "question": itemgetter('input'),
72
+ }
73
+ |
74
+ RunnableLambda(multimodal_prompt_function)
75
+ |
76
+ chatgpt
77
+ |
78
+ StrOutputParser()
79
+ )
80
+
81
+ # Pass input query to retriever and get context document elements
82
+ retrieve_docs = (itemgetter('input')
83
+ |
84
+ ed.retriever_multi_vector
85
+ |
86
+ RunnableLambda(ut.split_image_text_types))
87
+
88
+ # Below, we chain `.assign` calls. This takes a dict and successively
89
+ # adds keys-- "context" and "answer"-- where the value for each key
90
+ # is determined by a Runnable (function or chain executing at runtime).
91
+ # This helps in having the retrieved context along with the answer generated by GPT-4o
92
+ multimodal_rag_w_sources = (RunnablePassthrough.assign(context=retrieve_docs)
93
+ .assign(answer=multimodal_rag)
94
+ )
95
+
96
+
97
+ #------ direct testing-------
98
+ #response = multimodal_rag_w_sources.invoke({'input': query})
RAG_MLM/summary.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.output_parsers import StrOutputParser
2
+ from langchain_core.prompts import ChatPromptTemplate
3
+ from langchain_openai import ChatOpenAI
4
+ from langchain_core.runnables import RunnablePassthrough
5
+ from langchain_openai import ChatOpenAI
6
+ import base64
7
+ import os
8
+ from langchain_core.messages import HumanMessage
9
+ import os
10
+ import pickle
11
+ from dotenv import load_dotenv
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
17
+ print(OPENAI_API_KEY)
18
+
19
+ chatgpt = ChatOpenAI(model_name='gpt-4o', temperature=0)
20
+ docs = []
21
+ tables = []
22
+
23
+ with open('RAG_MLM/docs.pkl', 'rb') as f:
24
+ docs = pickle.load(f)
25
+
26
+ with open('RAG_MLM/table.pkl', 'rb') as f:
27
+ tables = pickle.load(f)
28
+
29
+ def summarize_all():
30
+ # Prompt
31
+ prompt_text = """
32
+ You are an assistant tasked with summarizing tables and text particularly for semantic retrieval.
33
+ These summaries will be embedded and used to retrieve the raw text or table elements
34
+ Give a detailed summary of the table or text below that is well optimized for retrieval.
35
+ For any tables also add in a one line description of what the table is about besides the summary.
36
+ Do not add additional words like Summary: etc.
37
+ Table or text chunk:
38
+ {element}
39
+ """
40
+ prompt = ChatPromptTemplate.from_template(prompt_text)
41
+
42
+ # Summary chain
43
+ summarize_chain = (
44
+ {"element": RunnablePassthrough()}
45
+ |
46
+ prompt
47
+ |
48
+ chatgpt
49
+ |
50
+ StrOutputParser() # extracts response as text
51
+ )
52
+
53
+ # Initialize empty summaries
54
+ text_summaries = []
55
+ table_summaries = []
56
+
57
+ text_docs = [doc.page_content for doc in docs]
58
+ table_docs = [table.page_content for table in tables]
59
+
60
+ text_summaries = summarize_chain.batch(text_docs, {"max_concurrency": 5})
61
+ table_summaries = summarize_chain.batch(table_docs, {"max_concurrency": 5})
62
+ print(text_summaries[1])
63
+ return text_summaries,table_summaries
64
+
65
+ def encode_image(image_path):
66
+ """Getting the base64 string"""
67
+ with open(image_path, "rb") as image_file:
68
+ return base64.b64encode(image_file.read()).decode("utf-8")
69
+ # create a function to summarize the image by passing a prompt to GPT-4o
70
+ def image_summarize(img_base64, prompt):
71
+ """Make image summary"""
72
+ chat = ChatOpenAI(model="gpt-4o", temperature=0)
73
+ msg = chat.invoke(
74
+ [
75
+ HumanMessage(
76
+ content=[
77
+ {"type": "text", "text": prompt},
78
+ {
79
+ "type": "image_url",
80
+ "image_url": {"url":
81
+ f"data:image/jpeg;base64,{img_base64}"},
82
+ },
83
+ ]
84
+ )
85
+ ]
86
+ )
87
+ return msg.content
88
+
89
+ def generate_img_summaries(path):
90
+ """
91
+ Generate summaries and base64 encoded strings for images
92
+ path: Path to list of .jpg files extracted by Unstructured
93
+ """
94
+ # Store base64 encoded images
95
+ img_base64_list = []
96
+ # Store image summaries
97
+ image_summaries = []
98
+
99
+ # Prompt
100
+ prompt = """You are an assistant tasked with summarizing images for retrieval.
101
+ Remember these images could potentially contain graphs, charts or
102
+ tables also.
103
+ These summaries will be embedded and used to retrieve the raw image
104
+ for question answering.
105
+ Give a detailed summary of the image that is well optimized for
106
+ retrieval.
107
+ Do not add additional words like Summary: etc.
108
+ """
109
+
110
+ # Apply to images
111
+ for img_file in sorted(os.listdir(path)):
112
+ if img_file.endswith(".jpg"):
113
+ img_path = os.path.join(path, img_file)
114
+ base64_image = encode_image(img_path)
115
+ img_base64_list.append(base64_image)
116
+ image_summaries.append(image_summarize(base64_image, prompt))
117
+ return img_base64_list, image_summaries
118
+
119
+ def save_summary():
120
+ path = './figures'
121
+ img_base64_list, image_summaries = generate_img_summaries(path)
122
+ with open('RAG_MLM/img_base64_list.pkl', 'wb') as f:
123
+ pickle.dump(img_base64_list, f)
124
+ with open('RAG_MLM/image_summaries.pkl', 'wb') as f:
125
+ pickle.dump(image_summaries, f)
126
+ text_summaries,table_summaries = summarize_all()
127
+ with open('RAG_MLM/text_summaries.pkl', 'wb') as f:
128
+ pickle.dump(text_summaries, f)
129
+ with open('RAG_MLM/table_summaries.pkl', 'wb') as f:
130
+ pickle.dump(table_summaries, f)
131
+
132
+
133
+ # call this to save summary----
134
+ file_path="RAG_MLM/text_summaries.pkl"
135
+ if os.path.exists(file_path):
136
+ print(f"✅ File '{file_path}' found")
137
+ else:
138
+ save_summary()
RAG_MLM/utility.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import base64
3
+ from langchain_core.documents import Document
4
+
5
+ # helps in detecting base64 encoded strings
6
+ def looks_like_base64(sb):
7
+ """Check if the string looks like base64"""
8
+ return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
9
+
10
+ # helps in checking if the base64 encoded image is actually an image
11
+ def is_image_data(b64data):
12
+ """
13
+ Check if the base64 data is an image by looking at the start of the data
14
+ """
15
+ image_signatures = {
16
+ b"\xff\xd8\xff": "jpg",
17
+ b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
18
+ b"\x47\x49\x46\x38": "gif",
19
+ b"\x52\x49\x46\x46": "webp",
20
+ }
21
+ try:
22
+ header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
23
+ for sig, format in image_signatures.items():
24
+ if header.startswith(sig):
25
+ return True
26
+ return False
27
+ except Exception:
28
+ return False
29
+
30
+ # returns a dictionary separating images and text (with table) elements
31
+ def split_image_text_types(docs):
32
+ """
33
+ Split base64-encoded images and texts (with tables)
34
+ """
35
+ b64_images = []
36
+ texts = []
37
+ for doc in docs:
38
+ # Check if the document is of type Document and extract page_content if so
39
+ if isinstance(doc, Document):
40
+ doc = doc.page_content.decode('utf-8')
41
+ else:
42
+ doc = doc.decode('utf-8')
43
+ if looks_like_base64(doc) and is_image_data(doc):
44
+ b64_images.append(doc)
45
+ else:
46
+ texts.append(doc)
47
+ return {"images": b64_images, "texts": texts}
48
+
49
+ def beautify_output(text_list):
50
+ # Combine list into single text
51
+ raw_text = " ".join(text_list)
52
+
53
+ # Remove unwanted characters like [|<, random numbers between newlines
54
+ cleaned_text = re.sub(r'\[\|\<\s*\d*\s*', '', raw_text)
55
+ cleaned_text = re.sub(r'\n+', '\n\n', cleaned_text) # Replace multiple \n with 2 newlines
56
+ cleaned_text = re.sub(r'\s+', ' ', cleaned_text) # Replace multiple spaces with single space
57
+ cleaned_text = re.sub(r'([.!?])\s*', r'\1\n\n', cleaned_text) # Newline after periods, exclamation, question marks
58
+
59
+ # Remove weird number artifacts (like 24) not attached to a sentence
60
+ cleaned_text = re.sub(r'(\n\n)\d+(\n\n)', r'\1', cleaned_text)
61
+
62
+ # Strip leading/trailing spaces
63
+ cleaned_text = cleaned_text.strip()
64
+
65
+ return cleaned_text
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import random
4
+ from RAG_MLM import extractor as ex
5
+ from RAG_MLM import differentiator as dif
6
+ from RAG_MLM import summary as sm
7
+ from RAG_MLM import embedder as eb
8
+ from RAG_MLM import ragMLM as rag
9
+ from RAG_MLM import utility as ut
10
+ from PIL import Image
11
+ import base64
12
+ from io import BytesIO
13
+ import os
14
+
15
+ #----image utility-----
16
+ def plt_img_base64(img_base64):
17
+ """Disply base64 encoded string as image"""
18
+ # Decode the base64 string
19
+ img_data = base64.b64decode(img_base64)
20
+ # Create a BytesIO object
21
+ img_buffer = BytesIO(img_data)
22
+ # Open the image using PIL
23
+ img = Image.open(img_buffer)
24
+ return img
25
+
26
+ # Dummy text generation function
27
+ def generate_text(input_text):
28
+ return f"Echo: {input_text}"
29
+
30
+ # Dummy multiple images generation
31
+ def generate_images(n,imgList):
32
+ images = []
33
+ for _ in range(n):
34
+ img = Image.new('RGB', (200, 200), color=(random.randint(0,255), random.randint(0,255), random.randint(0,255)))
35
+ images.append(img)
36
+ for item in imgList:
37
+ img = plt_img_base64(item)
38
+ images.append(img)
39
+ return images
40
+
41
+
42
+ # Main processing function
43
+ def process_input(query):
44
+ response_text = f"Processed: {query}"
45
+ context = "This is some dummy context."
46
+ images = [["https://via.placeholder.com/150", "https://via.placeholder.com/150"]]
47
+ return response_text, context, images
48
+
49
+ # Wrapper for utility function
50
+ def utility_function_wrapper(input_text):
51
+ ex.extractor_text_image_table()
52
+ dif.differentiate_table_text
53
+ sm.save_summary()
54
+ return "sucess:- generated files"
55
+
56
+ # Dummy API Key handler
57
+ def save_api_key(api_key):
58
+ # you can save this key to a file, env var, or in memory
59
+ print(f"Received API Key: {api_key}")
60
+ os.environ["OPENAI_API_KEY"] = api_key
61
+ return "✅ API Key saved in environment successfully!"
62
+
63
+ # Function to clear API Key from environment
64
+ def clear_api_key():
65
+ if "OPENAI_API_KEY" in os.environ:
66
+ del os.environ["OPENAI_API_KEY"]
67
+ return "❌ API Key cleared from environment!"
68
+ else:
69
+ return "⚠️ No API Key found to clear."
70
+ # The function Gradio will call
71
+ def process_input(user_input):
72
+ #------calling llm------
73
+ #docs = eb.retriever_multi_vector.invoke(user_input, limit=5)
74
+ #r = ut.split_image_text_types(docs)
75
+ response = rag.multimodal_rag_w_sources.invoke({'input': user_input})
76
+ text_sources = response['context']['texts']
77
+ text_sources = ut.beautify_output(text_sources)
78
+ text_answer = response['answer']
79
+ #text_answer = ut.beautify_output(text_answer)
80
+ img_sources = response['context']['images']
81
+ #---------end-----------
82
+
83
+ #text_response = generate_text(user_input)
84
+ image_responses = generate_images(1,img_sources) # generate 3 random images
85
+ return text_answer,text_sources, image_responses
86
+ # Define Gradio interface
87
+ # Main UI
88
+ with gr.Blocks() as iface:
89
+ with gr.Tab("Main App"):
90
+ input_query = gr.Textbox(lines=2, placeholder="Enter your query here...")
91
+ submit_button = gr.Button("Submit Query")
92
+ response_text = gr.Textbox(label="Response Text")
93
+ context = gr.Textbox(label="Context")
94
+ gallery = gr.Gallery(label="Response Images", columns=[3], height="auto")
95
+
96
+ submit_button.click(
97
+ process_input,
98
+ inputs=input_query,
99
+ outputs=[response_text, context, gallery]
100
+ )
101
+
102
+ with gr.Tab("Utility Functions"):
103
+ utility_input = gr.Textbox(lines=2, placeholder="Enter input for utility...")
104
+ utility_button = gr.Button("Run Utility Function")
105
+ utility_output = gr.Textbox(label="Utility Function Output")
106
+
107
+ utility_button.click(
108
+ utility_function_wrapper,
109
+ #inputs=utility_input,
110
+ outputs=utility_output
111
+ )
112
+
113
+ with gr.Tab("API Key Config"):
114
+ api_key_input = gr.Textbox(type="password", placeholder="Enter your API key securely...")
115
+ api_key_button = gr.Button("Save API Key")
116
+ clear_api_key_button = gr.Button("Clear API Key")
117
+ api_key_output = gr.Textbox(label="API Key Save Status")
118
+
119
+ api_key_button.click(
120
+ save_api_key,
121
+ inputs=api_key_input,
122
+ outputs=api_key_output
123
+ )
124
+
125
+ clear_api_key_button.click(
126
+ clear_api_key,
127
+ inputs=[],
128
+ outputs=api_key_output
129
+ )
130
+
131
+ # Launch
132
+ iface.launch()
133
+
134
+ #------
135
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-openai
3
+ langchain-chroma
4
+ langchain-community
5
+ langchain-experimental
6
+ htmltabletomd
7
+ pdf2image
8
+ pillow
9
+
10
+ unstructured[all-docs]
11
+ pdfminer
12
+ # install OCR dependencies for unstructured
13
+ pytesseract
14
+ poppler-utils
15
+ langchain-experimental
16
+ pdfminer
17
+ redis
18
+ gradio
resistest.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import redis
2
+
3
+ r = redis.Redis(
4
+ host='prompt-firefly-14099.upstash.io',
5
+ port=6379,
6
+ password='ATcTAAIjcDFhMjQ2OGIwMzA4ODU0MzEyYTZlNGI2MjUwZmMzZTRhM3AxMA',
7
+ ssl=True
8
+ )
9
+
10
+ r.set('foo', 'bar')
11
+ print(r.get('foo'))