AI-DocumentSearch / src /streamlit_app.py
Thamaraikannan's picture
Update src/streamlit_app.py
ffb8b25 verified
import streamlit as st
from sentence_transformers import SentenceTransformer, CrossEncoder
from pinecone import Pinecone
from groq import Groq
import uuid
import time
from pinecone_text.sparse import BM25Encoder
import os
import pickle
import nltk
import markdown2
nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
# -------------------------------
# Page Configuration
# -------------------------------
st.set_page_config(
page_title="AI Document Search & Chat",
page_icon="πŸ”",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for modern styling
st.markdown("""
<style>
.main-header {
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
padding: 2rem;
border-radius: 10px;
margin-bottom: 2rem;
text-align: center;
color: white;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.search-container {
background: white;
padding: 2rem;
border-radius: 15px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
border: 1px solid #e1e5e9;
margin-bottom: 2rem;
}
.filter-section {
background: #f8f9fa;
padding: 1.5rem;
border-radius: 10px;
border-left: 4px solid #667eea;
}
.result-card {
background: #303336;
padding: 0.8rem 1rem;
border-radius: 10px;
border: 0.5px solid rgba(255, 255, 255, 0.1);
margin-bottom: 0.8rem;
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
transition: transform 0.2s;
}
.result-card:hover {
transform: translateY(-2px);
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.15);
}
.ai-response-card {
background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
border: 2px solid #667eea;
border-radius: 15px;
padding: 2rem;
margin: 2rem 0;
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.1);
}
.ai-response-header {
display: flex;
align-items: center;
margin-bottom: 1rem;
color: #667eea;
font-weight: bold;
font-size: 1.1rem;
}
.ai-response-content {
background: #303336;
padding: 1.5rem;
border-radius: 10px;
border-left: 4px solid #667eea;
line-height: 1.7;
font-size: 1rem;
}
.source-section {
background: #f8f9fa;
padding: 1rem;
border-radius: 8px;
margin-top: 1rem;
border: 1px solid #e1e5e9;
}
.score-badge {
background: linear-gradient(45deg, #667eea, #764ba2);
color: white;
padding: 0.3rem 0.8rem;
border-radius: 20px;
font-size: 0.8rem;
font-weight: bold;
display: inline-block;
margin-bottom: 1rem;
}
.metadata-chip {
background: #e3f2fd;
color: #1565c0;
padding: 0.2rem 0.6rem;
border-radius: 15px;
font-size: 0.75rem;
display: inline-block;
margin: 0.2rem;
font-weight: 500;
}
.no-results {
text-align: center;
padding: 3rem;
color: #666;
}
.stats-container {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 1rem;
border-radius: 10px;
margin-bottom: 1rem;
}
.search-input {
border: 2px solid #e1e5e9 !important;
border-radius: 10px !important;
padding: 0.75rem !important;
font-size: 1rem !important;
}
.search-input:focus {
border-color: #667eea !important;
box-shadow: 0 0 0 0.2rem rgba(102, 126, 234, 0.25) !important;
}
.stButton > button {
background: linear-gradient(45deg, #667eea, #764ba2);
color: white;
border: none;
border-radius: 10px;
padding: 0.75rem 2rem;
font-weight: 600;
transition: all 0.3s;
width: 100%;
}
.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
}
.sidebar-content {
background: #f8f9fa;
padding: 1rem;
border-radius: 10px;
margin-bottom: 1rem;
}
.chat-mode-toggle {
background: linear-gradient(45deg, #28a745, #20c997);
color: white;
border: none;
border-radius: 10px;
padding: 0.5rem 1rem;
font-weight: 600;
margin-bottom: 1rem;
}
</style>
""", unsafe_allow_html=True)
# -------------------------------
# Load models with better caching
# -------------------------------
@st.cache_resource(show_spinner=False)
def load_models():
with st.spinner("πŸ€– Loading AI models..."):
embed_model = SentenceTransformer(
"google/embeddinggemma-300m",
token=HF_TOKEN
)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
return embed_model, reranker
@st.cache_resource(show_spinner=False)
def initialize_pinecone():
pc = Pinecone(api_key=PINECONE_API_KEY)
index = pc.Index("rag-latest")
return index
@st.cache_resource(show_spinner=False)
def initialize_bm25():
with open("src/bm25_model.pkl", "rb") as f:
bm25 = pickle.load(f)
return bm25
@st.cache_resource(show_spinner=False)
def initialize_groq():
client = Groq(api_key=GROQ_API_KEY)
return client
# Initialize models and services
with st.spinner("πŸš€ Initializing AI services..."):
embed_model, reranker = load_models()
index = initialize_pinecone()
bm25 = initialize_bm25()
groq_client = initialize_groq()
# Initialize session state
if "chat_mode" not in st.session_state:
st.session_state.chat_mode = True
# -------------------------------
# Helper Functions
# -------------------------------
def search_documents(query, filter_dict, top_k):
"""Search for relevant documents using embedding similarity and reranking."""
dense_query = embed_model.encode(query).tolist()
sparse_query = bm25.encode_queries([query])[0]
# Query Pinecone
res = index.query(
vector=dense_query,
sparse_vector=sparse_query,
top_k=10,
include_metadata=True,
hybrid=True,
filter=filter_dict
)
candidates = res["matches"]
if candidates:
# Rerank results
pairs = [(query, match["metadata"].get("text", "")) for match in candidates]
scores = reranker.predict(pairs)
for match, score in zip(candidates, scores):
match["rerank_score"] = float(score)
reranked = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)
return reranked[:3]
return []
def generate_ai_response(query, relevant_docs):
"""Generate AI response using Groq LLM based on query and relevant documents."""
# Prepare context from relevant documents
context_parts = []
sources = []
for i, doc in enumerate(relevant_docs, 1):
metadata = doc["metadata"]
text = metadata.get("text")
doc_id = metadata.get("doc_id")
title = metadata.get("title")
fiscal_year = metadata.get("fiscal_year")
page_no = metadata.get("page_no")
# Context for LLM
context_parts.append(f"[CHUNK {i} DOC {doc_id} {title} fiscal year {fiscal_year} ] (Page {page_no})\n{text}")
# Collect for UI
sources.append({
"id": i,
"title": title,
"page": page_no,
"doc_type": metadata.get("doc_type", ""),
})
context = "\n\n".join(context_parts)
# Create the prompt for Groq
prompt = f"""
You will answer the question using ONLY the provided document excerpts.
When you use information from a document, cite it with the format [DOC i],
where i corresponds to the document number given in CONTEXT DOCUMENTS.
If multiple docs are relevant, cite all of them (e.g., [DOC 1][DOC 3]).
CONTEXT DOCUMENTS:
{context}
USER QUESTION: {query}
ANSWER : " "
"""
try:
# Call Groq API
chat_completion = groq_client.chat.completions.create(
messages=[
{
"role": "system",
"content": """You are a professional assistant that answers user questions based **only on the content of provided document excerpts**. The user will ask a question, and you will also receive related text chunks retrieved from company documents or PDFs.
Instructions:
1. Use **only** the retrieved chunks to answer the user’s question. Do **not** add information from memory or outside sources.
2. If multiple chunks provide relevant info, combine them into a **clear, concise answer**.
3. If the answer is **not found** in the chunks, respond exactly with: "The document does not provide enough information to answer this question."
4. Keep the style **professional, factual, and concise**.
5. retrun the response as markdown format
7. Refuse to answer or speculate if no reliable evidence is found in the chunks.
"""
},
{
"role": "user",
"content": prompt
}
],
model="llama-3.3-70b-versatile",
stream=False
)
return chat_completion.choices[0].message.content
except Exception as e:
return f"❌ Error generating AI response: {str(e)}"
# -------------------------------
# Header
# -------------------------------
st.markdown("""
<div class="main-header">
<h1 style="margin: 0; font-size: 1.9rem;"> Hybrid Search RAG </h1>
<p style="margin: 0.5rem 0 0 0; font-size: 1.1rem; opacity: 0.9;">
Using Groq LLM, Pinecone, and Sentence Transformers
</p>
</div>
""", unsafe_allow_html=True)
# -------------------------------
# Sidebar for filters and mode toggle
# -------------------------------
def clear_all_filters():
# Common
st.session_state.search_query = ""
st.session_state.page_no_filter = ""
# Annual Report
st.session_state.company_filter = ""
st.session_state.fiscal_year_filter = ""
st.session_state.currency_filter = ""
st.session_state.unit_filter = ""
# # Contract Report
# st.session_state.agreement_date_filter = ""
# st.session_state.promoter_filter = ""
# st.session_state.allottee_filter = ""
# st.session_state.project_name_filter = ""
# st.session_state.apartment_block_filter = ""
# st.session_state.apartment_floor_filter = ""
# st.session_state.apartment_type_filter = ""
# # st.session_state.carpet_area_filter = "" # if you add this back
# st.session_state.jurisdiction_filter = ""
with st.sidebar:
st.markdown("### 🎯 Search Filters")
# Remove the annual_report option
doc_type = st.selectbox(
"Document Type",
["contract_report"], # Only keep contract_report
key="doc_type_filter"
)
# Contract Report filters
if doc_type == "contract_report":
with st.expander("Contract Report Filters", expanded=False):
agreement_date = st.text_input("Agreement Date", placeholder="YYYY-MM-DD", key="agreement_date_filter")
promoter = st.text_input("Promoter / Developer", placeholder="Enter promoter name...", key="promoter_filter")
allottee = st.text_input("Allottee (Buyer)", placeholder="Enter allottee name...", key="allottee_filter")
project_name = st.text_input("Project Name", placeholder="Enter project name...", key="project_name_filter")
apartment_block = st.text_input("Block", placeholder="e.g., Tower A", key="apartment_block_filter")
apartment_floor = st.text_input("Floor", placeholder="e.g., 10th floor", key="apartment_floor_filter")
apartment_type = st.text_input("Apartment Type", placeholder="e.g., 2BHK", key="apartment_type_filter")
jurisdiction = st.text_input("Jurisdiction", placeholder="e.g., Madras High Court", key="jurisdiction_filter")
page_no = st.text_input("Page Number", placeholder="e.g., 15", key="page_no_filter")
# Reset button
st.button("Clear All Filters", on_click=clear_all_filters)
# Model info
st.markdown("---")
st.markdown("### ℹ️ Model Info")
st.info("**Embedding**: Google EmbeddingGemma-300M\n**Reranker**: MS-MARCO MiniLM-L-6-v2\n**LLM**: Groq Llama-3.1-70B")
# -------------------------------
# Main search interface
# -------------------------------
col1, col2 = st.columns([3, 1])
with col1:
if st.session_state.chat_mode:
query = st.text_input(
"πŸ’¬ Ask a question about your documents",
placeholder="What would you like to know from the documents?",
label_visibility="collapsed",
key="search_query"
)
else:
query = st.text_input(
"πŸ” Search Query",
placeholder="What would you like to find in the documents?",
label_visibility="collapsed",
key="search_query"
)
with col2:
if st.session_state.chat_mode:
search_clicked = st.button("πŸ’¬ Ask AI", type="primary")
else:
search_clicked = st.button("πŸš€ Search", type="primary")
# -------------------------------
# Search functionality
# -------------------------------
if search_clicked or (query and len(query.strip()) > 0):
if not query.strip():
st.warning("⚠️ Please enter a search query to continue.")
else:
# Build filter dictionary
filter_dict = {}
# Common filters
if doc_type and doc_type != "All Types":
filter_dict["doc_type"] = {"$eq": doc_type}
if page_no and page_no.strip():
try:
filter_dict["page_no"] = {"$eq": int(page_no.strip())}
except ValueError:
st.error("⚠️ Page number must be a valid integer.")
st.stop()
# Contract Report filters
if doc_type == "contract_report":
if agreement_date and agreement_date.strip():
filter_dict["agreement_date"] = {"$eq": agreement_date.strip()}
if promoter and promoter.strip():
filter_dict["promoter_legal_name"] = {"$eq": promoter.strip()}
if allottee and allottee.strip():
filter_dict["allottee_name"] = {"$eq": allottee.strip()}
if project_name and project_name.strip():
filter_dict["project_name"] = {"$eq": project_name.strip()}
if apartment_block and apartment_block.strip():
filter_dict["apartment_block"] = {"$eq": apartment_block.strip()}
if apartment_floor and apartment_floor.strip():
filter_dict["apartment_floor"] = {"$eq": apartment_floor.strip()}
if apartment_type and apartment_type.strip():
filter_dict["apartment_type"] = {"$eq": apartment_type.strip()}
if jurisdiction and jurisdiction.strip():
filter_dict["jurisdiction"] = {"$eq": jurisdiction.strip()}
# Perform search with progress indicators
start_time = time.time()
with st.spinner("πŸ” Searching through documents..."):
relevant_docs = search_documents(query, filter_dict, top_k=5)
# Generate AI response if in chat mode
if st.session_state.chat_mode:
with st.spinner("πŸ€– Generating AI response..."):
ai_response = generate_ai_response(query, relevant_docs)
# Display AI response
# st.markdown(ai_response,unsafe_allow_html=True)
html_content1 = markdown2.markdown(ai_response)
st.markdown(f'<div style="background: #303336; padding: 1rem; border-radius: 8px; margin: 1rem 0; line-height: 1.6; color: white;">{html_content1}</div>', unsafe_allow_html=True)
st.markdown("---")
if relevant_docs:
search_time = time.time() - start_time
# Display source documents
if st.session_state.chat_mode:
st.markdown("### Evidence")
# else:
# st.markdown("### πŸ“‹ Search Results")
for i, result in enumerate(relevant_docs, start=1):
metadata = result["metadata"]
text_content = metadata.get("text", "No text available")
doc_id = metadata.get("doc_id", "N/A")
page_no = metadata.get("page_no", "N/A")
title = metadata.get("title")
# st.markdown("#### [{i}] DOC : {doc_id} | Page: {page_no} | Title {title}".format(i=i, doc_id=doc_id, page_no=page_no, title=title))
st.markdown(
"#### [{i}] DOC : <span style='color:green;'>{doc_id}</span> | Page: {page_no} | Title: {title}".format(
i=i, doc_id=doc_id, page_no=page_no, title=title
),
unsafe_allow_html=True
)
html_content = markdown2.markdown(text_content)
st.markdown(f'<div style="background: #303336; padding: 1rem; border-radius: 8px; margin: 1rem 0; line-height: 1.6; color: white;">{html_content}</div>', unsafe_allow_html=True)
# Expandable full metadata
doc_label = "Source" if st.session_state.chat_mode else "Result"
with st.expander(f"πŸ” View full metadata for {doc_label} #{i}"):
st.json(metadata)
st.markdown("</div>", unsafe_allow_html=True)
else:
# No results found
st.markdown("""
<div class="no-results">
<h3>πŸ€·β€β™‚οΈ No results found</h3>
<p>Try adjusting your search query or filters to find what you're looking for.</p>
<div style="margin-top: 2rem;">
<h4>πŸ’‘ Search Tips:</h4>
<ul style="text-align: left; display: inline-block;">
<li>Use specific keywords related to your topic</li>
<li>Try removing some filters to broaden your search</li>
<li>Check for typos in your query or filter values</li>
<li>Use synonyms or related terms</li>
</ul>
</div>
</div>
""", unsafe_allow_html=True)
# -------------------------------
# Usage Instructions
# -------------------------------
if not query:
st.markdown("---")
st.markdown("### πŸ’‘ How to Use")
st.markdown("""
**πŸ’¬ AI Chat Mode:**
- Ask natural language questions
- Get AI-generated answers based on documents
- View source documents used for the response
""")
# -------------------------------
# Footer
# -------------------------------
st.markdown("---")
st.markdown("""
<div style="text-align: center; color: #666; padding: 1rem;">
<small>πŸ€– Powered by Groq, Sentence Transformers, Pinecone, and Streamlit | Built with ❀️ for intelligent document search and chat</small>
</div>
""", unsafe_allow_html=True)