colpali / src /app.py
Thamaraikannan's picture
Update src/app.py
04fc0fb verified
import streamlit as st
import time
import markdown2
import torch
from qdrant_client import QdrantClient
from qdrant_client import models
from qdrant_client.models import Filter, FieldCondition, Range, MatchValue
from llm import analyze_image_with_query
# -------------------------------
# Page Configuration
# -------------------------------
st.set_page_config(
page_title="AI Document Search & Chat",
page_icon="πŸ”",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for modern styling
st.markdown("""
<style>
.main-header {
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
padding: 2rem;
border-radius: 10px;
margin-bottom: 2rem;
text-align: center;
color: white;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.search-container {
background: white;
padding: 2rem;
border-radius: 15px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
border: 1px solid #e1e5e9;
margin-bottom: 2rem;
}
.filter-section {
background: #f8f9fa;
padding: 1.5rem;
border-radius: 10px;
border-left: 4px solid #667eea;
}
.result-card {
background: #303336;
padding: 0.8rem 1rem;
border-radius: 10px;
border: 0.5px solid rgba(255, 255, 255, 0.1);
margin-bottom: 0.8rem;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
transition: transform 0.2s;
}
.result-card:hover {
transform: translateY(-2px);
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.15);
}
.ai-response-card {
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
border: 2px solid #667eea;
border-radius: 15px;
padding: 2rem;
margin: 2rem 0;
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.1);
}
.ai-response-header {
display: flex;
align-items: center;
margin-bottom: 1rem;
color: #667eea;
font-weight: bold;
font-size: 1.1rem;
}
.ai-response-content {
background: #303336;
padding: 1.5rem;
border-radius: 10px;
border-left: 4px solid #667eea;
line-height: 1.7;
font-size: 1rem;
}
.source-section {
background: #f8f9fa;
padding: 1rem;
border-radius: 8px;
margin-top: 1rem;
border: 1px solid #e1e5e9;
}
.score-badge {
background: linear-gradient(45deg, #667eea, #764ba2);
color: white;
padding: 0.3rem 0.8rem;
border-radius: 20px;
font-size: 0.8rem;
font-weight: bold;
display: inline-block;
margin-bottom: 1rem;
}
.metadata-chip {
background: #e3f2fd;
color: #1565c0;
padding: 0.2rem 0.6rem;
border-radius: 15px;
font-size: 0.75rem;
display: inline-block;
margin: 0.2rem;
font-weight: 500;
}
.no-results {
text-align: center;
padding: 3rem;
color: #666;
}
.stats-container {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 1rem;
border-radius: 10px;
margin-bottom: 1rem;
}
.search-input {
border: 2px solid #e1e5e9 !important;
border-radius: 10px !important;
padding: 0.75rem !important;
font-size: 1rem !important;
}
.search-input:focus {
border-color: #667eea !important;
box-shadow: 0 0 0 0.2rem rgba(102, 126, 234, 0.25) !important;
}
.stButton > button {
background: linear-gradient(45deg, #667eea, #764ba2);
color: white;
border: none;
border-radius: 10px;
padding: 0.75rem 2rem;
font-weight: 600;
transition: all 0.3s;
width: 100%;
}
.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
}
.sidebar-content {
background: #f8f9fa;
padding: 1rem;
border-radius: 10px;
margin-bottom: 1rem;
}
.chat-mode-toggle {
background: linear-gradient(45deg, #28a745, #20c997);
color: white;
border: none;
border-radius: 10px;
padding: 0.5rem 1rem;
font-weight: 600;
margin-bottom: 1rem;
}
</style>
""", unsafe_allow_html=True)
# -------------------------------
# Load models with better caching
# -------------------------------
@st.cache_resource(show_spinner=False)
def initialize_qdrant():
"""Initialize Qdrant client"""
qdrant_client = QdrantClient(
url="https://c75f218b-ecf5-4693-9b49-a2253478cf80.us-east4-0.gcp.cloud.qdrant.io:6333",
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.ErMufKlEvB6nuEzFTwXZWDvb0IrEC2VUkPpQ1ZaLcXc",
)
return qdrant_client
@st.cache_resource(show_spinner=False)
def initialize_colpali():
"""Initialize ColPali model and processor"""
try:
# Note: ColPali requires specific installation
from colpali_engine.models import ColPali, ColPaliProcessor
model_name = "vidore/colpali-v1.2"
colpali_model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda:0" if torch.cuda.is_available() else "cpu",
)
colpali_processor = ColPaliProcessor.from_pretrained(
"vidore/colpaligemma-3b-pt-448-base"
)
return colpali_model, colpali_processor
except ImportError:
st.error("ColPali not installed. Please install with: pip install colpali-engine")
return None, None
# Initialize models and services
with st.spinner("πŸš€ Initializing AI services..."):
qdrant_client = initialize_qdrant()
colpali_model, colpali_processor = initialize_colpali()
if colpali_model is None or colpali_processor is None:
st.stop()
# Initialize session state
if "chat_mode" not in st.session_state:
st.session_state.chat_mode = True
# Collection name
collection_name = "colpaliWithQdrant"
# -------------------------------
# Helper Functions
# -------------------------------
def search_documents(query, filter_dict):
"""Search for relevant documents using embedding similarity with metadata filtering."""
if colpali_model is None or colpali_processor is None:
st.error("ColPali model not initialized")
return []
try:
# Generate query embedding
with torch.no_grad():
batch_query = colpali_processor.process_queries([query])
if batch_query is None:
st.error("ColPali processor failed to process query.")
return []
batch_query = batch_query.to(colpali_model.device)
query_embedding = colpali_model(**batch_query)
if query_embedding is None:
st.error("ColPali model failed to generate embedding.")
return []
multivector_query = query_embedding[0].cpu().float().numpy().tolist()
string_fields = ["fiscal_year", "page_number"]
must_conditions = []
for key in ["fiscal_year", "page_number", "company", "currency"]:
if filter_dict.get(key):
value = filter_dict[key]["$eq"]
if key in string_fields:
value = str(value)
must_conditions.append(
FieldCondition(
key=key,
match=MatchValue(value=value)
)
)
filter_by_tag = Filter(must=must_conditions)
# Search in Qdrant with filters
start_time = time.time()
try:
search_result = qdrant_client.query_points(
collection_name=collection_name,
query=multivector_query,
query_filter=filter_by_tag,
limit=3,
timeout=100,
search_params=models.SearchParams(
quantization=models.QuantizationSearchParams(
ignore=False,
rescore=True,
oversampling=2.0,
)
)
)
except Exception as filter_error:
if "Index required" in str(filter_error):
st.warning("⚠️ Metadata filtering not available. Performing search without filters. Please create payload indexes.")
# Retry without filters
search_result = qdrant_client.query_points(
collection_name=collection_name,
query=multivector_query,
limit=3,
timeout=100,
search_params=models.SearchParams(
quantization=models.QuantizationSearchParams(
ignore=False,
rescore=True,
oversampling=2.0,
)
)
)
else:
raise
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Search completed in {elapsed_time:.4f} seconds")
# Convert search results to list of documents
print(f"Found {len(search_result.points)} results")
relevant_docs = []
for point in search_result.points:
relevant_docs.append({
"id": point.id,
"score": point.score,
"metadata": point.payload
})
return relevant_docs
except Exception as e:
st.error(f"Search error: {str(e)}")
return []
# -------------------------------
# Header
# -------------------------------
st.markdown("""
<div class="main-header">
<h1 style="margin: 0; font-size: 1.9rem;">Document Search RAG</h1>
<p style="margin: 0.5rem 0 0 0; font-size: 1.1rem; opacity: 0.9;">
Using ColPali, Qdrant and llama-4-scout-17b-16e-instruct.
</p>
</div>
""", unsafe_allow_html=True)
# -------------------------------
# Sidebar for filters and mode toggle
# -------------------------------
def clear_all_filters():
# Common
st.session_state.search_query = ""
st.session_state.page_no_filter = ""
# Annual Report
st.session_state.company_filter = ""
st.session_state.fiscal_year_filter = ""
st.session_state.currency_filter = ""
with st.sidebar:
st.markdown("### 🎯 Search Filters")
doc_type = st.selectbox(
"Document Type",
["annual_report"],
key="doc_type_filter"
)
# Annual Report filters
if doc_type == "annual_report":
with st.expander("Annual Report Filters", expanded=False):
# Example: company dropdown (populate with known companies)
companies_list = ["Sherwin Williams", "MSCI", "Eaton"]
company = st.selectbox("Company", [""] + companies_list, index=0, key="company_filter")
# Fiscal year dropdown
fiscal_years = ["2024"]
fiscal_year = st.selectbox("Fiscal Year", [""] + fiscal_years, index=0, key="fiscal_year_filter")
# Currency dropdown
currencies = ["USD"]
currency = st.selectbox("Currency", [""] + currencies, index=0, key="currency_filter")
# Page number dropdown or text input
# If you know page numbers, you can make it a dropdown
page_no = st.text_input("Page Number", placeholder="e.g., 15", key="page_no_filter")
# Reset button
st.button("Clear All Filters", on_click=clear_all_filters)
# Model info
st.markdown("---")
st.markdown("### ℹ️ Model Info")
st.info("**Embedding**: ColPali v1.2\n**Vector DB**: Qdrant\n**LLM**:llama-4-scout-17b-16e-instruct")
# -------------------------------
# Main search interface
# -------------------------------
col1, col2 = st.columns([3, 1])
with col1:
if st.session_state.chat_mode:
query = st.text_input(
"πŸ’¬ Ask a question about your documents",
placeholder="What would you like to know from the documents?",
label_visibility="collapsed",
key="search_query"
)
else:
query = st.text_input(
"πŸ” Search Query",
placeholder="What would you like to find in the documents?",
label_visibility="collapsed",
key="search_query"
)
with col2:
if st.session_state.chat_mode:
search_clicked = st.button("πŸ’¬ Ask AI", type="primary")
else:
search_clicked = st.button("πŸš€ Search", type="primary")
# -------------------------------
# Search functionality
# -------------------------------
if search_clicked or (query and len(query.strip()) > 0):
if not query.strip():
st.warning("⚠️ Please enter a search query to continue.")
else:
# Build filter dictionary
filter_dict = {}
# Common filters
if doc_type and doc_type != "All Types":
filter_dict["doc_type"] = {"$eq": doc_type}
if page_no and page_no.strip():
try:
filter_dict["page_number"] = {"$eq": int(page_no.strip())}
except ValueError:
st.error("⚠️ Page number must be a valid integer.")
st.stop()
# Annual Report filters
if doc_type == "annual_report":
if company and company.strip():
filter_dict["company"] = {"$eq": company.strip()}
if fiscal_year and fiscal_year.strip():
filter_dict["fiscal_year"] = {"$eq": fiscal_year.strip()}
if currency and currency.strip():
filter_dict["currency"] = {"$eq": currency.strip()}
# Perform search with progress indicators
start_time = time.time()
with st.spinner("πŸ” Searching through documents..."):
relevant_docs = search_documents(query, filter_dict)
# Generate AI response if in chat mode
if st.session_state.chat_mode and relevant_docs:
with st.spinner("πŸ€– Generating AI response..."):
gemini_answer = analyze_image_with_query(relevant_docs, query)
if gemini_answer:
st.markdown(
f"""
<div style="background: #303336; padding: 1rem; border-radius: 8px; margin: 1rem 0; line-height: 1.6; color: white;">
{gemini_answer}
</div>
""",
unsafe_allow_html=True
)
else:
st.warning("No AI response received.")
st.markdown("---")
if relevant_docs:
search_time = time.time() - start_time
# Display source documents
if st.session_state.chat_mode:
st.markdown("### πŸ“š Evidence")
for i, result in enumerate(relevant_docs, start=1):
metadata = result["metadata"]
image_url = metadata.get("image_url") # URL of the image
doc_id = metadata.get("doc_id", "N/A")
page_no = metadata.get("page_number", "N/A")
title = metadata.get("title", "N/A")
# Display document info
st.markdown(
f"#### [{i}] DOC : <span style='color:green;'>{doc_id}</span> | Page: {page_no} | Title: {title}",
unsafe_allow_html=True
)
# Display image from URL
if image_url:
st.image(image_url, caption=f"{title} | Page {page_no}", width='stretch')
else:
st.warning("No image available for this document.")
# Expandable full metadata
doc_label = "Source" if st.session_state.chat_mode else "Result"
with st.expander(f"πŸ” View full metadata for {doc_label} #{i}"):
st.json(metadata)
else:
# No results found
st.markdown("""
<div class="no-results">
<h3>πŸ€·β€β™‚οΈ No results found</h3>
<p>Try adjusting your search query or filters to find what you're looking for.</p>
<div style="margin-top: 2rem;">
<h4>πŸ’‘ Search Tips:</h4>
<ul style="text-align: left; display: inline-block;">
<li>Use specific keywords related to your topic</li>
<li>Try removing some filters to broaden your search</li>
<li>Check for typos in your query or filter values</li>
<li>Use synonyms or related terms</li>
</ul>
</div>
</div>
""", unsafe_allow_html=True)
# -------------------------------
# Usage Instructions
# -------------------------------
if not query:
st.markdown("---")
st.markdown("### πŸ’‘ How to Use")
st.markdown("""
**πŸ’¬ AI Chat Mode:**
- Ask natural language questions
- Get AI-generated answers based on documents
- View source documents used for the response
""")
# -------------------------------
# Footer
# -------------------------------
st.markdown("---")
st.markdown("""
<div style="text-align: center; color: #666; padding: 1rem;">
<small>πŸ€– Powered by Groq, ColPali, Qdrant, and Streamlit | Built with ❀️ for intelligent document search and chat</small>
</div>
""", unsafe_allow_html=True)