Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import fitz | |
| from io import BytesIO | |
| from PIL import Image | |
| import requests | |
| from llama_index.llms.nvidia import NVIDIA | |
| import streamlit as st | |
| from llama_index.core import Settings | |
| from llama_index.core import VectorStoreIndex, StorageContext | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from llama_index.vector_stores.milvus import MilvusVectorStore | |
| from llama_index.embeddings.nvidia import NVIDIAEmbedding | |
| from pptx import Presentation | |
| import subprocess | |
| from llama_index.core import Document | |
| def set_environment_variables(): | |
| """Set necessary environment variables.""" | |
| os.environ["NVIDIA_API_KEY"] = "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" #set API key | |
| def get_b64_image_from_content(image_content): | |
| """Convert image content to base64 encoded string.""" | |
| img = Image.open(BytesIO(image_content)) | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| buffered = BytesIO() | |
| img.save(buffered, format="JPEG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def is_graph(image_content): | |
| """Determine if an image is a graph, plot, chart, or table.""" | |
| res = describe_image(image_content) | |
| return any(keyword in res.lower() for keyword in ["graph", "plot", "chart", "table"]) | |
| def process_graph(image_content): | |
| """Process a graph image and generate a description.""" | |
| deplot_description = process_graph_deplot(image_content) | |
| mixtral = NVIDIA(model_name="meta/llama-3.1-70b-instruct") | |
| response = mixtral.complete("Your responsibility is to explain charts. You are an expert in describing the responses of linearized tables into plain English text for LLMs to use. Explain the following linearized table. " + deplot_description) | |
| return response.text | |
| def describe_image(image_content): | |
| """Generate a description of an image using NVIDIA API.""" | |
| image_b64 = get_b64_image_from_content(image_content) | |
| invoke_url = "https://ai.api.nvidia.com/v1/vlm/nvidia/neva-22b" | |
| api_key = os.getenv("NVIDIA_API_KEY") | |
| if not api_key: | |
| raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.") | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Accept": "application/json" | |
| } | |
| payload = { | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": f'Describe what you see in this image. <img src="data:image/png;base64,{image_b64}" />' | |
| } | |
| ], | |
| "max_tokens": 1024, | |
| "temperature": 0.20, | |
| "top_p": 0.70, | |
| "seed": 0, | |
| "stream": False | |
| } | |
| response = requests.post(invoke_url, headers=headers, json=payload) | |
| return response.json()["choices"][0]['message']['content'] | |
| def process_graph_deplot(image_content): | |
| """Process a graph image using NVIDIA's Deplot API.""" | |
| invoke_url = "https://ai.api.nvidia.com/v1/vlm/google/deplot" | |
| image_b64 = get_b64_image_from_content(image_content) | |
| api_key = os.getenv("NVIDIA_API_KEY") | |
| if not api_key: | |
| raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.") | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "Accept": "application/json" | |
| } | |
| payload = { | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": f'Generate underlying data table of the figure below: <img src="data:image/png;base64,{image_b64}" />' | |
| } | |
| ], | |
| "max_tokens": 1024, | |
| "temperature": 0.20, | |
| "top_p": 0.20, | |
| "stream": False | |
| } | |
| response = requests.post(invoke_url, headers=headers, json=payload) | |
| return response.json()["choices"][0]['message']['content'] | |
| def extract_text_around_item(text_blocks, bbox, page_height, threshold_percentage=0.1): | |
| """Extract text above and below a given bounding box on a page.""" | |
| before_text, after_text = "", "" | |
| vertical_threshold_distance = page_height * threshold_percentage | |
| horizontal_threshold_distance = bbox.width * threshold_percentage | |
| for block in text_blocks: | |
| block_bbox = fitz.Rect(block[:4]) | |
| vertical_distance = min(abs(block_bbox.y1 - bbox.y0), abs(block_bbox.y0 - bbox.y1)) | |
| horizontal_overlap = max(0, min(block_bbox.x1, bbox.x1) - max(block_bbox.x0, bbox.x0)) | |
| if vertical_distance <= vertical_threshold_distance and horizontal_overlap >= -horizontal_threshold_distance: | |
| if block_bbox.y1 < bbox.y0 and not before_text: | |
| before_text = block[4] | |
| elif block_bbox.y0 > bbox.y1 and not after_text: | |
| after_text = block[4] | |
| break | |
| return before_text, after_text | |
| def process_text_blocks(text_blocks, char_count_threshold=500): | |
| """Group text blocks based on a character count threshold.""" | |
| current_group = [] | |
| grouped_blocks = [] | |
| current_char_count = 0 | |
| for block in text_blocks: | |
| if block[-1] == 0: # Check if the block is of text type | |
| block_text = block[4] | |
| block_char_count = len(block_text) | |
| if current_char_count + block_char_count <= char_count_threshold: | |
| current_group.append(block) | |
| current_char_count += block_char_count | |
| else: | |
| if current_group: | |
| grouped_content = "\n".join([b[4] for b in current_group]) | |
| grouped_blocks.append((current_group[0], grouped_content)) | |
| current_group = [block] | |
| current_char_count = block_char_count | |
| # Append the last group | |
| if current_group: | |
| grouped_content = "\n".join([b[4] for b in current_group]) | |
| grouped_blocks.append((current_group[0], grouped_content)) | |
| return grouped_blocks | |
| def save_uploaded_file(uploaded_file): | |
| """Save an uploaded file to a temporary directory.""" | |
| temp_dir = os.path.join(os.getcwd(), "vectorstore", "ppt_references", "tmp") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_file_path = os.path.join(temp_dir, uploaded_file.name) | |
| with open(temp_file_path, "wb") as temp_file: | |
| temp_file.write(uploaded_file.read()) | |
| return temp_file_path | |
| # 2ème fichier du code | |
| def get_pdf_documents(pdf_file): | |
| """Process a PDF file and extract text, tables, and images.""" | |
| all_pdf_documents = [] | |
| ongoing_tables = {} | |
| try: | |
| f = fitz.open(stream=pdf_file.read(), filetype="pdf") | |
| except Exception as e: | |
| print(f"Error opening or processing the PDF file: {e}") | |
| return [] | |
| for i in range(len(f)): | |
| page = f[i] | |
| text_blocks = [block for block in page.get_text("blocks", sort=True) | |
| if block[-1] == 0 and not (block[1] < page.rect.height * 0.1 or block[3] > page.rect.height * 0.9)] | |
| grouped_text_blocks = process_text_blocks(text_blocks) | |
| table_docs, table_bboxes, ongoing_tables = parse_all_tables(pdf_file.name, page, i, text_blocks, ongoing_tables) | |
| all_pdf_documents.extend(table_docs) | |
| image_docs = parse_all_images(pdf_file.name, page, i, text_blocks) | |
| all_pdf_documents.extend(image_docs) | |
| for text_block_ctr, (heading_block, content) in enumerate(grouped_text_blocks, 1): | |
| heading_bbox = fitz.Rect(heading_block[:4]) | |
| if not any(heading_bbox.intersects(table_bbox) for table_bbox in table_bboxes): | |
| bbox = {"x1": heading_block[0], "y1": heading_block[1], "x2": heading_block[2], "x3": heading_block[3]} | |
| text_doc = Document( | |
| text=f"{heading_block[4]}\n{content}", | |
| metadata={ | |
| **bbox, | |
| "type": "text", | |
| "page_num": i, | |
| "source": f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}" | |
| }, | |
| id_=f"{pdf_file.name[:-4]}-page{i}-block{text_block_ctr}" | |
| ) | |
| all_pdf_documents.append(text_doc) | |
| f.close() | |
| return all_pdf_documents | |
| def parse_all_tables(filename, page, pagenum, text_blocks, ongoing_tables): | |
| """Extract tables from a PDF page.""" | |
| table_docs = [] | |
| table_bboxes = [] | |
| try: | |
| tables = page.find_tables(horizontal_strategy="lines_strict", vertical_strategy="lines_strict") | |
| for tab in tables: | |
| if not tab.header.external: | |
| pandas_df = tab.to_pandas() | |
| tablerefdir = os.path.join(os.getcwd(), "vectorstore/table_references") | |
| os.makedirs(tablerefdir, exist_ok=True) | |
| df_xlsx_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.xlsx") | |
| pandas_df.to_excel(df_xlsx_path) | |
| bbox = fitz.Rect(tab.bbox) | |
| table_bboxes.append(bbox) | |
| before_text, after_text = extract_text_around_item(text_blocks, bbox, page.rect.height) | |
| table_img = page.get_pixmap(clip=bbox) | |
| table_img_path = os.path.join(tablerefdir, f"table{len(table_docs)+1}-page{pagenum}.jpg") | |
| table_img.save(table_img_path) | |
| description = process_graph(table_img.tobytes()) | |
| caption = before_text.replace("\n", " ") + description + after_text.replace("\n", " ") | |
| if before_text == "" and after_text == "": | |
| caption = " ".join(tab.header.names) | |
| table_metadata = { | |
| "source": f"{filename[:-4]}-page{pagenum}-table{len(table_docs)+1}", | |
| "dataframe": df_xlsx_path, | |
| "image": table_img_path, | |
| "caption": caption, | |
| "type": "table", | |
| "page_num": pagenum | |
| } | |
| all_cols = ", ".join(list(pandas_df.columns.values)) | |
| doc = Document(text=f"This is a table with the caption: {caption}\nThe columns are {all_cols}", metadata=table_metadata) | |
| table_docs.append(doc) | |
| except Exception as e: | |
| print(f"Error during table extraction: {e}") | |
| return table_docs, table_bboxes, ongoing_tables | |
| def parse_all_images(filename, page, pagenum, text_blocks): | |
| """Extract images from a PDF page.""" | |
| image_docs = [] | |
| image_info_list = page.get_image_info(xrefs=True) | |
| page_rect = page.rect | |
| for image_info in image_info_list: | |
| xref = image_info['xref'] | |
| if xref == 0: | |
| continue | |
| img_bbox = fitz.Rect(image_info['bbox']) | |
| if img_bbox.width < page_rect.width / 20 or img_bbox.height < page_rect.height / 20: | |
| continue | |
| extracted_image = page.parent.extract_image(xref) | |
| image_data = extracted_image["image"] | |
| imgrefpath = os.path.join(os.getcwd(), "vectorstore/image_references") | |
| os.makedirs(imgrefpath, exist_ok=True) | |
| image_path = os.path.join(imgrefpath, f"image{xref}-page{pagenum}.png") | |
| with open(image_path, "wb") as img_file: | |
| img_file.write(image_data) | |
| before_text, after_text = extract_text_around_item(text_blocks, img_bbox, page.rect.height) | |
| if before_text == "" and after_text == "": | |
| continue | |
| image_description = " " | |
| if is_graph(image_data): | |
| image_description = process_graph(image_data) | |
| caption = before_text.replace("\n", " ") + image_description + after_text.replace("\n", " ") | |
| image_metadata = { | |
| "source": f"{filename[:-4]}-page{pagenum}-image{xref}", | |
| "image": image_path, | |
| "caption": caption, | |
| "type": "image", | |
| "page_num": pagenum | |
| } | |
| image_docs.append(Document(text="This is an image with the caption: " + caption, metadata=image_metadata)) | |
| return image_docs | |
| def process_ppt_file(ppt_path): | |
| """Process a PowerPoint file.""" | |
| pdf_path = convert_ppt_to_pdf(ppt_path) | |
| images_data = convert_pdf_to_images(pdf_path) | |
| slide_texts = extract_text_and_notes_from_ppt(ppt_path) | |
| processed_data = [] | |
| for (image_path, page_num), (slide_text, notes) in zip(images_data, slide_texts): | |
| if notes: | |
| notes = "\n\nThe speaker notes for this slide are: " + notes | |
| with open(image_path, 'rb') as image_file: | |
| image_content = image_file.read() | |
| image_description = " " | |
| if is_graph(image_content): | |
| image_description = process_graph(image_content) | |
| image_metadata = { | |
| "source": f"{os.path.basename(ppt_path)}", | |
| "image": image_path, | |
| "caption": slide_text + image_description + notes, | |
| "type": "image", | |
| "page_num": page_num | |
| } | |
| processed_data.append(Document(text="This is a slide with the text: " + slide_text + image_description, metadata=image_metadata)) | |
| return processed_data | |
| def convert_ppt_to_pdf(ppt_path): | |
| """Convert a PowerPoint file to PDF using LibreOffice.""" | |
| base_name = os.path.basename(ppt_path) | |
| ppt_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_') | |
| new_dir_path = os.path.abspath("vectorstore/ppt_references") | |
| os.makedirs(new_dir_path, exist_ok=True) | |
| pdf_path = os.path.join(new_dir_path, f"{ppt_name_without_ext}.pdf") | |
| command = ['libreoffice', '--headless', '--convert-to', 'pdf', '--outdir', new_dir_path, ppt_path] | |
| subprocess.run(command, check=True) | |
| return pdf_path | |
| def convert_pdf_to_images(pdf_path): | |
| """Convert a PDF file to a series of images using PyMuPDF.""" | |
| doc = fitz.open(pdf_path) | |
| base_name = os.path.basename(pdf_path) | |
| pdf_name_without_ext = os.path.splitext(base_name)[0].replace(' ', '_') | |
| new_dir_path = os.path.join(os.getcwd(), "vectorstore/ppt_references") | |
| os.makedirs(new_dir_path, exist_ok=True) | |
| image_paths = [] | |
| for page_num in range(len(doc)): | |
| page = doc.load_page(page_num) | |
| pix = page.get_pixmap() | |
| output_image_path = os.path.join(new_dir_path, f"{pdf_name_without_ext}_{page_num:04d}.png") | |
| pix.save(output_image_path) | |
| image_paths.append((output_image_path, page_num)) | |
| doc.close() | |
| return image_paths | |
| def extract_text_and_notes_from_ppt(ppt_path): | |
| """Extract text and notes from a PowerPoint file.""" | |
| prs = Presentation(ppt_path) | |
| text_and_notes = [] | |
| for slide in prs.slides: | |
| slide_text = ' '.join([shape.text for shape in slide.shapes if hasattr(shape, "text")]) | |
| try: | |
| notes = slide.notes_slide.notes_text_frame.text if slide.notes_slide else '' | |
| except: | |
| notes = '' | |
| text_and_notes.append((slide_text, notes)) | |
| return text_and_notes | |
| def load_multimodal_data(files): | |
| """Load and process multiple file types.""" | |
| documents = [] | |
| for file in files: | |
| file_extension = os.path.splitext(file.name.lower())[1] | |
| if file_extension in ('.png', '.jpg', '.jpeg'): | |
| image_content = file.read() | |
| image_text = describe_image(image_content) | |
| doc = Document(text=image_text, metadata={"source": file.name, "type": "image"}) | |
| documents.append(doc) | |
| elif file_extension == '.pdf': | |
| try: | |
| pdf_documents = get_pdf_documents(file) | |
| documents.extend(pdf_documents) | |
| except Exception as e: | |
| print(f"Error processing PDF {file.name}: {e}") | |
| elif file_extension in ('.ppt', '.pptx'): | |
| try: | |
| ppt_documents = process_ppt_file(save_uploaded_file(file)) | |
| documents.extend(ppt_documents) | |
| except Exception as e: | |
| print(f"Error processing PPT {file.name}: {e}") | |
| else: | |
| text = file.read().decode("utf-8") | |
| doc = Document(text=text, metadata={"source": file.name, "type": "text"}) | |
| documents.append(doc) | |
| return documents | |
| def load_data_from_directory(directory): | |
| """Load and process multiple file types from a directory.""" | |
| documents = [] | |
| for filename in os.listdir(directory): | |
| filepath = os.path.join(directory, filename) | |
| file_extension = os.path.splitext(filename.lower())[1] | |
| print(filename) | |
| if file_extension in ('.png', '.jpg', '.jpeg'): | |
| with open(filepath, "rb") as image_file: | |
| image_content = image_file.read() | |
| image_text = describe_image(image_content) | |
| doc = Document(text=image_text, metadata={"source": filename, "type": "image"}) | |
| print(doc) | |
| documents.append(doc) | |
| elif file_extension == '.pdf': | |
| with open(filepath, "rb") as pdf_file: | |
| try: | |
| pdf_documents = get_pdf_documents(pdf_file) | |
| documents.extend(pdf_documents) | |
| except Exception as e: | |
| print(f"Error processing PDF {filename}: {e}") | |
| elif file_extension in ('.ppt', '.pptx'): | |
| try: | |
| ppt_documents = process_ppt_file(filepath) | |
| documents.extend(ppt_documents) | |
| print(ppt_documents) | |
| except Exception as e: | |
| print(f"Error processing PPT {filename}: {e}") | |
| else: | |
| with open(filepath, "r", encoding="utf-8") as text_file: | |
| text = text_file.read() | |
| doc = Document(text=text, metadata={"source": filename, "type": "text"}) | |
| documents.append(doc) | |
| return documents | |
| # 3ème fichier | |
| # Set up the page configuration | |
| st.set_page_config(layout="wide") | |
| # Initialize settings | |
| def initialize_settings(): | |
| Settings.embed_model = NVIDIAEmbedding(model="nvidia/nv-embedqa-e5-v5", truncate="END") | |
| Settings.llm = NVIDIA(model="meta/llama-3.1-70b-instruct") | |
| Settings.text_splitter = SentenceSplitter(chunk_size=600) | |
| # Create index from documents | |
| def create_index(documents): | |
| vector_store = MilvusVectorStore( | |
| host = "127.0.0.1", | |
| port = 19530, | |
| dim = 1024 | |
| ) | |
| # vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1024, overwrite=True) #For CPU only vector store | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| return VectorStoreIndex.from_documents(documents, storage_context=storage_context) | |
| # Main function to run the Streamlit app | |
| def main(): | |
| set_environment_variables() | |
| initialize_settings() | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.title("Multimodal RAG") | |
| input_method = st.radio("Choose input method:", ("Upload Files", "Enter Directory Path")) | |
| if input_method == "Upload Files": | |
| uploaded_files = st.file_uploader("Drag and drop files here", accept_multiple_files=True) | |
| if uploaded_files and st.button("Process Files"): | |
| with st.spinner("Processing files..."): | |
| documents = load_multimodal_data(uploaded_files) | |
| st.session_state['index'] = create_index(documents) | |
| st.session_state['history'] = [] | |
| st.success("Files processed and index created!") | |
| else: | |
| directory_path = st.text_input("Enter directory path:") | |
| if directory_path and st.button("Process Directory"): | |
| if os.path.isdir(directory_path): | |
| with st.spinner("Processing directory..."): | |
| documents = load_data_from_directory(directory_path) | |
| st.session_state['index'] = create_index(documents) | |
| st.session_state['history'] = [] | |
| st.success("Directory processed and index created!") | |
| else: | |
| st.error("Invalid directory path. Please enter a valid path.") | |
| with col2: | |
| if 'index' in st.session_state: | |
| st.title("Chat") | |
| if 'history' not in st.session_state: | |
| st.session_state['history'] = [] | |
| query_engine = st.session_state['index'].as_query_engine(similarity_top_k=5, streaming=True) | |
| user_input = st.chat_input("Enter your query:") | |
| # Display chat messages | |
| chat_container = st.container() | |
| with chat_container: | |
| for message in st.session_state['history']: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if user_input: | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| st.session_state['history'].append({"role": "user", "content": user_input}) | |
| with st.chat_message("assistant"): | |
| message_placeholder = st.empty() | |
| full_response = "" | |
| response = query_engine.query(user_input) | |
| for token in response.response_gen: | |
| full_response += token | |
| message_placeholder.markdown(full_response + "▌") | |
| message_placeholder.markdown(full_response) | |
| st.session_state['history'].append({"role": "assistant", "content": full_response}) | |
| # Add a clear button | |
| if st.button("Clear Chat"): | |
| st.session_state['history'] = [] | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() | |