Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os | |
| import base64 | |
| import torch | |
| from PIL import Image | |
| from utils.retriever import FAISSRetriever | |
| from utils.embedder import MultiModalEmbedder | |
| from utils.memory import ChatMemory | |
| from utils.model_loader import load_llava_model | |
| from transformers import TextStreamer | |
| # Initialize components with caching | |
| def load_components(): | |
| embedder = MultiModalEmbedder() | |
| retriever = FAISSRetriever() | |
| llava_pipe = load_llava_model() | |
| return embedder, retriever, llava_pipe | |
| def main(): | |
| st.title("MultiModal RAG Chatbot 🤖🖼️") | |
| # Initialize session state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "memory" not in st.session_state: | |
| st.session_state.memory = ChatMemory() | |
| # Sidebar for document upload | |
| with st.sidebar: | |
| st.header("Knowledge Base") | |
| uploaded_files = st.file_uploader( | |
| "Upload documents/images", | |
| type=["pdf", "jpg", "png", "jpeg"], | |
| accept_multiple_files=True | |
| ) | |
| # Chat input | |
| user_input = st.chat_input("Ask something or upload an image...") | |
| uploaded_image = st.file_uploader("Upload image", type=["jpg", "png", "jpeg"], key="img_upload") | |
| # Display chat history | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| if msg["type"] == "text": | |
| st.markdown(msg["content"]) | |
| elif msg["type"] == "image": | |
| st.image(msg["content"]) | |
| # Process inputs | |
| if user_input or uploaded_image: | |
| embedder, retriever, llava_pipe = load_components() | |
| # Handle image upload | |
| image = None | |
| if uploaded_image: | |
| image = Image.open(uploaded_image).convert("RGB") | |
| with st.chat_message("user"): | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| st.session_state.messages.append({ | |
| "role": "user", | |
| "type": "image", | |
| "content": image | |
| }) | |
| # Generate response | |
| with st.spinner("Thinking..."): | |
| # Retrieve context | |
| if image: | |
| image_emb = embedder.embed_image(image) | |
| text_emb = embedder.embed_text(user_input) if user_input else None | |
| context = retriever.search(image_emb, text_emb) | |
| else: | |
| context = retriever.search(text_emb=embedder.embed_text(user_input)) | |
| # Generate LLM response | |
| prompt = f"CONTEXT: {context}\n\nQUERY: {user_input or 'Explain this image'}" | |
| response = llava_pipe( | |
| prompt, | |
| image=image, | |
| max_new_tokens=512, | |
| streamer=TextStreamer(), | |
| return_full_text=False | |
| )[0]['generated_text'] | |
| # Update memory and display | |
| st.session_state.memory.update(user_input, response) | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "type": "text", | |
| "content": response | |
| }) | |
| if __name__ == "__main__": | |
| main() |