| | import streamlit as st |
| | from PIL import Image |
| | from function import bounding_box |
| | from tempfile import NamedTemporaryFile |
| | import os |
| | from function import ImageCaptionTools, ObjectDetectionTool |
| | from langchain.agents import initialize_agent, AgentType |
| | from langchain_google_genai import ChatGoogleGenerativeAI |
| | from langchain.memory import ConversationBufferWindowMemory |
| | from htmlTemplate import css, bot_template, user_template |
| | import random |
| |
|
| | DIR = './temp' |
| | if not os.path.exists(DIR): |
| | os.mkdir(DIR) |
| |
|
| | if "image_processed" not in st.session_state: |
| | DIR_PATH = os.path.join(DIR, str(random.randint(1,999999999))) |
| | st.session_state.dirpath = DIR_PATH |
| | if not os.path.exists(DIR_PATH): |
| | os.mkdir(DIR_PATH) |
| |
|
| | def delete_temp_files(): |
| | for filename in os.listdir(st.session_state.dirpath): |
| | file_path = os.path.join(st.session_state.dirpath, filename) |
| | if os.path.isfile(file_path): |
| | os.unlink(file_path) |
| |
|
| |
|
| |
|
| | |
| | def agent_init(): |
| | tools = [ImageCaptionTools(), ObjectDetectionTool()] |
| | llm = ChatGoogleGenerativeAI(model="gemini-pro") |
| | memory = ConversationBufferWindowMemory(memory_key='chat_history', |
| | k=5, |
| | return_messages=True) |
| | agents = initialize_agent( |
| | agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION, |
| | llm=llm, |
| | tools=tools, |
| | max_iterations=5, |
| | verbose=True, |
| | memory=memory |
| | ) |
| | return agents |
| |
|
| |
|
| |
|
| | def main(): |
| | st.set_page_config( |
| | page_title="Chat with an Image", |
| | page_icon="🖼️", |
| | layout="wide" |
| | ) |
| | st.write(css, unsafe_allow_html=True) |
| | st.title("Chat with an Image 🖼️") |
| | agent = agent_init() |
| |
|
| | |
| | if 'reloaded' not in st.session_state: |
| | st.session_state.reloaded = False |
| | else: |
| | st.session_state.reloaded = True |
| |
|
| | if "image_processed" not in st.session_state: |
| | st.session_state.image_processed = None |
| |
|
| | if "result_bounding" not in st.session_state: |
| | st.session_state.result_bounding = None |
| |
|
| | |
| |
|
| | col1, col2 = st.columns([1, 1]) |
| | with col1: |
| | image_upload = st.file_uploader(label="Please Upload Your Image", type=['jpg', 'png', 'jpeg']) |
| | if not image_upload: |
| | st.warning("Please upload your image") |
| | else: |
| | st.image( |
| | image_upload, |
| | use_column_width=True |
| | ) |
| | click_process = st.button("Process Image", disabled=not image_upload) |
| | if click_process: |
| | delete_temp_files() |
| | with NamedTemporaryFile(dir=st.session_state.dirpath, delete=False) as f: |
| | f.write(image_upload.getbuffer()) |
| | st.session_state.image_path = f.name |
| | st.session_state.image_processed = True |
| |
|
| | if (st.session_state.image_processed and st.session_state.result_bounding is None) or click_process: |
| | with st.spinner("Please Wait"): |
| | result_bounding = bounding_box(st.session_state.image_path) |
| | st.session_state.result_bounding = result_bounding |
| |
|
| | |
| | if st.session_state.result_bounding is not None: |
| | with st.expander("Show Image (Bounding Box)"): |
| | st.image(st.session_state.result_bounding) |
| |
|
| | with col2: |
| | user_question = st.text_area("Ask About your image", |
| | disabled=not st.session_state.image_processed, |
| | max_chars=150) |
| | click_ask = st.button("Ask Question", disabled=not st.session_state.image_processed) |
| | if click_ask: |
| | st.write(user_template.replace("{{MSG}}", user_question), unsafe_allow_html=True) |
| | with st.spinner("Doraemon Searching for Answer🔎"): |
| | chat_history = agent.invoke({"input": f"{user_question}, this is the image path: {st.session_state.image_path}"}) |
| | response = chat_history['output'] |
| | st.write(bot_template.replace("{{MSG}}", response), unsafe_allow_html=True) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|