vitalune commited on
Commit
8dcd999
·
verified ·
1 Parent(s): 1c8fb39

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +136 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,138 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, StorageContext, load_index_from_storage
5
+ from llama_index.llms.openai import OpenAI
6
+ from llama_index.embeddings.openai import OpenAIEmbedding
7
 
8
+ # Load environment variables from .env (if present)
9
+ load_dotenv()
10
+
11
+ # Backend configuration (from llama_test.ipynb)
12
+ # These values are fixed and cannot be changed from the UI
13
+ LLM_MODEL = "gpt-5-nano-2025-08-07"
14
+ EMBEDDING_MODEL = "text-embedding-3-small"
15
+ TEMPERATURE = 0.1
16
+ DATA_DIR = "data"
17
+ PERSIST_DIR = "./storage"
18
+
19
+ # Configure Streamlit page
20
+ st.set_page_config(
21
+ page_title="LlamaIndex RAG Agent",
22
+ page_icon="🦙",
23
+ layout="centered"
24
+ )
25
+ # Get API key from environment variable or Streamlit secrets
26
+ # This should be set before running the Streamlit app
27
+ openai_api_key = os.getenv('OPENAI_API_KEY') or st.secrets.get("OPENAI_API_KEY")
28
+
29
+ # Get API key from environment variable
30
+ # This should be set before running the Streamlit app
31
+ openai_api_key = os.getenv('OPENAI_API_KEY')
32
+
33
+ # Initialize chat history
34
+ if "messages" not in st.session_state:
35
+ st.session_state.messages = []
36
+
37
+ # Initialize query engine
38
+ @st.cache_resource
39
+ def initialize_query_engine(_api_key):
40
+ """Initialize the LlamaIndex query engine with caching"""
41
+
42
+ # Set API key
43
+ os.environ['OPENAI_API_KEY'] = _api_key
44
+
45
+ # Configure models with backend configuration
46
+ llm = OpenAI(model=LLM_MODEL, temperature=TEMPERATURE)
47
+ embed_model = OpenAIEmbedding(model=EMBEDDING_MODEL)
48
+
49
+ try:
50
+ if not os.path.exists(PERSIST_DIR):
51
+ # Load documents and create index
52
+ if not os.path.exists(DATA_DIR):
53
+ os.makedirs(DATA_DIR)
54
+ return None, "Please add documents to the 'data' directory"
55
+
56
+ documents = SimpleDirectoryReader(DATA_DIR).load_data()
57
+
58
+ if not documents:
59
+ return None, "No documents found in the 'data' directory"
60
+
61
+ index = VectorStoreIndex.from_documents(
62
+ documents,
63
+ llm=llm,
64
+ embed_model=embed_model
65
+ )
66
+ # Store for later
67
+ index.storage_context.persist(persist_dir=PERSIST_DIR)
68
+ status = f"✅ Index created with {len(documents)} documents"
69
+ else:
70
+ # Load existing index
71
+ storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
72
+ index = load_index_from_storage(storage_context)
73
+
74
+ # Configure the loaded index with LLM and embedding models
75
+ # This ensures the query engine uses the correct models
76
+ index._llm = llm
77
+ index._embed_model = embed_model
78
+ status = "✅ Index loaded from storage"
79
+
80
+ # Create query engine
81
+ query_engine = index.as_query_engine(llm=llm, embed_model=embed_model)
82
+ return query_engine, status
83
+
84
+ except Exception as e:
85
+ return None, f"❌ Error: {str(e)}"
86
+
87
+ # Main chat interface
88
+ if not openai_api_key:
89
+ st.warning("⚠️ Please set the OPENAI_API_KEY environment variable to get started.")
90
+ st.stop()
91
+
92
+ # Initialize query engine
93
+ if "query_engine" not in st.session_state:
94
+ with st.spinner("Initializing RAG agent..."):
95
+ query_engine, status = initialize_query_engine(openai_api_key)
96
+ st.session_state.query_engine = query_engine
97
+
98
+ if query_engine is None:
99
+ st.error(status)
100
+ st.stop()
101
+ else:
102
+ st.success(status)
103
+
104
+ # Display chat history
105
+ for message in st.session_state.messages:
106
+ with st.chat_message(message["role"]):
107
+ st.markdown(message["content"])
108
+
109
+ # Chat input
110
+ if prompt := st.chat_input("Ask a question about your documents"):
111
+ # Display user message
112
+ with st.chat_message("user"):
113
+ st.markdown(prompt)
114
+
115
+ # Add user message to history
116
+ st.session_state.messages.append({"role": "user", "content": prompt})
117
+
118
+ # Generate response
119
+ with st.chat_message("assistant"):
120
+ with st.spinner("Thinking..."):
121
+ try:
122
+ response = st.session_state.query_engine.query(prompt)
123
+ response_text = str(response)
124
+ st.markdown(response_text)
125
+
126
+ # Add assistant response to history
127
+ st.session_state.messages.append({
128
+ "role": "assistant",
129
+ "content": response_text
130
+ })
131
+
132
+ except Exception as e:
133
+ error_msg = f"Error generating response: {str(e)}"
134
+ st.error(error_msg)
135
+ st.session_state.messages.append({
136
+ "role": "assistant",
137
+ "content": error_msg
138
+ })