Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- src/__init__.py +0 -0
- src/collections_setup.py +107 -0
- src/mylogging.py +45 -0
- src/run.py +128 -0
- src/runpod_setup.py +76 -0
- src/text_processing.py +92 -0
- src/utils.py +181 -0
src/__init__.py
ADDED
|
File without changes
|
src/collections_setup.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chromadb
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import fitz
|
| 4 |
+
import os
|
| 5 |
+
from chromadb.utils import embedding_functions
|
| 6 |
+
from text_processing import lines_chunking, paragraphs_chunking
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_chroma_client():
|
| 10 |
+
"""
|
| 11 |
+
Get an ephemeral ChromaDB client for session-based RAG.
|
| 12 |
+
Data is automatically deleted when user closes browser/session ends.
|
| 13 |
+
"""
|
| 14 |
+
return chromadb.EphemeralClient()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@st.cache_resource
|
| 18 |
+
def initialize_chroma_client():
|
| 19 |
+
"""
|
| 20 |
+
Initialize ChromaDB client and store in Streamlit's resource cache.
|
| 21 |
+
This ensures one client per Streamlit session.
|
| 22 |
+
"""
|
| 23 |
+
return get_chroma_client()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@st.cache_resource
|
| 27 |
+
def initialize_chromadb(embedding_model):
|
| 28 |
+
"""
|
| 29 |
+
Initialize ChromaDB client and embedding function.
|
| 30 |
+
Both are cached to avoid recreating on every rerun.
|
| 31 |
+
"""
|
| 32 |
+
# Get the cached client
|
| 33 |
+
client = initialize_chroma_client()
|
| 34 |
+
|
| 35 |
+
# Initialize an embedding function (using a Sentence Transformer model)
|
| 36 |
+
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
|
| 37 |
+
model_name=embedding_model
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return client, embedding_func
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def initialize_collection(client, embedding_func, collection_name):
|
| 44 |
+
"""
|
| 45 |
+
Initialize a collection in ChromaDB.
|
| 46 |
+
"""
|
| 47 |
+
collection = client.get_or_create_collection(
|
| 48 |
+
name=collection_name,
|
| 49 |
+
embedding_function=embedding_func,
|
| 50 |
+
metadata={"hnsw:space": "cosine"},
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return collection
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def update_collection(collection, files_to_add_to_collection):
|
| 57 |
+
"""
|
| 58 |
+
Update collection with new uploaded files.
|
| 59 |
+
Returns updated collection and session state.
|
| 60 |
+
"""
|
| 61 |
+
for file_to_add in files_to_add_to_collection:
|
| 62 |
+
|
| 63 |
+
current_file = next(
|
| 64 |
+
(file for file in st.session_state.get('uploaded_files_raw', [])
|
| 65 |
+
if file.name == file_to_add),None)
|
| 66 |
+
|
| 67 |
+
if current_file is None:
|
| 68 |
+
st.error(f"File '{file_to_add}' not found in uploaded files.")
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
# Read file content
|
| 72 |
+
try:
|
| 73 |
+
if current_file.type == "text/plain": # Handling TXT files
|
| 74 |
+
file_text = current_file.getvalue().decode("utf-8")
|
| 75 |
+
elif current_file.type == "application/pdf": # Handling PDFs
|
| 76 |
+
with fitz.open(stream=current_file.getvalue(), filetype="pdf") as pdf_document:
|
| 77 |
+
file_text = "\n".join([page.get_text("text") for page in pdf_document])
|
| 78 |
+
else:
|
| 79 |
+
st.warning(f"Unsupported file type: {current_file.name} type:{current_file.type}")
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# Tokenize text into chunks
|
| 83 |
+
max_words = 200
|
| 84 |
+
chunks = lines_chunking(file_text, max_words=max_words)
|
| 85 |
+
|
| 86 |
+
if not chunks: # Skip if no chunks generated
|
| 87 |
+
st.warning(f"No content extracted from {current_file.name}")
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
# Store chunks in the collection
|
| 91 |
+
filename = current_file.name
|
| 92 |
+
collection.add(
|
| 93 |
+
documents=chunks,
|
| 94 |
+
ids=[f"id{filename[:-4]}.{j}" for j in range(len(chunks))],
|
| 95 |
+
metadatas=[{"source": filename, "part": n} for n in range(len(chunks))],
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
st.session_state.collections_files_name.append(filename)
|
| 99 |
+
st.success(f"Added {len(chunks)} chunks from {filename}")
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
st.error(f"Error processing {current_file.name}: {str(e)}")
|
| 103 |
+
# Remove from session state if processing failed
|
| 104 |
+
st.session_state.uploaded_files_name.remove(filename)
|
| 105 |
+
|
| 106 |
+
return collection
|
| 107 |
+
|
src/mylogging.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import logging
|
| 3 |
+
import io
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def configure_logging():
|
| 7 |
+
"""
|
| 8 |
+
Configure logging.
|
| 9 |
+
"""
|
| 10 |
+
log_stream = io.StringIO()
|
| 11 |
+
handler = logging.StreamHandler(log_stream)
|
| 12 |
+
handler.setFormatter(logging.Formatter('%(message)s'))
|
| 13 |
+
handler.setLevel(logging.WARNING)
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger()
|
| 16 |
+
logger.setLevel(logging.WARNING)
|
| 17 |
+
logger.addHandler(handler)
|
| 18 |
+
|
| 19 |
+
return logger, log_stream
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def toggle_logging(level, logger):
|
| 23 |
+
"""
|
| 24 |
+
Toggle logging level.
|
| 25 |
+
"""
|
| 26 |
+
if level == 'DEBUG':
|
| 27 |
+
logger.setLevel(logging.DEBUG)
|
| 28 |
+
elif level == 'INFO':
|
| 29 |
+
logger.setLevel(logging.INFO)
|
| 30 |
+
elif level == 'WARNING':
|
| 31 |
+
logger.setLevel(logging.WARNING)
|
| 32 |
+
else:
|
| 33 |
+
logger.warning(f"Unknown logging level: {level}. Using WARNING as default.")
|
| 34 |
+
logger.setLevel(logging.WARNING)
|
| 35 |
+
for handler in logger.handlers:
|
| 36 |
+
handler.setLevel(logger.level)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def display_logs(log_stream):
|
| 40 |
+
"""
|
| 41 |
+
Display logs in the app
|
| 42 |
+
"""
|
| 43 |
+
log_stream.seek(0)
|
| 44 |
+
logs = log_stream.read()
|
| 45 |
+
st.text(logs)
|
src/run.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
from utils import load_background_image, apply_style, configure_page, breaks, file_uploader, initialise_session_state
|
| 4 |
+
from mylogging import configure_logging, toggle_logging, display_logs
|
| 5 |
+
from collections_setup import initialize_chromadb, initialize_collection, update_collection
|
| 6 |
+
from runpod_setup import get_relevant_text, generate_answer, get_contextual_prompt
|
| 7 |
+
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
|
| 10 |
+
configure_page()
|
| 11 |
+
apply_style()
|
| 12 |
+
load_background_image()
|
| 13 |
+
initialise_session_state()
|
| 14 |
+
breaks(2)
|
| 15 |
+
st.write(
|
| 16 |
+
"""
|
| 17 |
+
Welcome to this Streamlit app that demonstrates Retrieval-Augmented Generation (RAG) using a **Mistral-7B model hosted on Runpod** and **ChromaDB** for retrieval.
|
| 18 |
+
|
| 19 |
+
With this app, you can:
|
| 20 |
+
- Upload multiple PDF or text files to build a contextual knowledge base,
|
| 21 |
+
- Ask custom questions based on your uploaded documents, and
|
| 22 |
+
- Generate informed responses using a lightweight, hosted LLM.
|
| 23 |
+
|
| 24 |
+
**Note:** All uploaded files and generated embeddings are stored **in memory only** and will be **lost when the app is closed or restarted**. No data is persisted between sessions.
|
| 25 |
+
"""
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# # Disable Chroma telemetry
|
| 29 |
+
os.environ["CHROMA_TELEMETRY_ENABLED"] = "False"
|
| 30 |
+
|
| 31 |
+
# Initialize logger
|
| 32 |
+
logger, log_stream = configure_logging()
|
| 33 |
+
st.markdown(
|
| 34 |
+
"""
|
| 35 |
+
<style>
|
| 36 |
+
/* This targets the selectbox container */
|
| 37 |
+
div[data-baseweb="select"] {
|
| 38 |
+
max-width: 150px;
|
| 39 |
+
}
|
| 40 |
+
</style>
|
| 41 |
+
""",
|
| 42 |
+
unsafe_allow_html=True,
|
| 43 |
+
)
|
| 44 |
+
st.divider()
|
| 45 |
+
|
| 46 |
+
# ---- Logging Setup ----
|
| 47 |
+
use_logging = False
|
| 48 |
+
if use_logging:
|
| 49 |
+
logging_level = st.selectbox("Select logging level", ['INFO', 'DEBUG', 'WARNING'], index=2)
|
| 50 |
+
toggle_logging(logging_level, logger)
|
| 51 |
+
|
| 52 |
+
# ---- Vector Store Setup ----
|
| 53 |
+
# Initialize ChromaDB and collection
|
| 54 |
+
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
| 55 |
+
client, embedding_func = initialize_chromadb(EMBEDDING_MODEL)
|
| 56 |
+
collection_name = "my_collection"
|
| 57 |
+
collection = initialize_collection(client, embedding_func, collection_name)
|
| 58 |
+
|
| 59 |
+
# Upload files
|
| 60 |
+
st.markdown(
|
| 61 |
+
'<h3>Upload Files</h3>',
|
| 62 |
+
unsafe_allow_html=True)
|
| 63 |
+
st.html("""
|
| 64 |
+
Uploaded files are processed to build a contextual knowledge base for the RAG model.<br>
|
| 65 |
+
When you submit a prompt, the model retrieves relevant information from these documents to generate responses.
|
| 66 |
+
""")
|
| 67 |
+
col1_, _, col2_ = st.columns([.4, .1, .5])
|
| 68 |
+
with col1_:
|
| 69 |
+
file_uploader()
|
| 70 |
+
|
| 71 |
+
# Get the current uploaded filenames
|
| 72 |
+
logger.debug(f"\n\t-- Currently uploaded files: {st.session_state.get('uploaded_files_name', 'None')}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Update collection with uploaded files
|
| 76 |
+
files_to_add_to_collection= [file_name for file_name in st.session_state.get("uploaded_files_name", []) if file_name not in st.session_state.get("collections_files_name", [])]
|
| 77 |
+
logger.debug(f"\n\t-- Files not in collection: {files_to_add_to_collection}")
|
| 78 |
+
|
| 79 |
+
if files_to_add_to_collection:
|
| 80 |
+
collection = update_collection(collection, files_to_add_to_collection)
|
| 81 |
+
|
| 82 |
+
# Update the session state
|
| 83 |
+
logger.debug(f"Collection count: {collection.count()}")
|
| 84 |
+
logger.debug(f"\n\t-- Collection data currently uploaded:")
|
| 85 |
+
data_head = collection.get(limit=5)
|
| 86 |
+
for i, (metadata, document) in enumerate(zip(data_head["metadatas"], data_head["documents"]), start=1):
|
| 87 |
+
logger.debug(f"Item {i}:")
|
| 88 |
+
logger.debug(f"Metadata: {metadata}")
|
| 89 |
+
logger.debug(f"Document: {document}")
|
| 90 |
+
logger.debug("-" * 40)
|
| 91 |
+
|
| 92 |
+
# ---- Response Generation ----
|
| 93 |
+
# Streamlit UI
|
| 94 |
+
st.divider()
|
| 95 |
+
col1, _, col2 = st.columns([.6, .01, 1])
|
| 96 |
+
with col1:
|
| 97 |
+
st.subheader("Enter your prompt")
|
| 98 |
+
query = st.text_area("", height=200)
|
| 99 |
+
generate_clicked = st.button("Generate Response")
|
| 100 |
+
if generate_clicked:
|
| 101 |
+
if query.strip():
|
| 102 |
+
# Get the number of available documents in ChromaDB
|
| 103 |
+
available_docs = collection.count()
|
| 104 |
+
|
| 105 |
+
if available_docs > 0:
|
| 106 |
+
# Ensure n_results doesn't exceed available_docs
|
| 107 |
+
n_results = min(2, available_docs)
|
| 108 |
+
relevant_text = get_relevant_text(collection, query=query, nresults=n_results)
|
| 109 |
+
else:
|
| 110 |
+
relevant_text = "" # No documents available, so no additional context
|
| 111 |
+
st.warning("No knowledge base available. Generating response based only on the prompt.")
|
| 112 |
+
|
| 113 |
+
logger.debug("\n\t-- Relevant text retrieved:")
|
| 114 |
+
logger.debug(relevant_text)
|
| 115 |
+
|
| 116 |
+
with st.spinner("Generating response..."):
|
| 117 |
+
context_query = get_contextual_prompt(query, relevant_text)
|
| 118 |
+
response = generate_answer(context_query, max_tokens=200)
|
| 119 |
+
|
| 120 |
+
with col2:
|
| 121 |
+
st.subheader("Response:")
|
| 122 |
+
st.text_area("", value=response, height=200)
|
| 123 |
+
else:
|
| 124 |
+
logger.debug("No query provided; skipping relevant text retrieval.")
|
| 125 |
+
st.warning("Please enter a prompt.")
|
| 126 |
+
|
| 127 |
+
if use_logging:
|
| 128 |
+
display_logs(log_stream)
|
src/runpod_setup.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import os
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# Load .env from project root
|
| 7 |
+
load_dotenv(dotenv_path=Path(__file__).resolve().parents[1] / ".env")
|
| 8 |
+
|
| 9 |
+
API_KEY = os.getenv("RUNPOD_API_KEY")
|
| 10 |
+
ENDPOINT = os.getenv("RUNPOD_ENDPOINT")
|
| 11 |
+
|
| 12 |
+
HEADERS = {
|
| 13 |
+
"Authorization": f"Bearer {API_KEY}",
|
| 14 |
+
"Content-Type": "application/json"
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_relevant_text(collection, query='', nresults=3, sim_th=None):
|
| 19 |
+
"""
|
| 20 |
+
Get relevant text from a collection for a given query
|
| 21 |
+
"""
|
| 22 |
+
query_result = collection.query(query_texts=query, n_results=nresults)
|
| 23 |
+
docs = query_result.get('documents')[0]
|
| 24 |
+
if sim_th is not None:
|
| 25 |
+
similarities = [1 - d for d in query_result.get("distances")[0]]
|
| 26 |
+
relevant_docs = [d for d, s in zip(docs, similarities) if s >= sim_th]
|
| 27 |
+
return ''.join(relevant_docs)
|
| 28 |
+
return ''.join([doc for doc in docs if doc is not None])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_contextual_prompt(question, context):
|
| 32 |
+
"""
|
| 33 |
+
Optimized prompt format for Mistral 7B
|
| 34 |
+
"""
|
| 35 |
+
# Option 1: Mistral Chat Template (Recommended)
|
| 36 |
+
contextual_prompt = f"""<s>[INST] You are a helpful assistant that answers questions based on the provided context. Use only the information given in the context to answer the question. If the context doesn't contain enough information, say so clearly.
|
| 37 |
+
|
| 38 |
+
Context:
|
| 39 |
+
{context}
|
| 40 |
+
|
| 41 |
+
Question: {question} [/INST]"""
|
| 42 |
+
|
| 43 |
+
return contextual_prompt
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def generate_answer(prompt, max_tokens=150, temperature=0.7, HEADERS=HEADERS, ENDPOINT=ENDPOINT):
|
| 47 |
+
"""
|
| 48 |
+
Submit a prompt to the RunPod SYNC endpoint and get back a response string.
|
| 49 |
+
"""
|
| 50 |
+
payload = {
|
| 51 |
+
"input": {
|
| 52 |
+
"prompt": prompt,
|
| 53 |
+
"max_tokens": max_tokens,
|
| 54 |
+
"temperature": temperature
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
# Use /runsync instead of /run - immediate response!
|
| 60 |
+
response = requests.post(f"{ENDPOINT}/runsync", headers=HEADERS, json=payload, timeout=65)
|
| 61 |
+
response.raise_for_status()
|
| 62 |
+
result = response.json()
|
| 63 |
+
|
| 64 |
+
print(f"[RunPod] Request completed successfully")
|
| 65 |
+
|
| 66 |
+
if result.get("status") == "COMPLETED":
|
| 67 |
+
return result["output"]["response"]
|
| 68 |
+
else:
|
| 69 |
+
error_msg = result.get("error", "Unknown error")
|
| 70 |
+
raise RuntimeError(f"RunPod job failed: {error_msg}")
|
| 71 |
+
|
| 72 |
+
except requests.exceptions.Timeout:
|
| 73 |
+
raise RuntimeError("Request timed out (>60s). Try reducing prompt length or max_tokens.")
|
| 74 |
+
except requests.exceptions.RequestException as e:
|
| 75 |
+
raise RuntimeError(f"RunPod API error: {e}")
|
| 76 |
+
|
src/text_processing.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nltk
|
| 2 |
+
from nltk.tokenize import sent_tokenize
|
| 3 |
+
nltk.download('punkt_tab')
|
| 4 |
+
nltk.download("punkt")
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def paragraphs_chunking(text, max_words=200, max_sentence_words=50):
|
| 8 |
+
"""
|
| 9 |
+
Splits text into structured chunks, preserving paragraph integrity and avoiding unnatural breaks.
|
| 10 |
+
- Uses paragraph-based splitting first.
|
| 11 |
+
- Splits long paragraphs into smaller chunks based on sentence boundaries.
|
| 12 |
+
"""
|
| 13 |
+
# Split text into paragraphs first
|
| 14 |
+
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
| 15 |
+
|
| 16 |
+
chunks = []
|
| 17 |
+
for para in paragraphs:
|
| 18 |
+
words = para.split()
|
| 19 |
+
|
| 20 |
+
# If paragraph is within limit, keep as a single chunk
|
| 21 |
+
if len(words) <= max_words:
|
| 22 |
+
chunks.append(para)
|
| 23 |
+
continue
|
| 24 |
+
|
| 25 |
+
# Sentence-based chunking for large paragraphs
|
| 26 |
+
sentences = sent_tokenize(para)
|
| 27 |
+
chunk, chunk_word_count = [], 0
|
| 28 |
+
|
| 29 |
+
for sentence in sentences:
|
| 30 |
+
sentence_word_count = len(sentence.split())
|
| 31 |
+
|
| 32 |
+
# If adding this sentence keeps chunk within word limit, add it
|
| 33 |
+
if chunk_word_count + sentence_word_count <= max_words:
|
| 34 |
+
chunk.append(sentence)
|
| 35 |
+
chunk_word_count += sentence_word_count
|
| 36 |
+
else:
|
| 37 |
+
# Finalize current chunk and start a new one
|
| 38 |
+
chunks.append(" ".join(chunk))
|
| 39 |
+
chunk = [sentence]
|
| 40 |
+
chunk_word_count = sentence_word_count
|
| 41 |
+
|
| 42 |
+
# Append any remaining chunk
|
| 43 |
+
if chunk:
|
| 44 |
+
chunks.append(" ".join(chunk))
|
| 45 |
+
|
| 46 |
+
return chunks
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def lines_chunking(text, max_words=200):
|
| 50 |
+
"""
|
| 51 |
+
Splits text into structured chunks, preserving paragraph integrity and avoiding unnatural breaks.
|
| 52 |
+
- Uses paragraph-based splitting first.
|
| 53 |
+
- Splits long paragraphs into smaller chunks based on sentence boundaries.
|
| 54 |
+
"""
|
| 55 |
+
# Split text into lines
|
| 56 |
+
lines = text.splitlines()
|
| 57 |
+
|
| 58 |
+
# Group lines into paragraphs
|
| 59 |
+
paragraphs = []
|
| 60 |
+
current_paragraph = []
|
| 61 |
+
for line in lines:
|
| 62 |
+
if line.strip():
|
| 63 |
+
current_paragraph.append(line.strip())
|
| 64 |
+
else: # Empty line indicates end of paragraph
|
| 65 |
+
if current_paragraph:
|
| 66 |
+
paragraphs.append(" ".join(current_paragraph))
|
| 67 |
+
current_paragraph = []
|
| 68 |
+
if current_paragraph:
|
| 69 |
+
paragraphs.append(" ".join(current_paragraph))
|
| 70 |
+
|
| 71 |
+
# Process paragraphs
|
| 72 |
+
chunks = []
|
| 73 |
+
for para in paragraphs:
|
| 74 |
+
words = para.split()
|
| 75 |
+
if len(words) <= max_words:
|
| 76 |
+
chunks.append(para)
|
| 77 |
+
else:
|
| 78 |
+
sentences = sent_tokenize(para)
|
| 79 |
+
chunk, chunk_word_count = [], 0
|
| 80 |
+
for sentence in sentences:
|
| 81 |
+
sentence_word_count = len(sentence.split())
|
| 82 |
+
if chunk_word_count + sentence_word_count <= max_words:
|
| 83 |
+
chunk.append(sentence)
|
| 84 |
+
chunk_word_count += sentence_word_count
|
| 85 |
+
else:
|
| 86 |
+
chunks.append(" ".join(chunk))
|
| 87 |
+
chunk = [sentence]
|
| 88 |
+
chunk_word_count = sentence_word_count
|
| 89 |
+
if chunk:
|
| 90 |
+
chunks.append(" ".join(chunk))
|
| 91 |
+
|
| 92 |
+
return chunks
|
src/utils.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
import sqlite3
|
| 4 |
+
import base64
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
DEFAULT_SESSION_STATE = {
|
| 8 |
+
# PDF Upload
|
| 9 |
+
'uploaded_files_name': [],
|
| 10 |
+
'collections_files_name': [],
|
| 11 |
+
'uploaded_files_raw': [],
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def configure_page() -> None:
|
| 16 |
+
"""
|
| 17 |
+
Configures the Streamlit page.
|
| 18 |
+
"""
|
| 19 |
+
st.set_page_config(page_title="myRAG",
|
| 20 |
+
layout="wide",
|
| 21 |
+
page_icon=":rocket:")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def apply_style():
|
| 25 |
+
st.markdown("""
|
| 26 |
+
<style>
|
| 27 |
+
@import url('https://fonts.googleapis.com/css2?family=Work+Sans:wght@400;700&display=swap');
|
| 28 |
+
html, body, .stApp,
|
| 29 |
+
.css-1v3fvcr, .css-ffhzg2, .css-1d391kg,
|
| 30 |
+
div[data-testid="stMarkdownContainer"],
|
| 31 |
+
div[data-testid="stText"],
|
| 32 |
+
div[data-testid="stTextInput"],
|
| 33 |
+
div[data-testid="stSelectbox"],
|
| 34 |
+
div[data-testid="stCheckbox"],
|
| 35 |
+
div[data-testid="stSlider"],
|
| 36 |
+
label, input, textarea, button, select,
|
| 37 |
+
.stButton, .stTextInput > div, .stMarkdown, .stCaption,
|
| 38 |
+
.streamlit-expanderHeader, .st-expander > div,
|
| 39 |
+
h1, h2, h3, h4, h5, h6,
|
| 40 |
+
.stMarkdown h1, .stMarkdown h2, .stMarkdown h3 {
|
| 41 |
+
font-family: 'Work Sans', sans-serif !important;
|
| 42 |
+
}
|
| 43 |
+
/* Ensure bold text uses the correct font weight */
|
| 44 |
+
strong, b, .stMarkdown strong, .stMarkdown b {
|
| 45 |
+
font-family: 'Work Sans', sans-serif !important;
|
| 46 |
+
font-weight: 700 !important;
|
| 47 |
+
}
|
| 48 |
+
</style>
|
| 49 |
+
""", unsafe_allow_html=True)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def breaks(n=1):
|
| 53 |
+
"""
|
| 54 |
+
Creates a line break.
|
| 55 |
+
"""
|
| 56 |
+
if n == 1:
|
| 57 |
+
st.markdown("<br>",unsafe_allow_html=True)
|
| 58 |
+
elif n == 2:
|
| 59 |
+
st.markdown("<br><br>",unsafe_allow_html=True)
|
| 60 |
+
elif n == 3:
|
| 61 |
+
st.markdown("<br><br><br>",unsafe_allow_html=True)
|
| 62 |
+
else:
|
| 63 |
+
st.markdown("<br><br><br><br>",unsafe_allow_html=True)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_base64_encoded_image(image_path):
|
| 67 |
+
"""
|
| 68 |
+
Reads an image file and encodes it to Base64.
|
| 69 |
+
"""
|
| 70 |
+
with open(image_path, "rb") as img_file:
|
| 71 |
+
return base64.b64encode(img_file.read()).decode()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_background_image():
|
| 75 |
+
"""
|
| 76 |
+
Loads and displays a background image with an overlaid title.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
possible_paths = [
|
| 80 |
+
"../images/image6.jpg", # Local development (from src/ folder)
|
| 81 |
+
"images/image6.jpg", # Docker container (from /app)
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
image_path = None
|
| 85 |
+
for path in possible_paths:
|
| 86 |
+
if os.path.exists(path):
|
| 87 |
+
image_path = path
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
if not image_path:
|
| 91 |
+
st.error("Could not find image6.jpg in any expected location")
|
| 92 |
+
return
|
| 93 |
+
|
| 94 |
+
base64_image = get_base64_encoded_image(image_path)
|
| 95 |
+
|
| 96 |
+
# Inject CSS for the background and title overlay
|
| 97 |
+
st.markdown(
|
| 98 |
+
f"""
|
| 99 |
+
<style>
|
| 100 |
+
/* Background container with image */
|
| 101 |
+
.bg-container {{
|
| 102 |
+
position: relative;
|
| 103 |
+
background-image: url("data:image/png;base64,{base64_image}");
|
| 104 |
+
background-size: container;
|
| 105 |
+
background-position: center;
|
| 106 |
+
height: 150px; /* Adjust the height of the background */
|
| 107 |
+
width: 100%;
|
| 108 |
+
margin: 0 auto;
|
| 109 |
+
filter: contrast(110%) brightness(210%); /* Dim the brightness of the image */
|
| 110 |
+
border-radius: 100px; /* Makes the container's corners rounded */
|
| 111 |
+
overflow: hidden;
|
| 112 |
+
}}
|
| 113 |
+
|
| 114 |
+
/* Overlay for dimming effect */
|
| 115 |
+
.bg-container::after {{
|
| 116 |
+
content: '';
|
| 117 |
+
position: absolute;
|
| 118 |
+
top: ;
|
| 119 |
+
left: 0;
|
| 120 |
+
width: 100%;
|
| 121 |
+
height: 100%;
|
| 122 |
+
background-color: rgba(20, 10, 20, 0.44); /* Semi-transparent black overlay */
|
| 123 |
+
z-index: 1; /* Ensure the overlay is above the image */
|
| 124 |
+
}}
|
| 125 |
+
|
| 126 |
+
/* Overlay title styling */
|
| 127 |
+
.overlay-title {{
|
| 128 |
+
position: absolute;
|
| 129 |
+
top: 50%;
|
| 130 |
+
left: 50%;
|
| 131 |
+
transform: translate(-50%, -50%);
|
| 132 |
+
color: black; /* Title color */
|
| 133 |
+
font-size: 50px;
|
| 134 |
+
font-weight: bold;
|
| 135 |
+
text-shadow: 1px 1px 3px rgba(255, 255, 255, .0); /* Shadow for better visibility */
|
| 136 |
+
text-align: center;
|
| 137 |
+
z-index: 2; /* Ensure the title is above the overlay */
|
| 138 |
+
}}
|
| 139 |
+
</style>
|
| 140 |
+
""",
|
| 141 |
+
unsafe_allow_html=True
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Create the background container with an overlaid title
|
| 145 |
+
st.markdown(
|
| 146 |
+
"""
|
| 147 |
+
<div class="bg-container">
|
| 148 |
+
<div class="overlay-title">Mistral-RAG</div>
|
| 149 |
+
</div>
|
| 150 |
+
""",
|
| 151 |
+
unsafe_allow_html=True
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def initialise_session_state():
|
| 156 |
+
"""
|
| 157 |
+
Initializes the session state variables if not already set.
|
| 158 |
+
"""
|
| 159 |
+
for key, default_val in DEFAULT_SESSION_STATE.items():
|
| 160 |
+
if key not in st.session_state:
|
| 161 |
+
st.session_state[key] = default_val
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def file_uploader():
|
| 165 |
+
uploaded_files = st.file_uploader(
|
| 166 |
+
"",
|
| 167 |
+
type=["txt", "pdf"],
|
| 168 |
+
accept_multiple_files=True)
|
| 169 |
+
|
| 170 |
+
if uploaded_files: # Check if list is not empty
|
| 171 |
+
for file in uploaded_files: # Process each file
|
| 172 |
+
if file.name not in st.session_state.uploaded_files_name:
|
| 173 |
+
# Append to session state lists safely
|
| 174 |
+
st.session_state.uploaded_files_name.append(file.name)
|
| 175 |
+
st.session_state.uploaded_files_raw.append(file)
|
| 176 |
+
st.success(f"Added new file: {file.name}")
|
| 177 |
+
|
| 178 |
+
else:
|
| 179 |
+
st.info("Please upload a PDF file to proceed.")
|
| 180 |
+
|
| 181 |
+
|