# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import base64 import streamlit as st import fitz import torch from io import BytesIO from PIL import Image import requests from transformers import MllamaForConditionalGeneration, AutoProcessor @st.cache_resource def initialize_vlm(): """Initialize and load the Vision-Language Model (VLM) for image description from a specified model ID.""" model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" vlm_model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16) vlm_processor = AutoProcessor.from_pretrained(model_id) return vlm_model, vlm_processor def get_b64_image_from_content(image_content): """Convert image content to base64 encoded string.""" img = Image.open(BytesIO(image_content)) if img.mode != 'RGB': img = img.convert('RGB') buffered = BytesIO() img.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def is_graph(image_content): """Determine if an image is a graph, plot, chart, or table.""" res = describe_image(image_content) return any(keyword in res.lower() for keyword in ["graph", "plot", "chart", "table"]) def process_graph(image_content, llm): """Process a graph image and generate a description.""" deplot_description = process_graph_deplot(image_content) response = llm.complete("Your responsibility is to explain charts. You are an expert in describing the responses of linearized tables into plain English text for LLMs to use. Explain the following linearized table. " + deplot_description) return response.text def describe_image(image_content): """Generate a description of an image using the multimodal LLM.""" vlm_model, vlm_processor = initialize_vlm() image = Image.open(BytesIO(image_content)) messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "Describe what you see in this image"} ] } ] text = vlm_processor.apply_chat_template(messages, add_generation_prompt=True) inputs = vlm_processor(text=text, images=image, return_tensors="pt").to(vlm_model.device) output = vlm_model.generate(**inputs, max_new_tokens=1024) text = vlm_processor.decode(output[0], skip_special_tokens=True) return text def process_graph_deplot(image_content): """Process a graph image using NVIDIA's Deplot API.""" invoke_url = "https://ai.api.nvidia.com/v1/vlm/google/deplot" image_b64 = get_b64_image_from_content(image_content) api_key = os.getenv("NVIDIA_API_KEY") if not api_key: raise ValueError("NVIDIA API Key is not set. Please set the NVIDIA_API_KEY environment variable.") headers = { "Authorization": f"Bearer {api_key}", "Accept": "application/json" } payload = { "messages": [ { "role": "user", "content": f'Generate underlying data table of the figure below: ' } ], "max_tokens": 1024, "temperature": 0.20, "top_p": 0.20, "stream": False } response = requests.post(invoke_url, headers=headers, json=payload) return response.json()["choices"][0]['message']['content'] def extract_text_around_item(text_blocks, bbox, page_height, threshold_percentage=0.1): """Extract text above and below a given bounding box on a page.""" before_text, after_text = "", "" vertical_threshold_distance = page_height * threshold_percentage horizontal_threshold_distance = bbox.width * threshold_percentage for block in text_blocks: block_bbox = fitz.Rect(block[:4]) vertical_distance = min(abs(block_bbox.y1 - bbox.y0), abs(block_bbox.y0 - bbox.y1)) horizontal_overlap = max(0, min(block_bbox.x1, bbox.x1) - max(block_bbox.x0, bbox.x0)) if vertical_distance <= vertical_threshold_distance and horizontal_overlap >= -horizontal_threshold_distance: if block_bbox.y1 < bbox.y0 and not before_text: before_text = block[4] elif block_bbox.y0 > bbox.y1 and not after_text: after_text = block[4] break return before_text, after_text def process_text_blocks(text_blocks, char_count_threshold=500): """Group text blocks based on a character count threshold.""" current_group = [] grouped_blocks = [] current_char_count = 0 for block in text_blocks: if block[-1] == 0: # Check if the block is of text type block_text = block[4] block_char_count = len(block_text) if current_char_count + block_char_count <= char_count_threshold: current_group.append(block) current_char_count += block_char_count else: if current_group: grouped_content = "\n".join([b[4] for b in current_group]) grouped_blocks.append((current_group[0], grouped_content)) current_group = [block] current_char_count = block_char_count # Append the last group if current_group: grouped_content = "\n".join([b[4] for b in current_group]) grouped_blocks.append((current_group[0], grouped_content)) return grouped_blocks def save_uploaded_file(uploaded_file): """Save an uploaded file to a temporary directory.""" temp_dir = os.path.join(os.getcwd(), "vectorstore", "ppt_references", "tmp") os.makedirs(temp_dir, exist_ok=True) temp_file_path = os.path.join(temp_dir, uploaded_file.name) with open(temp_file_path, "wb") as temp_file: temp_file.write(uploaded_file.read()) return temp_file_path