Spaces:
Sleeping
Sleeping
| from llama_index.llms.openai import OpenAI | |
| from llama_index.core import load_index_from_storage, get_response_synthesizer | |
| import matplotlib.pyplot as plt | |
| import os | |
| from PIL import Image | |
| from llama_index.core import PromptTemplate | |
| from awsfunctions import download_files_from_s3, check_file_exists_in_s3 | |
| import tempfile, shutil | |
| import streamlit as st | |
| st.cache_resource() | |
| def get_image_from_s3(image_path): | |
| temp_dir = tempfile.mkdtemp() | |
| download_files_from_s3(temp_dir, [image_path]) | |
| image = Image.open(os.path.join(temp_dir, image_path)) | |
| shutil.rmtree(temp_dir) | |
| return image | |
| def plot_images(image_paths): | |
| images_shown = 0 | |
| plt.figure(figsize=(16, 9)) | |
| for img_path in image_paths: | |
| if check_file_exists_in_s3(img_path): | |
| image = get_image_from_s3(img_path) | |
| st.image(image) | |
| # plt.subplot(2, 3, images_shown + 1) | |
| # plt.imshow(image) | |
| # plt.xticks([]) | |
| # plt.yticks([]) | |
| # images_shown += 1 | |
| # if images_shown >= 6: | |
| # break | |
| def retrieve_and_query(query, retriever_engine): | |
| retrieval_results = retriever_engine.retrieve(query) | |
| qa_tmpl_str = ( | |
| "Context information is below.\n" | |
| "---------------------\n" | |
| "{context_str}\n" | |
| "---------------------\n" | |
| "Given the context information , " | |
| "answer the query in detail.\n" | |
| "Query: {query_str}\n" | |
| "Answer: " | |
| ) | |
| qa_tmpl = PromptTemplate(qa_tmpl_str) | |
| llm = OpenAI(model="gpt-4o-mini", temperature=0) | |
| response_synthesizer = get_response_synthesizer(response_mode="refine", text_qa_template=qa_tmpl, llm=llm) | |
| response = response_synthesizer.synthesize(query, nodes=retrieval_results) | |
| retrieved_image_path_list = [] | |
| for node in retrieval_results: | |
| if (node.metadata['file_type'] == 'image/jpeg') or (node.metadata['file_type'] == 'image/png'): | |
| if node.score > 0.25: | |
| retrieved_image_path_list.append(node.metadata['file_path']) | |
| return response, retrieved_image_path_list | |