Advait3009's picture
Create app.py
2f62b14 verified
import streamlit as st
import os
import base64
import torch
from PIL import Image
from utils.retriever import FAISSRetriever
from utils.embedder import MultiModalEmbedder
from utils.memory import ChatMemory
from utils.model_loader import load_llava_model
from transformers import TextStreamer
# Initialize components with caching
@st.cache_resource
def load_components():
embedder = MultiModalEmbedder()
retriever = FAISSRetriever()
llava_pipe = load_llava_model()
return embedder, retriever, llava_pipe
def main():
st.title("MultiModal RAG Chatbot 🤖🖼️")
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "memory" not in st.session_state:
st.session_state.memory = ChatMemory()
# Sidebar for document upload
with st.sidebar:
st.header("Knowledge Base")
uploaded_files = st.file_uploader(
"Upload documents/images",
type=["pdf", "jpg", "png", "jpeg"],
accept_multiple_files=True
)
# Chat input
user_input = st.chat_input("Ask something or upload an image...")
uploaded_image = st.file_uploader("Upload image", type=["jpg", "png", "jpeg"], key="img_upload")
# Display chat history
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
if msg["type"] == "text":
st.markdown(msg["content"])
elif msg["type"] == "image":
st.image(msg["content"])
# Process inputs
if user_input or uploaded_image:
embedder, retriever, llava_pipe = load_components()
# Handle image upload
image = None
if uploaded_image:
image = Image.open(uploaded_image).convert("RGB")
with st.chat_message("user"):
st.image(image, caption="Uploaded Image", use_column_width=True)
st.session_state.messages.append({
"role": "user",
"type": "image",
"content": image
})
# Generate response
with st.spinner("Thinking..."):
# Retrieve context
if image:
image_emb = embedder.embed_image(image)
text_emb = embedder.embed_text(user_input) if user_input else None
context = retriever.search(image_emb, text_emb)
else:
context = retriever.search(text_emb=embedder.embed_text(user_input))
# Generate LLM response
prompt = f"CONTEXT: {context}\n\nQUERY: {user_input or 'Explain this image'}"
response = llava_pipe(
prompt,
image=image,
max_new_tokens=512,
streamer=TextStreamer(),
return_full_text=False
)[0]['generated_text']
# Update memory and display
st.session_state.memory.update(user_input, response)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({
"role": "assistant",
"type": "text",
"content": response
})
if __name__ == "__main__":
main()