davidepanza commited on
Commit
1c5d91d
·
verified ·
1 Parent(s): ad74fc0

Upload 7 files

Browse files
src/__init__.py ADDED
File without changes
src/collections_setup.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ import streamlit as st
3
+ import fitz
4
+ import os
5
+ from chromadb.utils import embedding_functions
6
+ from text_processing import lines_chunking, paragraphs_chunking
7
+
8
+
9
+ def get_chroma_client():
10
+ """
11
+ Get an ephemeral ChromaDB client for session-based RAG.
12
+ Data is automatically deleted when user closes browser/session ends.
13
+ """
14
+ return chromadb.EphemeralClient()
15
+
16
+
17
+ @st.cache_resource
18
+ def initialize_chroma_client():
19
+ """
20
+ Initialize ChromaDB client and store in Streamlit's resource cache.
21
+ This ensures one client per Streamlit session.
22
+ """
23
+ return get_chroma_client()
24
+
25
+
26
+ @st.cache_resource
27
+ def initialize_chromadb(embedding_model):
28
+ """
29
+ Initialize ChromaDB client and embedding function.
30
+ Both are cached to avoid recreating on every rerun.
31
+ """
32
+ # Get the cached client
33
+ client = initialize_chroma_client()
34
+
35
+ # Initialize an embedding function (using a Sentence Transformer model)
36
+ embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
37
+ model_name=embedding_model
38
+ )
39
+
40
+ return client, embedding_func
41
+
42
+
43
+ def initialize_collection(client, embedding_func, collection_name):
44
+ """
45
+ Initialize a collection in ChromaDB.
46
+ """
47
+ collection = client.get_or_create_collection(
48
+ name=collection_name,
49
+ embedding_function=embedding_func,
50
+ metadata={"hnsw:space": "cosine"},
51
+ )
52
+
53
+ return collection
54
+
55
+
56
+ def update_collection(collection, files_to_add_to_collection):
57
+ """
58
+ Update collection with new uploaded files.
59
+ Returns updated collection and session state.
60
+ """
61
+ for file_to_add in files_to_add_to_collection:
62
+
63
+ current_file = next(
64
+ (file for file in st.session_state.get('uploaded_files_raw', [])
65
+ if file.name == file_to_add),None)
66
+
67
+ if current_file is None:
68
+ st.error(f"File '{file_to_add}' not found in uploaded files.")
69
+ continue
70
+
71
+ # Read file content
72
+ try:
73
+ if current_file.type == "text/plain": # Handling TXT files
74
+ file_text = current_file.getvalue().decode("utf-8")
75
+ elif current_file.type == "application/pdf": # Handling PDFs
76
+ with fitz.open(stream=current_file.getvalue(), filetype="pdf") as pdf_document:
77
+ file_text = "\n".join([page.get_text("text") for page in pdf_document])
78
+ else:
79
+ st.warning(f"Unsupported file type: {current_file.name} type:{current_file.type}")
80
+ continue
81
+
82
+ # Tokenize text into chunks
83
+ max_words = 200
84
+ chunks = lines_chunking(file_text, max_words=max_words)
85
+
86
+ if not chunks: # Skip if no chunks generated
87
+ st.warning(f"No content extracted from {current_file.name}")
88
+ continue
89
+
90
+ # Store chunks in the collection
91
+ filename = current_file.name
92
+ collection.add(
93
+ documents=chunks,
94
+ ids=[f"id{filename[:-4]}.{j}" for j in range(len(chunks))],
95
+ metadatas=[{"source": filename, "part": n} for n in range(len(chunks))],
96
+ )
97
+
98
+ st.session_state.collections_files_name.append(filename)
99
+ st.success(f"Added {len(chunks)} chunks from {filename}")
100
+
101
+ except Exception as e:
102
+ st.error(f"Error processing {current_file.name}: {str(e)}")
103
+ # Remove from session state if processing failed
104
+ st.session_state.uploaded_files_name.remove(filename)
105
+
106
+ return collection
107
+
src/mylogging.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import logging
3
+ import io
4
+
5
+
6
+ def configure_logging():
7
+ """
8
+ Configure logging.
9
+ """
10
+ log_stream = io.StringIO()
11
+ handler = logging.StreamHandler(log_stream)
12
+ handler.setFormatter(logging.Formatter('%(message)s'))
13
+ handler.setLevel(logging.WARNING)
14
+
15
+ logger = logging.getLogger()
16
+ logger.setLevel(logging.WARNING)
17
+ logger.addHandler(handler)
18
+
19
+ return logger, log_stream
20
+
21
+
22
+ def toggle_logging(level, logger):
23
+ """
24
+ Toggle logging level.
25
+ """
26
+ if level == 'DEBUG':
27
+ logger.setLevel(logging.DEBUG)
28
+ elif level == 'INFO':
29
+ logger.setLevel(logging.INFO)
30
+ elif level == 'WARNING':
31
+ logger.setLevel(logging.WARNING)
32
+ else:
33
+ logger.warning(f"Unknown logging level: {level}. Using WARNING as default.")
34
+ logger.setLevel(logging.WARNING)
35
+ for handler in logger.handlers:
36
+ handler.setLevel(logger.level)
37
+
38
+
39
+ def display_logs(log_stream):
40
+ """
41
+ Display logs in the app
42
+ """
43
+ log_stream.seek(0)
44
+ logs = log_stream.read()
45
+ st.text(logs)
src/run.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from utils import load_background_image, apply_style, configure_page, breaks, file_uploader, initialise_session_state
4
+ from mylogging import configure_logging, toggle_logging, display_logs
5
+ from collections_setup import initialize_chromadb, initialize_collection, update_collection
6
+ from runpod_setup import get_relevant_text, generate_answer, get_contextual_prompt
7
+
8
+ if __name__ == "__main__":
9
+
10
+ configure_page()
11
+ apply_style()
12
+ load_background_image()
13
+ initialise_session_state()
14
+ breaks(2)
15
+ st.write(
16
+ """
17
+ Welcome to this Streamlit app that demonstrates Retrieval-Augmented Generation (RAG) using a **Mistral-7B model hosted on Runpod** and **ChromaDB** for retrieval.
18
+
19
+ With this app, you can:
20
+ - Upload multiple PDF or text files to build a contextual knowledge base,
21
+ - Ask custom questions based on your uploaded documents, and
22
+ - Generate informed responses using a lightweight, hosted LLM.
23
+
24
+ **Note:** All uploaded files and generated embeddings are stored **in memory only** and will be **lost when the app is closed or restarted**. No data is persisted between sessions.
25
+ """
26
+ )
27
+
28
+ # # Disable Chroma telemetry
29
+ os.environ["CHROMA_TELEMETRY_ENABLED"] = "False"
30
+
31
+ # Initialize logger
32
+ logger, log_stream = configure_logging()
33
+ st.markdown(
34
+ """
35
+ <style>
36
+ /* This targets the selectbox container */
37
+ div[data-baseweb="select"] {
38
+ max-width: 150px;
39
+ }
40
+ </style>
41
+ """,
42
+ unsafe_allow_html=True,
43
+ )
44
+ st.divider()
45
+
46
+ # ---- Logging Setup ----
47
+ use_logging = False
48
+ if use_logging:
49
+ logging_level = st.selectbox("Select logging level", ['INFO', 'DEBUG', 'WARNING'], index=2)
50
+ toggle_logging(logging_level, logger)
51
+
52
+ # ---- Vector Store Setup ----
53
+ # Initialize ChromaDB and collection
54
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
55
+ client, embedding_func = initialize_chromadb(EMBEDDING_MODEL)
56
+ collection_name = "my_collection"
57
+ collection = initialize_collection(client, embedding_func, collection_name)
58
+
59
+ # Upload files
60
+ st.markdown(
61
+ '<h3>Upload Files</h3>',
62
+ unsafe_allow_html=True)
63
+ st.html("""
64
+ Uploaded files are processed to build a contextual knowledge base for the RAG model.<br>
65
+ When you submit a prompt, the model retrieves relevant information from these documents to generate responses.
66
+ """)
67
+ col1_, _, col2_ = st.columns([.4, .1, .5])
68
+ with col1_:
69
+ file_uploader()
70
+
71
+ # Get the current uploaded filenames
72
+ logger.debug(f"\n\t-- Currently uploaded files: {st.session_state.get('uploaded_files_name', 'None')}")
73
+
74
+
75
+ # Update collection with uploaded files
76
+ files_to_add_to_collection= [file_name for file_name in st.session_state.get("uploaded_files_name", []) if file_name not in st.session_state.get("collections_files_name", [])]
77
+ logger.debug(f"\n\t-- Files not in collection: {files_to_add_to_collection}")
78
+
79
+ if files_to_add_to_collection:
80
+ collection = update_collection(collection, files_to_add_to_collection)
81
+
82
+ # Update the session state
83
+ logger.debug(f"Collection count: {collection.count()}")
84
+ logger.debug(f"\n\t-- Collection data currently uploaded:")
85
+ data_head = collection.get(limit=5)
86
+ for i, (metadata, document) in enumerate(zip(data_head["metadatas"], data_head["documents"]), start=1):
87
+ logger.debug(f"Item {i}:")
88
+ logger.debug(f"Metadata: {metadata}")
89
+ logger.debug(f"Document: {document}")
90
+ logger.debug("-" * 40)
91
+
92
+ # ---- Response Generation ----
93
+ # Streamlit UI
94
+ st.divider()
95
+ col1, _, col2 = st.columns([.6, .01, 1])
96
+ with col1:
97
+ st.subheader("Enter your prompt")
98
+ query = st.text_area("", height=200)
99
+ generate_clicked = st.button("Generate Response")
100
+ if generate_clicked:
101
+ if query.strip():
102
+ # Get the number of available documents in ChromaDB
103
+ available_docs = collection.count()
104
+
105
+ if available_docs > 0:
106
+ # Ensure n_results doesn't exceed available_docs
107
+ n_results = min(2, available_docs)
108
+ relevant_text = get_relevant_text(collection, query=query, nresults=n_results)
109
+ else:
110
+ relevant_text = "" # No documents available, so no additional context
111
+ st.warning("No knowledge base available. Generating response based only on the prompt.")
112
+
113
+ logger.debug("\n\t-- Relevant text retrieved:")
114
+ logger.debug(relevant_text)
115
+
116
+ with st.spinner("Generating response..."):
117
+ context_query = get_contextual_prompt(query, relevant_text)
118
+ response = generate_answer(context_query, max_tokens=200)
119
+
120
+ with col2:
121
+ st.subheader("Response:")
122
+ st.text_area("", value=response, height=200)
123
+ else:
124
+ logger.debug("No query provided; skipping relevant text retrieval.")
125
+ st.warning("Please enter a prompt.")
126
+
127
+ if use_logging:
128
+ display_logs(log_stream)
src/runpod_setup.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from pathlib import Path
5
+
6
+ # Load .env from project root
7
+ load_dotenv(dotenv_path=Path(__file__).resolve().parents[1] / ".env")
8
+
9
+ API_KEY = os.getenv("RUNPOD_API_KEY")
10
+ ENDPOINT = os.getenv("RUNPOD_ENDPOINT")
11
+
12
+ HEADERS = {
13
+ "Authorization": f"Bearer {API_KEY}",
14
+ "Content-Type": "application/json"
15
+ }
16
+
17
+
18
+ def get_relevant_text(collection, query='', nresults=3, sim_th=None):
19
+ """
20
+ Get relevant text from a collection for a given query
21
+ """
22
+ query_result = collection.query(query_texts=query, n_results=nresults)
23
+ docs = query_result.get('documents')[0]
24
+ if sim_th is not None:
25
+ similarities = [1 - d for d in query_result.get("distances")[0]]
26
+ relevant_docs = [d for d, s in zip(docs, similarities) if s >= sim_th]
27
+ return ''.join(relevant_docs)
28
+ return ''.join([doc for doc in docs if doc is not None])
29
+
30
+
31
+ def get_contextual_prompt(question, context):
32
+ """
33
+ Optimized prompt format for Mistral 7B
34
+ """
35
+ # Option 1: Mistral Chat Template (Recommended)
36
+ contextual_prompt = f"""<s>[INST] You are a helpful assistant that answers questions based on the provided context. Use only the information given in the context to answer the question. If the context doesn't contain enough information, say so clearly.
37
+
38
+ Context:
39
+ {context}
40
+
41
+ Question: {question} [/INST]"""
42
+
43
+ return contextual_prompt
44
+
45
+
46
+ def generate_answer(prompt, max_tokens=150, temperature=0.7, HEADERS=HEADERS, ENDPOINT=ENDPOINT):
47
+ """
48
+ Submit a prompt to the RunPod SYNC endpoint and get back a response string.
49
+ """
50
+ payload = {
51
+ "input": {
52
+ "prompt": prompt,
53
+ "max_tokens": max_tokens,
54
+ "temperature": temperature
55
+ }
56
+ }
57
+
58
+ try:
59
+ # Use /runsync instead of /run - immediate response!
60
+ response = requests.post(f"{ENDPOINT}/runsync", headers=HEADERS, json=payload, timeout=65)
61
+ response.raise_for_status()
62
+ result = response.json()
63
+
64
+ print(f"[RunPod] Request completed successfully")
65
+
66
+ if result.get("status") == "COMPLETED":
67
+ return result["output"]["response"]
68
+ else:
69
+ error_msg = result.get("error", "Unknown error")
70
+ raise RuntimeError(f"RunPod job failed: {error_msg}")
71
+
72
+ except requests.exceptions.Timeout:
73
+ raise RuntimeError("Request timed out (>60s). Try reducing prompt length or max_tokens.")
74
+ except requests.exceptions.RequestException as e:
75
+ raise RuntimeError(f"RunPod API error: {e}")
76
+
src/text_processing.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from nltk.tokenize import sent_tokenize
3
+ nltk.download('punkt_tab')
4
+ nltk.download("punkt")
5
+
6
+
7
+ def paragraphs_chunking(text, max_words=200, max_sentence_words=50):
8
+ """
9
+ Splits text into structured chunks, preserving paragraph integrity and avoiding unnatural breaks.
10
+ - Uses paragraph-based splitting first.
11
+ - Splits long paragraphs into smaller chunks based on sentence boundaries.
12
+ """
13
+ # Split text into paragraphs first
14
+ paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
15
+
16
+ chunks = []
17
+ for para in paragraphs:
18
+ words = para.split()
19
+
20
+ # If paragraph is within limit, keep as a single chunk
21
+ if len(words) <= max_words:
22
+ chunks.append(para)
23
+ continue
24
+
25
+ # Sentence-based chunking for large paragraphs
26
+ sentences = sent_tokenize(para)
27
+ chunk, chunk_word_count = [], 0
28
+
29
+ for sentence in sentences:
30
+ sentence_word_count = len(sentence.split())
31
+
32
+ # If adding this sentence keeps chunk within word limit, add it
33
+ if chunk_word_count + sentence_word_count <= max_words:
34
+ chunk.append(sentence)
35
+ chunk_word_count += sentence_word_count
36
+ else:
37
+ # Finalize current chunk and start a new one
38
+ chunks.append(" ".join(chunk))
39
+ chunk = [sentence]
40
+ chunk_word_count = sentence_word_count
41
+
42
+ # Append any remaining chunk
43
+ if chunk:
44
+ chunks.append(" ".join(chunk))
45
+
46
+ return chunks
47
+
48
+
49
+ def lines_chunking(text, max_words=200):
50
+ """
51
+ Splits text into structured chunks, preserving paragraph integrity and avoiding unnatural breaks.
52
+ - Uses paragraph-based splitting first.
53
+ - Splits long paragraphs into smaller chunks based on sentence boundaries.
54
+ """
55
+ # Split text into lines
56
+ lines = text.splitlines()
57
+
58
+ # Group lines into paragraphs
59
+ paragraphs = []
60
+ current_paragraph = []
61
+ for line in lines:
62
+ if line.strip():
63
+ current_paragraph.append(line.strip())
64
+ else: # Empty line indicates end of paragraph
65
+ if current_paragraph:
66
+ paragraphs.append(" ".join(current_paragraph))
67
+ current_paragraph = []
68
+ if current_paragraph:
69
+ paragraphs.append(" ".join(current_paragraph))
70
+
71
+ # Process paragraphs
72
+ chunks = []
73
+ for para in paragraphs:
74
+ words = para.split()
75
+ if len(words) <= max_words:
76
+ chunks.append(para)
77
+ else:
78
+ sentences = sent_tokenize(para)
79
+ chunk, chunk_word_count = [], 0
80
+ for sentence in sentences:
81
+ sentence_word_count = len(sentence.split())
82
+ if chunk_word_count + sentence_word_count <= max_words:
83
+ chunk.append(sentence)
84
+ chunk_word_count += sentence_word_count
85
+ else:
86
+ chunks.append(" ".join(chunk))
87
+ chunk = [sentence]
88
+ chunk_word_count = sentence_word_count
89
+ if chunk:
90
+ chunks.append(" ".join(chunk))
91
+
92
+ return chunks
src/utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import sqlite3
4
+ import base64
5
+
6
+
7
+ DEFAULT_SESSION_STATE = {
8
+ # PDF Upload
9
+ 'uploaded_files_name': [],
10
+ 'collections_files_name': [],
11
+ 'uploaded_files_raw': [],
12
+ }
13
+
14
+
15
+ def configure_page() -> None:
16
+ """
17
+ Configures the Streamlit page.
18
+ """
19
+ st.set_page_config(page_title="myRAG",
20
+ layout="wide",
21
+ page_icon=":rocket:")
22
+
23
+
24
+ def apply_style():
25
+ st.markdown("""
26
+ <style>
27
+ @import url('https://fonts.googleapis.com/css2?family=Work+Sans:wght@400;700&display=swap');
28
+ html, body, .stApp,
29
+ .css-1v3fvcr, .css-ffhzg2, .css-1d391kg,
30
+ div[data-testid="stMarkdownContainer"],
31
+ div[data-testid="stText"],
32
+ div[data-testid="stTextInput"],
33
+ div[data-testid="stSelectbox"],
34
+ div[data-testid="stCheckbox"],
35
+ div[data-testid="stSlider"],
36
+ label, input, textarea, button, select,
37
+ .stButton, .stTextInput > div, .stMarkdown, .stCaption,
38
+ .streamlit-expanderHeader, .st-expander > div,
39
+ h1, h2, h3, h4, h5, h6,
40
+ .stMarkdown h1, .stMarkdown h2, .stMarkdown h3 {
41
+ font-family: 'Work Sans', sans-serif !important;
42
+ }
43
+ /* Ensure bold text uses the correct font weight */
44
+ strong, b, .stMarkdown strong, .stMarkdown b {
45
+ font-family: 'Work Sans', sans-serif !important;
46
+ font-weight: 700 !important;
47
+ }
48
+ </style>
49
+ """, unsafe_allow_html=True)
50
+
51
+
52
+ def breaks(n=1):
53
+ """
54
+ Creates a line break.
55
+ """
56
+ if n == 1:
57
+ st.markdown("<br>",unsafe_allow_html=True)
58
+ elif n == 2:
59
+ st.markdown("<br><br>",unsafe_allow_html=True)
60
+ elif n == 3:
61
+ st.markdown("<br><br><br>",unsafe_allow_html=True)
62
+ else:
63
+ st.markdown("<br><br><br><br>",unsafe_allow_html=True)
64
+
65
+
66
+ def get_base64_encoded_image(image_path):
67
+ """
68
+ Reads an image file and encodes it to Base64.
69
+ """
70
+ with open(image_path, "rb") as img_file:
71
+ return base64.b64encode(img_file.read()).decode()
72
+
73
+
74
+ def load_background_image():
75
+ """
76
+ Loads and displays a background image with an overlaid title.
77
+ """
78
+
79
+ possible_paths = [
80
+ "../images/image6.jpg", # Local development (from src/ folder)
81
+ "images/image6.jpg", # Docker container (from /app)
82
+ ]
83
+
84
+ image_path = None
85
+ for path in possible_paths:
86
+ if os.path.exists(path):
87
+ image_path = path
88
+ break
89
+
90
+ if not image_path:
91
+ st.error("Could not find image6.jpg in any expected location")
92
+ return
93
+
94
+ base64_image = get_base64_encoded_image(image_path)
95
+
96
+ # Inject CSS for the background and title overlay
97
+ st.markdown(
98
+ f"""
99
+ <style>
100
+ /* Background container with image */
101
+ .bg-container {{
102
+ position: relative;
103
+ background-image: url("data:image/png;base64,{base64_image}");
104
+ background-size: container;
105
+ background-position: center;
106
+ height: 150px; /* Adjust the height of the background */
107
+ width: 100%;
108
+ margin: 0 auto;
109
+ filter: contrast(110%) brightness(210%); /* Dim the brightness of the image */
110
+ border-radius: 100px; /* Makes the container's corners rounded */
111
+ overflow: hidden;
112
+ }}
113
+
114
+ /* Overlay for dimming effect */
115
+ .bg-container::after {{
116
+ content: '';
117
+ position: absolute;
118
+ top: ;
119
+ left: 0;
120
+ width: 100%;
121
+ height: 100%;
122
+ background-color: rgba(20, 10, 20, 0.44); /* Semi-transparent black overlay */
123
+ z-index: 1; /* Ensure the overlay is above the image */
124
+ }}
125
+
126
+ /* Overlay title styling */
127
+ .overlay-title {{
128
+ position: absolute;
129
+ top: 50%;
130
+ left: 50%;
131
+ transform: translate(-50%, -50%);
132
+ color: black; /* Title color */
133
+ font-size: 50px;
134
+ font-weight: bold;
135
+ text-shadow: 1px 1px 3px rgba(255, 255, 255, .0); /* Shadow for better visibility */
136
+ text-align: center;
137
+ z-index: 2; /* Ensure the title is above the overlay */
138
+ }}
139
+ </style>
140
+ """,
141
+ unsafe_allow_html=True
142
+ )
143
+
144
+ # Create the background container with an overlaid title
145
+ st.markdown(
146
+ """
147
+ <div class="bg-container">
148
+ <div class="overlay-title">Mistral-RAG</div>
149
+ </div>
150
+ """,
151
+ unsafe_allow_html=True
152
+ )
153
+
154
+
155
+ def initialise_session_state():
156
+ """
157
+ Initializes the session state variables if not already set.
158
+ """
159
+ for key, default_val in DEFAULT_SESSION_STATE.items():
160
+ if key not in st.session_state:
161
+ st.session_state[key] = default_val
162
+
163
+
164
+ def file_uploader():
165
+ uploaded_files = st.file_uploader(
166
+ "",
167
+ type=["txt", "pdf"],
168
+ accept_multiple_files=True)
169
+
170
+ if uploaded_files: # Check if list is not empty
171
+ for file in uploaded_files: # Process each file
172
+ if file.name not in st.session_state.uploaded_files_name:
173
+ # Append to session state lists safely
174
+ st.session_state.uploaded_files_name.append(file.name)
175
+ st.session_state.uploaded_files_raw.append(file)
176
+ st.success(f"Added new file: {file.name}")
177
+
178
+ else:
179
+ st.info("Please upload a PDF file to proceed.")
180
+
181
+