cmd0160 commited on
Commit
e7113d1
·
1 Parent(s): b892fd9

Adding auto-ingest

Browse files
Files changed (1) hide show
  1. app.py +189 -56
app.py CHANGED
@@ -1,76 +1,209 @@
1
- """Streamlit app for Abalone RAG chatbot."""
2
  import os
 
3
  os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false")
4
  os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true")
5
  os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "false")
6
 
7
  import streamlit as st
8
-
9
  from src.vectorstore import get_retriever
10
  from src.qa_chain import make_conversational_chain
11
-
12
 
13
  st.set_page_config(page_title="Abalone RAG Chatbot", page_icon="🐚")
14
 
15
  st.title("Abalone RAG Chatbot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  if "chat_history" not in st.session_state:
18
  st.session_state["chat_history"] = []
19
 
20
- with st.sidebar:
21
- st.header("Settings")
22
- model_name = st.selectbox("Model", ["gpt-3.5-turbo", "gpt-4"], index=0)
23
- top_k = st.number_input("Retriever top_k", min_value=1, max_value=10, value=4)
24
- if st.button("Rebuild vectorstore (ingest)"):
25
- st.info("Rebuild requested. Run ingestion script or push data to trigger rebuild.")
26
 
27
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
28
- if not OPENAI_API_KEY:
29
- st.error("OPENAI_API_KEY not found. Set the OPENAI_API_KEY environment variable or add it to Hugging Face Spaces Secrets.")
30
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- persist_dir = "./vectorstore"
33
  retriever = None
34
- try:
35
- retriever = get_retriever(persist_dir=persist_dir, top_k=top_k)
36
- except Exception as e:
37
- st.warning("Vectorstore not found or not initialized. Please run the ingestion script to build it.\n" + str(e))
38
-
39
-
40
- if retriever:
41
- chain = make_conversational_chain(retriever, model_name=model_name)
42
-
43
- user_input = st.text_input("Ask a question about Abalone", key="input")
44
- if st.button("Send") and user_input:
45
- with st.spinner("Thinking..."):
46
- prior_history = [(h.get("question"), h.get("answer", "")) for h in st.session_state.get("chat_history", [])]
47
- result = chain({"question": user_input, "chat_history": prior_history})
48
- answer = result.get("answer") or result.get("output_text") or ""
49
- source_docs = result.get("source_documents") or []
50
- st.session_state.setdefault("chat_history", [])
51
- st.session_state["chat_history"].append({"question": user_input, "answer": answer, "sources": source_docs})
52
-
53
- if st.session_state.get("chat_history"):
54
- for item in reversed(st.session_state.get("chat_history", [])):
55
- st.markdown(f"**User:** {item.get('question')}")
56
- st.markdown(f"**Assistant:** {item.get('answer')}")
57
- sources = item.get("sources") or []
58
- if sources:
59
- with st.expander("Sources"):
60
- for sd in sources:
61
- if isinstance(sd, dict):
62
- meta = sd.get("metadata", {})
63
- content_preview = sd.get("page_content") or sd.get("content") or sd.get("text", "")
64
- else:
65
- meta = getattr(sd, "metadata", {}) or {}
66
- content_preview = getattr(sd, "page_content", None)
67
- if content_preview is None:
68
- content_preview = getattr(sd, "content", "")
69
  st.write(meta)
70
- if content_preview:
71
- try:
72
- st.write(content_preview[:400])
73
- except Exception:
74
- st.write(str(content_preview))
75
- else:
76
- st.info("No retriever available. Ingest data first.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+
3
  os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false")
4
  os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true")
5
  os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "false")
6
 
7
  import streamlit as st
 
8
  from src.vectorstore import get_retriever
9
  from src.qa_chain import make_conversational_chain
10
+ from src.ingest import ingest as run_ingest
11
 
12
  st.set_page_config(page_title="Abalone RAG Chatbot", page_icon="🐚")
13
 
14
  st.title("Abalone RAG Chatbot")
15
+ st.write(
16
+ "Ask natural-language questions about abalone studies and data. "
17
+ "The app uses a local Chroma vectorstore and OpenAI to retrieve and answer."
18
+ )
19
+
20
+ st.sidebar.header("Configuration")
21
+
22
+ model_name = st.sidebar.selectbox(
23
+ "Model",
24
+ options=["gpt-3.5-turbo", "gpt-4"],
25
+ index=0,
26
+ )
27
+
28
+ top_k = st.sidebar.slider(
29
+ "Number of retrieved chunks (k)",
30
+ min_value=2,
31
+ max_value=10,
32
+ value=4,
33
+ )
34
+
35
+ data_dir = st.sidebar.text_input("Data directory", value="./data")
36
+ persist_dir = st.sidebar.text_input("Vectorstore directory", value="./vectorstore")
37
+
38
+ chunk_size = st.sidebar.number_input(
39
+ "Chunk size",
40
+ min_value=200,
41
+ max_value=4000,
42
+ value=1000,
43
+ step=100,
44
+ )
45
+
46
+ chunk_overlap = st.sidebar.number_input(
47
+ "Chunk overlap",
48
+ min_value=0,
49
+ max_value=1000,
50
+ value=200,
51
+ step=50,
52
+ )
53
+
54
+ st.sidebar.markdown("---")
55
+ st.sidebar.caption(
56
+ "If the vectorstore is missing or invalid, the app will attempt to ingest "
57
+ "the data automatically using these settings."
58
+ )
59
 
60
  if "chat_history" not in st.session_state:
61
  st.session_state["chat_history"] = []
62
 
63
+ if "retriever_initialized" not in st.session_state:
64
+ st.session_state["retriever_initialized"] = False
 
 
 
 
65
 
66
+ def ensure_openai_key() -> bool:
67
+ if not os.environ.get("OPENAI_API_KEY"):
68
+ st.error("OPENAI_API_KEY is not set.")
69
+ return False
70
+ return True
71
+
72
+ @st.cache_resource(show_spinner=False)
73
+ def build_or_load_retriever_cached(
74
+ data_dir: str,
75
+ persist_dir: str,
76
+ top_k: int,
77
+ chunk_size: int,
78
+ chunk_overlap: int,
79
+ ):
80
+ try:
81
+ return get_retriever(persist_dir=persist_dir, top_k=top_k)
82
+ except Exception:
83
+ run_ingest(
84
+ data_dir=data_dir,
85
+ persist_dir=persist_dir,
86
+ chunk_size=chunk_size,
87
+ chunk_overlap=chunk_overlap,
88
+ )
89
+ return get_retriever(persist_dir=persist_dir, top_k=top_k)
90
+
91
+ def get_or_build_retriever_with_ui():
92
+ if not ensure_openai_key():
93
+ return None
94
+ try:
95
+ return build_or_load_retriever_cached(
96
+ data_dir=data_dir,
97
+ persist_dir=persist_dir,
98
+ top_k=top_k,
99
+ chunk_size=chunk_size,
100
+ chunk_overlap=chunk_overlap,
101
+ )
102
+ except Exception as e:
103
+ st.error(
104
+ "Could not initialize vectorstore.\n\n"
105
+ f"Details: `{e}`"
106
+ )
107
+ return None
108
 
 
109
  retriever = None
110
+ with st.spinner("Initializing vectorstore and retriever..."):
111
+ retriever = get_or_build_retriever_with_ui()
112
+
113
+ if retriever is None:
114
+ st.info("No retriever available. Fix the errors above and refresh the page.")
115
+ st.stop()
116
+
117
+ st.success("Vectorstore and retriever are ready.")
118
+
119
+ chain = make_conversational_chain(retriever, model_name=model_name)
120
+
121
+ if st.session_state["chat_history"]:
122
+ st.subheader("Conversation")
123
+ for i, turn in enumerate(st.session_state["chat_history"]):
124
+ st.markdown(f"**You:** {turn['question']}")
125
+ st.markdown(f"**Abalone Bot:** {turn['answer']}")
126
+ if turn.get("sources"):
127
+ with st.expander(f"Show sources for question {i + 1}"):
128
+ for j, src in enumerate(turn["sources"], start=1):
129
+ st.markdown(f"**Source {j}:**")
130
+ meta = src.get("metadata", {})
131
+ if meta:
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  st.write(meta)
133
+ preview = src.get("content_preview", "")
134
+ if preview:
135
+ st.write(preview)
136
+
137
+ st.subheader("Ask a question")
138
+
139
+ user_input = st.text_input(
140
+ "Ask a question about abalone (biology, data, methodology, etc.)",
141
+ key="user_question_input",
142
+ )
143
+
144
+ send_clicked = st.button("Send")
145
+
146
+ if send_clicked and user_input:
147
+ if not ensure_openai_key():
148
+ st.stop()
149
+
150
+ with st.spinner("Thinking..."):
151
+ prior_history = [
152
+ (h.get("question"), h.get("answer", "")) for h in st.session_state["chat_history"]
153
+ ]
154
+
155
+ result = chain(
156
+ {"question": user_input, "chat_history": prior_history}
157
+ )
158
+
159
+ answer = (
160
+ result.get("answer")
161
+ or result.get("result")
162
+ or result.get("output_text")
163
+ or ""
164
+ )
165
+ source_docs = result.get("source_documents") or []
166
+
167
+ sources_for_ui = []
168
+ for sd in source_docs:
169
+ if isinstance(sd, dict):
170
+ meta = sd.get("metadata", {}) or {}
171
+ content_preview = sd.get("page_content") or sd.get("content") or sd.get("text", "")
172
+ else:
173
+ meta = getattr(sd, "metadata", {}) or {}
174
+ content_preview = getattr(sd, "page_content", None)
175
+ if content_preview is None:
176
+ content_preview = getattr(sd, "content", "")
177
+ if content_preview is None:
178
+ content_preview = ""
179
+ sources_for_ui.append(
180
+ {
181
+ "metadata": meta,
182
+ "content_preview": str(content_preview)[:500],
183
+ }
184
+ )
185
+
186
+ st.session_state["chat_history"].append(
187
+ {
188
+ "question": user_input,
189
+ "answer": answer,
190
+ "sources": sources_for_ui,
191
+ }
192
+ )
193
+
194
+ st.markdown("---")
195
+ st.markdown("### Latest Answer")
196
+ st.markdown(f"**You:** {user_input}")
197
+ st.markdown(f"**Abalone Bot:** {answer}")
198
+
199
+ if sources_for_ui:
200
+ with st.expander("Show sources for this answer"):
201
+ for i, src in enumerate(sources_for_ui, start=1):
202
+ st.markdown(f"**Source {i}:**")
203
+ if src["metadata"]:
204
+ st.write(src["metadata"])
205
+ if src["content_preview"]:
206
+ st.write(src["content_preview"])
207
+
208
+ elif send_clicked and not user_input:
209
+ st.warning("Please enter a question before clicking Send.")