|
|
import streamlit as st |
|
|
import time |
|
|
|
|
|
import markdown2 |
|
|
import torch |
|
|
from qdrant_client import QdrantClient |
|
|
from qdrant_client import models |
|
|
from qdrant_client.models import Filter, FieldCondition, Range, MatchValue |
|
|
from llm import analyze_image_with_query |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="AI Document Search & Chat", |
|
|
page_icon="π", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.main-header { |
|
|
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); |
|
|
padding: 2rem; |
|
|
border-radius: 10px; |
|
|
margin-bottom: 2rem; |
|
|
text-align: center; |
|
|
color: white; |
|
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
|
} |
|
|
|
|
|
.search-container { |
|
|
background: white; |
|
|
padding: 2rem; |
|
|
border-radius: 15px; |
|
|
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); |
|
|
border: 1px solid #e1e5e9; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
|
|
|
.filter-section { |
|
|
background: #f8f9fa; |
|
|
padding: 1.5rem; |
|
|
border-radius: 10px; |
|
|
border-left: 4px solid #667eea; |
|
|
} |
|
|
|
|
|
.result-card { |
|
|
background: #303336; |
|
|
padding: 0.8rem 1rem; |
|
|
border-radius: 10px; |
|
|
border: 0.5px solid rgba(255, 255, 255, 0.1); |
|
|
margin-bottom: 0.8rem; |
|
|
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05); |
|
|
transition: transform 0.2s; |
|
|
} |
|
|
.result-card:hover { |
|
|
transform: translateY(-2px); |
|
|
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.15); |
|
|
} |
|
|
|
|
|
.ai-response-card { |
|
|
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); |
|
|
border: 2px solid #667eea; |
|
|
border-radius: 15px; |
|
|
padding: 2rem; |
|
|
margin: 2rem 0; |
|
|
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.1); |
|
|
} |
|
|
|
|
|
.ai-response-header { |
|
|
display: flex; |
|
|
align-items: center; |
|
|
margin-bottom: 1rem; |
|
|
color: #667eea; |
|
|
font-weight: bold; |
|
|
font-size: 1.1rem; |
|
|
} |
|
|
|
|
|
.ai-response-content { |
|
|
background: #303336; |
|
|
padding: 1.5rem; |
|
|
border-radius: 10px; |
|
|
border-left: 4px solid #667eea; |
|
|
line-height: 1.7; |
|
|
font-size: 1rem; |
|
|
} |
|
|
|
|
|
.source-section { |
|
|
background: #f8f9fa; |
|
|
padding: 1rem; |
|
|
border-radius: 8px; |
|
|
margin-top: 1rem; |
|
|
border: 1px solid #e1e5e9; |
|
|
} |
|
|
|
|
|
.score-badge { |
|
|
background: linear-gradient(45deg, #667eea, #764ba2); |
|
|
color: white; |
|
|
padding: 0.3rem 0.8rem; |
|
|
border-radius: 20px; |
|
|
font-size: 0.8rem; |
|
|
font-weight: bold; |
|
|
display: inline-block; |
|
|
margin-bottom: 1rem; |
|
|
} |
|
|
|
|
|
.metadata-chip { |
|
|
background: #e3f2fd; |
|
|
color: #1565c0; |
|
|
padding: 0.2rem 0.6rem; |
|
|
border-radius: 15px; |
|
|
font-size: 0.75rem; |
|
|
display: inline-block; |
|
|
margin: 0.2rem; |
|
|
font-weight: 500; |
|
|
} |
|
|
|
|
|
.no-results { |
|
|
text-align: center; |
|
|
padding: 3rem; |
|
|
color: #666; |
|
|
} |
|
|
|
|
|
.stats-container { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
color: white; |
|
|
padding: 1rem; |
|
|
border-radius: 10px; |
|
|
margin-bottom: 1rem; |
|
|
} |
|
|
|
|
|
.search-input { |
|
|
border: 2px solid #e1e5e9 !important; |
|
|
border-radius: 10px !important; |
|
|
padding: 0.75rem !important; |
|
|
font-size: 1rem !important; |
|
|
} |
|
|
|
|
|
.search-input:focus { |
|
|
border-color: #667eea !important; |
|
|
box-shadow: 0 0 0 0.2rem rgba(102, 126, 234, 0.25) !important; |
|
|
} |
|
|
|
|
|
.stButton > button { |
|
|
background: linear-gradient(45deg, #667eea, #764ba2); |
|
|
color: white; |
|
|
border: none; |
|
|
border-radius: 10px; |
|
|
padding: 0.75rem 2rem; |
|
|
font-weight: 600; |
|
|
transition: all 0.3s; |
|
|
width: 100%; |
|
|
} |
|
|
|
|
|
.stButton > button:hover { |
|
|
transform: translateY(-2px); |
|
|
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3); |
|
|
} |
|
|
|
|
|
.sidebar-content { |
|
|
background: #f8f9fa; |
|
|
padding: 1rem; |
|
|
border-radius: 10px; |
|
|
margin-bottom: 1rem; |
|
|
} |
|
|
.chat-mode-toggle { |
|
|
background: linear-gradient(45deg, #28a745, #20c997); |
|
|
color: white; |
|
|
border: none; |
|
|
border-radius: 10px; |
|
|
padding: 0.5rem 1rem; |
|
|
font-weight: 600; |
|
|
margin-bottom: 1rem; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def initialize_qdrant(): |
|
|
"""Initialize Qdrant client""" |
|
|
qdrant_client = QdrantClient( |
|
|
url="https://c75f218b-ecf5-4693-9b49-a2253478cf80.us-east4-0.gcp.cloud.qdrant.io:6333", |
|
|
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.ErMufKlEvB6nuEzFTwXZWDvb0IrEC2VUkPpQ1ZaLcXc", |
|
|
) |
|
|
return qdrant_client |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def initialize_colpali(): |
|
|
"""Initialize ColPali model and processor""" |
|
|
try: |
|
|
|
|
|
from colpali_engine.models import ColPali, ColPaliProcessor |
|
|
|
|
|
model_name = "vidore/colpali-v1.2" |
|
|
colpali_model = ColPali.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cuda:0" if torch.cuda.is_available() else "cpu", |
|
|
) |
|
|
colpali_processor = ColPaliProcessor.from_pretrained( |
|
|
"vidore/colpaligemma-3b-pt-448-base" |
|
|
) |
|
|
return colpali_model, colpali_processor |
|
|
except ImportError: |
|
|
st.error("ColPali not installed. Please install with: pip install colpali-engine") |
|
|
return None, None |
|
|
|
|
|
|
|
|
|
|
|
with st.spinner("π Initializing AI services..."): |
|
|
qdrant_client = initialize_qdrant() |
|
|
colpali_model, colpali_processor = initialize_colpali() |
|
|
if colpali_model is None or colpali_processor is None: |
|
|
st.stop() |
|
|
|
|
|
|
|
|
if "chat_mode" not in st.session_state: |
|
|
st.session_state.chat_mode = True |
|
|
|
|
|
|
|
|
collection_name = "colpaliWithQdrant" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def search_documents(query, filter_dict): |
|
|
"""Search for relevant documents using embedding similarity with metadata filtering.""" |
|
|
if colpali_model is None or colpali_processor is None: |
|
|
st.error("ColPali model not initialized") |
|
|
return [] |
|
|
try: |
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_query = colpali_processor.process_queries([query]) |
|
|
if batch_query is None: |
|
|
st.error("ColPali processor failed to process query.") |
|
|
return [] |
|
|
batch_query = batch_query.to(colpali_model.device) |
|
|
query_embedding = colpali_model(**batch_query) |
|
|
if query_embedding is None: |
|
|
st.error("ColPali model failed to generate embedding.") |
|
|
return [] |
|
|
multivector_query = query_embedding[0].cpu().float().numpy().tolist() |
|
|
|
|
|
string_fields = ["fiscal_year", "page_number"] |
|
|
must_conditions = [] |
|
|
|
|
|
for key in ["fiscal_year", "page_number", "company", "currency"]: |
|
|
if filter_dict.get(key): |
|
|
value = filter_dict[key]["$eq"] |
|
|
if key in string_fields: |
|
|
value = str(value) |
|
|
must_conditions.append( |
|
|
FieldCondition( |
|
|
key=key, |
|
|
match=MatchValue(value=value) |
|
|
) |
|
|
) |
|
|
filter_by_tag = Filter(must=must_conditions) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
try: |
|
|
search_result = qdrant_client.query_points( |
|
|
collection_name=collection_name, |
|
|
query=multivector_query, |
|
|
query_filter=filter_by_tag, |
|
|
limit=3, |
|
|
timeout=100, |
|
|
search_params=models.SearchParams( |
|
|
quantization=models.QuantizationSearchParams( |
|
|
ignore=False, |
|
|
rescore=True, |
|
|
oversampling=2.0, |
|
|
) |
|
|
) |
|
|
) |
|
|
except Exception as filter_error: |
|
|
if "Index required" in str(filter_error): |
|
|
st.warning("β οΈ Metadata filtering not available. Performing search without filters. Please create payload indexes.") |
|
|
|
|
|
search_result = qdrant_client.query_points( |
|
|
collection_name=collection_name, |
|
|
query=multivector_query, |
|
|
limit=3, |
|
|
timeout=100, |
|
|
search_params=models.SearchParams( |
|
|
quantization=models.QuantizationSearchParams( |
|
|
ignore=False, |
|
|
rescore=True, |
|
|
oversampling=2.0, |
|
|
) |
|
|
) |
|
|
) |
|
|
else: |
|
|
raise |
|
|
|
|
|
end_time = time.time() |
|
|
elapsed_time = end_time - start_time |
|
|
print(f"Search completed in {elapsed_time:.4f} seconds") |
|
|
|
|
|
|
|
|
print(f"Found {len(search_result.points)} results") |
|
|
relevant_docs = [] |
|
|
for point in search_result.points: |
|
|
relevant_docs.append({ |
|
|
"id": point.id, |
|
|
"score": point.score, |
|
|
"metadata": point.payload |
|
|
}) |
|
|
|
|
|
return relevant_docs |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Search error: {str(e)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<div class="main-header"> |
|
|
<h1 style="margin: 0; font-size: 1.9rem;">Document Search RAG</h1> |
|
|
<p style="margin: 0.5rem 0 0 0; font-size: 1.1rem; opacity: 0.9;"> |
|
|
Using ColPali, Qdrant and llama-4-scout-17b-16e-instruct. |
|
|
</p> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_all_filters(): |
|
|
|
|
|
st.session_state.search_query = "" |
|
|
st.session_state.page_no_filter = "" |
|
|
|
|
|
|
|
|
st.session_state.company_filter = "" |
|
|
st.session_state.fiscal_year_filter = "" |
|
|
st.session_state.currency_filter = "" |
|
|
|
|
|
with st.sidebar: |
|
|
st.markdown("### π― Search Filters") |
|
|
|
|
|
doc_type = st.selectbox( |
|
|
"Document Type", |
|
|
["annual_report"], |
|
|
key="doc_type_filter" |
|
|
) |
|
|
|
|
|
|
|
|
if doc_type == "annual_report": |
|
|
with st.expander("Annual Report Filters", expanded=False): |
|
|
|
|
|
companies_list = ["Sherwin Williams", "MSCI", "Eaton"] |
|
|
company = st.selectbox("Company", [""] + companies_list, index=0, key="company_filter") |
|
|
|
|
|
|
|
|
fiscal_years = ["2024"] |
|
|
fiscal_year = st.selectbox("Fiscal Year", [""] + fiscal_years, index=0, key="fiscal_year_filter") |
|
|
|
|
|
|
|
|
currencies = ["USD"] |
|
|
currency = st.selectbox("Currency", [""] + currencies, index=0, key="currency_filter") |
|
|
|
|
|
|
|
|
|
|
|
page_no = st.text_input("Page Number", placeholder="e.g., 15", key="page_no_filter") |
|
|
|
|
|
|
|
|
|
|
|
st.button("Clear All Filters", on_click=clear_all_filters) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("### βΉοΈ Model Info") |
|
|
st.info("**Embedding**: ColPali v1.2\n**Vector DB**: Qdrant\n**LLM**:llama-4-scout-17b-16e-instruct") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
col1, col2 = st.columns([3, 1]) |
|
|
|
|
|
with col1: |
|
|
if st.session_state.chat_mode: |
|
|
query = st.text_input( |
|
|
"π¬ Ask a question about your documents", |
|
|
placeholder="What would you like to know from the documents?", |
|
|
label_visibility="collapsed", |
|
|
key="search_query" |
|
|
) |
|
|
else: |
|
|
query = st.text_input( |
|
|
"π Search Query", |
|
|
placeholder="What would you like to find in the documents?", |
|
|
label_visibility="collapsed", |
|
|
key="search_query" |
|
|
) |
|
|
|
|
|
with col2: |
|
|
if st.session_state.chat_mode: |
|
|
search_clicked = st.button("π¬ Ask AI", type="primary") |
|
|
else: |
|
|
search_clicked = st.button("π Search", type="primary") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if search_clicked or (query and len(query.strip()) > 0): |
|
|
if not query.strip(): |
|
|
st.warning("β οΈ Please enter a search query to continue.") |
|
|
else: |
|
|
|
|
|
filter_dict = {} |
|
|
|
|
|
|
|
|
if doc_type and doc_type != "All Types": |
|
|
filter_dict["doc_type"] = {"$eq": doc_type} |
|
|
|
|
|
if page_no and page_no.strip(): |
|
|
try: |
|
|
filter_dict["page_number"] = {"$eq": int(page_no.strip())} |
|
|
except ValueError: |
|
|
st.error("β οΈ Page number must be a valid integer.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
if doc_type == "annual_report": |
|
|
if company and company.strip(): |
|
|
filter_dict["company"] = {"$eq": company.strip()} |
|
|
if fiscal_year and fiscal_year.strip(): |
|
|
filter_dict["fiscal_year"] = {"$eq": fiscal_year.strip()} |
|
|
if currency and currency.strip(): |
|
|
filter_dict["currency"] = {"$eq": currency.strip()} |
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with st.spinner("π Searching through documents..."): |
|
|
relevant_docs = search_documents(query, filter_dict) |
|
|
|
|
|
|
|
|
if st.session_state.chat_mode and relevant_docs: |
|
|
with st.spinner("π€ Generating AI response..."): |
|
|
gemini_answer = analyze_image_with_query(relevant_docs, query) |
|
|
if gemini_answer: |
|
|
st.markdown( |
|
|
f""" |
|
|
<div style="background: #303336; padding: 1rem; border-radius: 8px; margin: 1rem 0; line-height: 1.6; color: white;"> |
|
|
{gemini_answer} |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
else: |
|
|
st.warning("No AI response received.") |
|
|
st.markdown("---") |
|
|
|
|
|
if relevant_docs: |
|
|
search_time = time.time() - start_time |
|
|
|
|
|
|
|
|
if st.session_state.chat_mode: |
|
|
st.markdown("### π Evidence") |
|
|
|
|
|
for i, result in enumerate(relevant_docs, start=1): |
|
|
metadata = result["metadata"] |
|
|
image_url = metadata.get("image_url") |
|
|
doc_id = metadata.get("doc_id", "N/A") |
|
|
page_no = metadata.get("page_number", "N/A") |
|
|
title = metadata.get("title", "N/A") |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
f"#### [{i}] DOC : <span style='color:green;'>{doc_id}</span> | Page: {page_no} | Title: {title}", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
if image_url: |
|
|
st.image(image_url, caption=f"{title} | Page {page_no}", width='stretch') |
|
|
else: |
|
|
st.warning("No image available for this document.") |
|
|
|
|
|
|
|
|
doc_label = "Source" if st.session_state.chat_mode else "Result" |
|
|
with st.expander(f"π View full metadata for {doc_label} #{i}"): |
|
|
st.json(metadata) |
|
|
|
|
|
else: |
|
|
|
|
|
st.markdown(""" |
|
|
<div class="no-results"> |
|
|
<h3>π€·ββοΈ No results found</h3> |
|
|
<p>Try adjusting your search query or filters to find what you're looking for.</p> |
|
|
<div style="margin-top: 2rem;"> |
|
|
<h4>π‘ Search Tips:</h4> |
|
|
<ul style="text-align: left; display: inline-block;"> |
|
|
<li>Use specific keywords related to your topic</li> |
|
|
<li>Try removing some filters to broaden your search</li> |
|
|
<li>Check for typos in your query or filter values</li> |
|
|
<li>Use synonyms or related terms</li> |
|
|
</ul> |
|
|
</div> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not query: |
|
|
st.markdown("---") |
|
|
st.markdown("### π‘ How to Use") |
|
|
|
|
|
st.markdown(""" |
|
|
**π¬ AI Chat Mode:** |
|
|
- Ask natural language questions |
|
|
- Get AI-generated answers based on documents |
|
|
- View source documents used for the response |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown(""" |
|
|
<div style="text-align: center; color: #666; padding: 1rem;"> |
|
|
<small>π€ Powered by Groq, ColPali, Qdrant, and Streamlit | Built with β€οΈ for intelligent document search and chat</small> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|