import streamlit as st import os import base64 import torch from PIL import Image from utils.retriever import FAISSRetriever from utils.embedder import MultiModalEmbedder from utils.memory import ChatMemory from utils.model_loader import load_llava_model from transformers import TextStreamer # Initialize components with caching @st.cache_resource def load_components(): embedder = MultiModalEmbedder() retriever = FAISSRetriever() llava_pipe = load_llava_model() return embedder, retriever, llava_pipe def main(): st.title("MultiModal RAG Chatbot 🤖🖼️") # Initialize session state if "messages" not in st.session_state: st.session_state.messages = [] if "memory" not in st.session_state: st.session_state.memory = ChatMemory() # Sidebar for document upload with st.sidebar: st.header("Knowledge Base") uploaded_files = st.file_uploader( "Upload documents/images", type=["pdf", "jpg", "png", "jpeg"], accept_multiple_files=True ) # Chat input user_input = st.chat_input("Ask something or upload an image...") uploaded_image = st.file_uploader("Upload image", type=["jpg", "png", "jpeg"], key="img_upload") # Display chat history for msg in st.session_state.messages: with st.chat_message(msg["role"]): if msg["type"] == "text": st.markdown(msg["content"]) elif msg["type"] == "image": st.image(msg["content"]) # Process inputs if user_input or uploaded_image: embedder, retriever, llava_pipe = load_components() # Handle image upload image = None if uploaded_image: image = Image.open(uploaded_image).convert("RGB") with st.chat_message("user"): st.image(image, caption="Uploaded Image", use_column_width=True) st.session_state.messages.append({ "role": "user", "type": "image", "content": image }) # Generate response with st.spinner("Thinking..."): # Retrieve context if image: image_emb = embedder.embed_image(image) text_emb = embedder.embed_text(user_input) if user_input else None context = retriever.search(image_emb, text_emb) else: context = retriever.search(text_emb=embedder.embed_text(user_input)) # Generate LLM response prompt = f"CONTEXT: {context}\n\nQUERY: {user_input or 'Explain this image'}" response = llava_pipe( prompt, image=image, max_new_tokens=512, streamer=TextStreamer(), return_full_text=False )[0]['generated_text'] # Update memory and display st.session_state.memory.update(user_input, response) with st.chat_message("assistant"): st.markdown(response) st.session_state.messages.append({ "role": "assistant", "type": "text", "content": response }) if __name__ == "__main__": main()