vitalune commited on
Commit
722a075
·
verified ·
1 Parent(s): 5600aa2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +291 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,293 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ import asyncio
4
+ from pathlib import Path
5
+ from typing import List
6
+ from dotenv import load_dotenv
7
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, StorageContext, load_index_from_storage, Document
8
+ from llama_index.llms.openai import OpenAI
9
+ from llama_index.embeddings.openai import OpenAIEmbedding
10
+ from llama_cloud_services import LlamaParse
11
 
12
+ # Load environment variables from .env (if present)
13
+ load_dotenv()
14
+
15
+ # Backend configuration (from llama_test.ipynb)
16
+ # These values are fixed and cannot be changed from the UI
17
+ LLM_MODEL = "gpt-5-nano-2025-08-07"
18
+ EMBEDDING_MODEL = "text-embedding-3-small"
19
+ TEMPERATURE = 0.1
20
+ DATA_DIR = "data"
21
+ PERSIST_DIR = "./storage"
22
+
23
+ # System prompt configuration
24
+ # This can be customized to change the chatbot's behavior and personality
25
+ # You can also set this via SYSTEM_PROMPT environment variable
26
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful AI assistant with access to a knowledge base.
27
+ Answer questions based on the provided context. If you cannot find the answer in the context,
28
+ let the user know that the information is not available in the documents."""
29
+
30
+ # Allow overriding system prompt via environment variable
31
+ SYSTEM_PROMPT = os.getenv('SYSTEM_PROMPT', DEFAULT_SYSTEM_PROMPT)
32
+
33
+ # Configure Streamlit page
34
+ st.set_page_config(
35
+ page_title="CatBot",
36
+ page_icon="😺",
37
+ layout="centered"
38
+ )
39
+
40
+ # Helper function to get API keys from multiple sources
41
+ def get_api_key(key_name: str) -> str:
42
+ """
43
+ Get API key from multiple sources in priority order:
44
+ 1. Environment variables (works for local dev, Docker, and Hugging Face Spaces)
45
+ 2. Streamlit secrets (works for Streamlit Cloud)
46
+
47
+ Hugging Face Spaces: Set secrets in Space Settings > Repository secrets
48
+ Streamlit Cloud: Set secrets in App Settings > Secrets
49
+ Local dev: Use .env file or export environment variables
50
+ """
51
+ # Try environment variable first (highest priority)
52
+ api_key = os.getenv(key_name)
53
+ if api_key:
54
+ return api_key
55
+
56
+ # Try Streamlit secrets as fallback
57
+ try:
58
+ if key_name in st.secrets:
59
+ return st.secrets[key_name]
60
+ except (FileNotFoundError, KeyError):
61
+ pass
62
+
63
+ return None
64
+
65
+ # Get API keys from environment variables or Streamlit secrets
66
+ # For Hugging Face Spaces: Add these as secrets in your Space settings
67
+ # For Streamlit Cloud: Add these in the app secrets
68
+ # For local development: Use .env file
69
+ openai_api_key = get_api_key('OPENAI_API_KEY')
70
+ llama_cloud_api_key = get_api_key('LLAMA_CLOUD_API_KEY')
71
+
72
+ # Initialize chat history
73
+ if "messages" not in st.session_state:
74
+ st.session_state.messages = []
75
+
76
+ # Helper function to load documents with LlamaParse
77
+ def load_documents_with_llamaparse(data_dir: str, llama_api_key: str) -> List[Document]:
78
+ """
79
+ Load documents from data directory using LlamaParse for complex file types
80
+ and SimpleDirectoryReader for basic text files.
81
+
82
+ Supported complex file types: PDF, DOCX, PPTX, XLSX
83
+ """
84
+ data_path = Path(data_dir)
85
+ if not data_path.exists():
86
+ return []
87
+
88
+ # File extensions that benefit from LlamaParse
89
+ llamaparse_extensions = {'.pdf', '.docx', '.pptx', '.xlsx', '.doc', '.ppt', '.xls'}
90
+ # File extensions for simple text reading
91
+ simple_extensions = {'.txt', '.md', '.csv', '.json', '.html', '.xml'}
92
+
93
+ all_files = list(data_path.glob('*'))
94
+ llamaparse_files = []
95
+ simple_files = []
96
+
97
+ for file_path in all_files:
98
+ if file_path.is_file():
99
+ ext = file_path.suffix.lower()
100
+ if ext in llamaparse_extensions:
101
+ llamaparse_files.append(str(file_path))
102
+ elif ext in simple_extensions:
103
+ simple_files.append(str(file_path))
104
+
105
+ documents = []
106
+
107
+ # Process complex files with LlamaParse
108
+ if llamaparse_files:
109
+ st.info(f"📄 Processing {len(llamaparse_files)} complex file(s) with LlamaParse: {', '.join([Path(f).name for f in llamaparse_files])}")
110
+ try:
111
+ # Configure LlamaParse with optimal settings
112
+ parser = LlamaParse(
113
+ api_key=llama_api_key,
114
+ parse_mode="parse_page_with_agent",
115
+ model="openai-gpt-4-1-mini",
116
+ high_res_ocr=True,
117
+ adaptive_long_table=True,
118
+ outlined_table_extraction=True,
119
+ output_tables_as_HTML=True,
120
+ num_workers=4,
121
+ verbose=True,
122
+ language="en"
123
+ )
124
+
125
+ # Parse files (LlamaParse handles batch processing)
126
+ # Use asyncio to run the async parse method
127
+ loop = asyncio.new_event_loop()
128
+ asyncio.set_event_loop(loop)
129
+
130
+ try:
131
+ if len(llamaparse_files) == 1:
132
+ result = loop.run_until_complete(parser.aparse(llamaparse_files[0]))
133
+ results = [result]
134
+ else:
135
+ results = loop.run_until_complete(parser.aparse(llamaparse_files))
136
+ finally:
137
+ loop.close()
138
+
139
+ # Convert JobResults to LlamaIndex Documents
140
+ for result in results:
141
+ # Get markdown documents with page splitting for better chunking
142
+ llamaparse_docs = result.get_markdown_documents(split_by_page=True)
143
+ documents.extend(llamaparse_docs)
144
+
145
+ except Exception as e:
146
+ st.warning(f"LlamaParse processing failed for some files: {str(e)}")
147
+ st.info("Falling back to SimpleDirectoryReader for these files...")
148
+ # Fall back to simple reader if LlamaParse fails
149
+ simple_files.extend(llamaparse_files)
150
+
151
+ # Process simple text files with SimpleDirectoryReader
152
+ if simple_files:
153
+ st.info(f"📝 Processing {len(simple_files)} simple file(s) with SimpleDirectoryReader: {', '.join([Path(f).name for f in simple_files])}")
154
+ for file_path in simple_files:
155
+ try:
156
+ file_docs = SimpleDirectoryReader(input_files=[file_path]).load_data()
157
+ documents.extend(file_docs)
158
+ except Exception as e:
159
+ st.warning(f"Failed to load {file_path}: {str(e)}")
160
+
161
+ return documents
162
+
163
+ # Initialize query engine
164
+ @st.cache_resource
165
+ def initialize_query_engine(_openai_api_key, _llama_api_key, _system_prompt):
166
+ """Initialize the LlamaIndex query engine with caching"""
167
+
168
+ # Set API keys
169
+ os.environ['OPENAI_API_KEY'] = _openai_api_key
170
+ if _llama_api_key:
171
+ os.environ['LLAMA_CLOUD_API_KEY'] = _llama_api_key
172
+
173
+ # Configure models with backend configuration
174
+ llm = OpenAI(
175
+ model=LLM_MODEL,
176
+ temperature=TEMPERATURE,
177
+ system_prompt=_system_prompt
178
+ )
179
+ embed_model = OpenAIEmbedding(model=EMBEDDING_MODEL)
180
+
181
+ try:
182
+ if not os.path.exists(PERSIST_DIR):
183
+ # Load documents and create index
184
+ if not os.path.exists(DATA_DIR):
185
+ os.makedirs(DATA_DIR)
186
+ return None, "Please add documents to the 'data' directory"
187
+
188
+ # Use LlamaParse if API key is available, otherwise fall back to SimpleDirectoryReader
189
+ if _llama_api_key:
190
+ st.info("Using LlamaParse for advanced document processing...")
191
+ documents = load_documents_with_llamaparse(DATA_DIR, _llama_api_key)
192
+ else:
193
+ st.info("Using SimpleDirectoryReader (LlamaParse API key not found)...")
194
+ documents = SimpleDirectoryReader(DATA_DIR).load_data()
195
+
196
+ if not documents:
197
+ return None, "No documents found in the 'data' directory"
198
+
199
+ index = VectorStoreIndex.from_documents(
200
+ documents,
201
+ llm=llm,
202
+ embed_model=embed_model
203
+ )
204
+ # Store for later
205
+ index.storage_context.persist(persist_dir=PERSIST_DIR)
206
+ status = f"Index created with {len(documents)} documents"
207
+ else:
208
+ # Load existing index
209
+ storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
210
+ index = load_index_from_storage(storage_context)
211
+
212
+ # Configure the loaded index with LLM and embedding models
213
+ # This ensures the query engine uses the correct models
214
+ index._llm = llm
215
+ index._embed_model = embed_model
216
+ status = "Index loaded from storage"
217
+
218
+ # Create query engine
219
+ query_engine = index.as_query_engine(llm=llm, embed_model=embed_model)
220
+ return query_engine, status
221
+
222
+ except Exception as e:
223
+ return None, f"❌ Error: {str(e)}"
224
+
225
+ # Main chat interface
226
+ if not openai_api_key:
227
+ st.error("⚠️ OPENAI_API_KEY is required to run CatBot")
228
+ st.info("""
229
+ **How to set the API key:**
230
+
231
+ - **Hugging Face Spaces**: Go to Settings → Repository secrets �� Add `OPENAI_API_KEY`
232
+ - **Local Development**: Create a `.env` file with `OPENAI_API_KEY=your_key_here`
233
+ - **Streamlit Cloud**: Add to App Settings → Secrets
234
+
235
+ Get your OpenAI API key from: https://platform.openai.com/api-keys
236
+ """)
237
+ st.stop()
238
+
239
+ # Display info about LlamaParse availability
240
+ if not llama_cloud_api_key:
241
+ st.info("💡 Tip: Set LLAMA_CLOUD_API_KEY to enable advanced parsing of PDFs, DOCX, and other complex documents.")
242
+
243
+ # Initialize query engine
244
+ if "query_engine" not in st.session_state:
245
+ with st.spinner("Initializing RAG agent..."):
246
+ query_engine, status = initialize_query_engine(
247
+ openai_api_key,
248
+ llama_cloud_api_key,
249
+ SYSTEM_PROMPT
250
+ )
251
+ st.session_state.query_engine = query_engine
252
+
253
+ if query_engine is None:
254
+ st.error(status)
255
+ st.stop()
256
+ else:
257
+ st.success(status)
258
+
259
+ # Display chat history
260
+ for message in st.session_state.messages:
261
+ with st.chat_message(message["role"]):
262
+ st.markdown(message["content"])
263
+
264
+ # Chat input
265
+ if prompt := st.chat_input("Ask a question about your documents"):
266
+ # Display user message
267
+ with st.chat_message("user"):
268
+ st.markdown(prompt)
269
+
270
+ # Add user message to history
271
+ st.session_state.messages.append({"role": "user", "content": prompt})
272
+
273
+ # Generate response
274
+ with st.chat_message("assistant"):
275
+ with st.spinner("Thinking..."):
276
+ try:
277
+ response = st.session_state.query_engine.query(prompt)
278
+ response_text = str(response)
279
+ st.markdown(response_text)
280
+
281
+ # Add assistant response to history
282
+ st.session_state.messages.append({
283
+ "role": "assistant",
284
+ "content": response_text
285
+ })
286
+
287
+ except Exception as e:
288
+ error_msg = f"Error generating response: {str(e)}"
289
+ st.error(error_msg)
290
+ st.session_state.messages.append({
291
+ "role": "assistant",
292
+ "content": error_msg
293
+ })