Shubham170793's picture
Update src/streamlit_app.py
585fec8 verified
raw
history blame
9.5 kB
import os
import re
import shutil
import hashlib
import streamlit as st
import torch
# ==========================================================
# βœ… Environment Diagnostics
# ==========================================================
print("CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
if torch.cuda.is_available():
print("GPU name:", torch.cuda.get_device_name(0))
else:
print("Running on CPU")
# ==========================================================
# βœ… Page Configuration
# ==========================================================
st.set_page_config(
page_title="Enterprise Knowledge Assistant",
layout="wide"
)
# ==========================================================
# 🧹 Cache Management (prevent HF overflow)
# ==========================================================
def clean_cache(max_size_gb: float = 2.0):
"""
Cleans large cache folders (> max_size_gb),
preserving /tmp/hf_cache (used for model weights).
"""
folders = [
"/root/.cache/huggingface",
"/root/.cache/transformers",
"/root/.cache/torch",
]
total_deleted = 0.0
for folder in folders:
if os.path.exists(folder):
size_gb = sum(
os.path.getsize(os.path.join(dp, f))
for dp, _, files in os.walk(folder)
for f in files
) / (1024**3)
if size_gb > max_size_gb or "torch" in folder:
shutil.rmtree(folder, ignore_errors=True)
total_deleted += size_gb
print(f"πŸ—‘οΈ Deleted {folder} ({size_gb:.2f} GB)")
else:
print(f"βœ… Preserved {folder} ({size_gb:.2f} GB)")
os.makedirs("/tmp/hf_cache", exist_ok=True)
print(f"🧹 Cache cleanup done. ~{total_deleted:.2f} GB removed.")
def check_disk_usage():
"""Display disk usage info in sidebar."""
st.sidebar.markdown("### πŸ’Ύ Disk Usage (Debug)")
try:
usage = os.popen("du -sh /root/.cache /tmp 2>/dev/null").read()
st.sidebar.text(usage if usage else "No cache directories found.")
except Exception as e:
st.sidebar.text(f"⚠️ Disk usage check failed: {e}")
# Run cache cleanup once at startup
clean_cache()
check_disk_usage()
# ==========================================================
# βš™οΈ Hugging Face Cache Configuration
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR,
"HF_DATASETS_CACHE": CACHE_DIR,
"HF_MODULES_CACHE": CACHE_DIR
})
# ==========================================================
# πŸ“¦ Imports AFTER Environment Setup
# ==========================================================
from ingestion import extract_text_from_pdf, chunk_text
from vectorstore import build_faiss_index
from qa import retrieve_chunks, generate_answer, cache_embeddings, embed_chunks
# ==========================================================
# πŸ“ Paths
# ==========================================================
BASE_DIR = os.path.dirname(__file__)
LOGO_PATH = os.path.join(BASE_DIR, "logo.png")
SAMPLE_PATH = os.path.join(BASE_DIR, "sample.pdf")
# ==========================================================
# πŸ–₯️ UI Header
# ==========================================================
st.title("πŸ“„ Enterprise Knowledge Assistant")
st.caption("Query SAP documentation and enterprise PDFs using natural language and reasoning.")
# ==========================================================
# 🧭 Sidebar β€” Library, Settings, Diagnostics
# ==========================================================
with st.sidebar:
# πŸ–ΌοΈ App Logo
if os.path.exists(LOGO_PATH):
st.image(LOGO_PATH, width=150)
# 🧠 Reasoning Mode Toggle
if "reasoning_mode" not in st.session_state:
st.session_state.reasoning_mode = False
st.session_state.reasoning_mode = st.toggle(
"🧠 Enable Reasoning Mode",
value=st.session_state.reasoning_mode,
help="When ON: GPT-4o uses reasoning + web-like synthesis.\nWhen OFF: Strictly factual from PDF."
)
st.markdown("---")
# πŸ“š Document Library
st.header("πŸ“š Document Library")
doc_choice = st.radio(
"Choose a document:",
["-- Select --", "Sample PDF", "Upload Custom PDF"],
index=0
)
st.markdown("---")
# βš™οΈ Settings
st.header("βš™οΈ Settings")
chunk_size = st.slider("Chunk Size (characters)", 200, 1500, 800, step=50)
overlap = st.slider("Chunk Overlap (characters)", 50, 200, 120, step=10)
top_k = st.slider("Top K Results", 1, 10, 5)
st.markdown("---")
st.caption("πŸ‘¨β€πŸ’» Built by Shubham Sharma")
# ==========================================================
# 🧾 Document Handling
# ==========================================================
text, chunks, index, embeddings = None, None, None, None
if doc_choice == "-- Select --":
st.info("⬅️ Please choose a document from the sidebar.")
elif doc_choice == "Sample PDF":
temp_path = SAMPLE_PATH
st.success("πŸ“˜ Using built-in Sample PDF")
with st.spinner("πŸ” Extracting and processing document..."):
text = extract_text_from_pdf(temp_path)
chunks = chunk_text(text, chunk_size=chunk_size)
st.write(f"πŸ“‘ Extracted {len(chunks)} chunks.")
# βœ… Cached Embeddings
with st.spinner("βš™οΈ Loading cached embeddings or generating new ones..."):
embeddings = cache_embeddings(os.path.basename(temp_path), chunks, embed_chunks)
hash_name = hashlib.md5(os.path.basename(temp_path).encode()).hexdigest()
cache_file = f"/tmp/embed_cache/{hash_name}.pkl"
if os.path.exists(cache_file):
st.info(f"🧠 Using cached embeddings for {os.path.basename(temp_path)}")
else:
st.warning(f"πŸ’‘ Generated new embeddings for {os.path.basename(temp_path)}")
index = build_faiss_index(embeddings)
elif doc_choice == "Upload Custom PDF":
uploaded_file = st.file_uploader("πŸ“‚ Upload your PDF", type="pdf")
if uploaded_file:
temp_path = os.path.join("/tmp", uploaded_file.name)
with open(temp_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"βœ… File '{uploaded_file.name}' uploaded successfully")
with st.spinner("βš™οΈ Extracting and processing your document..."):
text = extract_text_from_pdf(temp_path)
chunks = chunk_text(text, chunk_size=chunk_size)
st.write(f"πŸ“„ Extracted {len(chunks)} chunks.")
with st.spinner("βš™οΈ Loading cached embeddings or generating new ones..."):
embeddings = cache_embeddings(os.path.basename(temp_path), chunks, embed_chunks)
hash_name = hashlib.md5(os.path.basename(temp_path).encode()).hexdigest()
cache_file = f"/tmp/embed_cache/{hash_name}.pkl"
if os.path.exists(cache_file):
st.info(f"🧠 Using cached embeddings for {os.path.basename(temp_path)}")
else:
st.warning(f"πŸ’‘ Generated new embeddings for {os.path.basename(temp_path)}")
index = build_faiss_index(embeddings)
st.success("πŸš€ Document processed successfully!")
# ==========================================================
# πŸ“‘ Document Preview
# ==========================================================
if chunks:
st.subheader("πŸ“‘ Document Preview")
st.text_area("Extracted text (first 1000 chars)", text[:1000], height=200)
avg_len = int(sum(len(c) for c in chunks) / len(chunks))
st.caption(f"πŸ“¦ {len(chunks)} chunks | Avg length: {avg_len} chars")
# ==========================================================
# πŸ’¬ Query Section
# ==========================================================
if index and chunks:
st.markdown("---")
st.subheader("πŸ€– Ask a Question")
user_query = st.text_input("πŸ” Your question about the document:")
if user_query:
mode_label = (
"🧠 Reasoning Mode (expanded thinking)"
if st.session_state.reasoning_mode
else "πŸ“„ Strict Document Mode (factual only)"
)
st.caption(f"Mode: {mode_label}")
with st.spinner("🧠 Thinking... retrieving context and generating answer..."):
retrieved = retrieve_chunks(user_query, index, chunks, top_k=top_k, embeddings=embeddings)
answer = generate_answer(user_query, retrieved, reasoning_mode=st.session_state.reasoning_mode)
# βœ… Display Answer
st.markdown("### βœ… Assistant’s Answer")
st.markdown(
f"<div style='background-color:#0E1117;padding:12px;border-radius:10px;color:white;'>{answer}</div>",
unsafe_allow_html=True
)
# πŸ“„ Supporting Chunks
with st.expander("πŸ“„ Supporting Chunks (Context Used)"):
for i, r in enumerate(retrieved, start=1):
st.markdown(
f"""
<div style='background-color:#111827;padding:10px;border-radius:8px;margin-bottom:6px;'>
<b>Chunk {i}:</b><br>{r}
</div>
""",
unsafe_allow_html=True,
)
else:
st.info("πŸ“₯ Upload or select a document to start exploring.")