Update app.py
Browse files
app.py
CHANGED
|
@@ -26,7 +26,8 @@ def load_embedding_model(model_name: str = DEFAULT_EMBED_MODEL):
|
|
| 26 |
|
| 27 |
def save_uploaded_file(uploaded_file) -> str:
|
| 28 |
tmpdir = tempfile.gettempdir()
|
| 29 |
-
|
|
|
|
| 30 |
with open(temp_path, "wb") as f:
|
| 31 |
f.write(uploaded_file.getbuffer())
|
| 32 |
return temp_path
|
|
@@ -112,8 +113,8 @@ def load_index_and_metadata(index_path: str = INDEX_PATH, meta_path: str = METAD
|
|
| 112 |
return None, None
|
| 113 |
|
| 114 |
|
| 115 |
-
def get_groq_client_from_env(
|
| 116 |
-
key = os.environ.get("GROQ_API_KEY") or (
|
| 117 |
if not key:
|
| 118 |
return None
|
| 119 |
return Groq(api_key=key)
|
|
@@ -142,13 +143,16 @@ def main():
|
|
| 142 |
top_k = int(st.sidebar.number_input("Top-k results to retrieve", min_value=1, max_value=50, value=DEFAULT_TOP_K))
|
| 143 |
|
| 144 |
st.sidebar.markdown("---")
|
| 145 |
-
st.sidebar.write("Groq API key
|
| 146 |
-
hardcoded_key = st.sidebar.text_input("Paste GROQ key here (will be used if env var not set)", type="password", value="")
|
| 147 |
|
| 148 |
with st.spinner("Loading embedding model..."):
|
| 149 |
embed_model = load_embedding_model()
|
| 150 |
|
| 151 |
index, metadata = load_index_and_metadata()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
if index is None:
|
| 153 |
st.info("No existing FAISS index found. Upload documents and click 'Ingest documents' to build index.")
|
| 154 |
else:
|
|
@@ -200,8 +204,17 @@ def main():
|
|
| 200 |
embeddings = normalize_embeddings(embeddings)
|
| 201 |
|
| 202 |
index = build_faiss_index(embeddings)
|
| 203 |
-
|
| 204 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
st.markdown("---")
|
| 207 |
|
|
@@ -214,9 +227,11 @@ def main():
|
|
| 214 |
if not query or not query.strip():
|
| 215 |
st.warning("Please enter a question.")
|
| 216 |
else:
|
| 217 |
-
if index
|
| 218 |
st.warning("No index available. Please ingest documents first.")
|
| 219 |
else:
|
|
|
|
|
|
|
| 220 |
q_emb = embed_model.encode([query], convert_to_numpy=True)
|
| 221 |
q_emb = normalize_embeddings(q_emb)
|
| 222 |
q_emb = np.ascontiguousarray(q_emb.astype('float32'))
|
|
@@ -234,9 +249,9 @@ def main():
|
|
| 234 |
meta = metadata[idx]
|
| 235 |
contexts.append(meta)
|
| 236 |
|
| 237 |
-
client = get_groq_client_from_env(
|
| 238 |
if client is None:
|
| 239 |
-
st.error("Groq client could not be created. Set
|
| 240 |
else:
|
| 241 |
system_msg = {
|
| 242 |
"role": "system",
|
|
@@ -273,9 +288,9 @@ def main():
|
|
| 273 |
|
| 274 |
st.sidebar.markdown("---")
|
| 275 |
st.sidebar.header("Index status")
|
| 276 |
-
if index
|
| 277 |
try:
|
| 278 |
-
st.sidebar.write(f"Vectors in index: {index.ntotal}")
|
| 279 |
except Exception:
|
| 280 |
st.sidebar.write("Vectors in index: unknown")
|
| 281 |
else:
|
|
|
|
| 26 |
|
| 27 |
def save_uploaded_file(uploaded_file) -> str:
|
| 28 |
tmpdir = tempfile.gettempdir()
|
| 29 |
+
safe_name = os.path.basename(uploaded_file.name)
|
| 30 |
+
temp_path = os.path.join(tmpdir, safe_name)
|
| 31 |
with open(temp_path, "wb") as f:
|
| 32 |
f.write(uploaded_file.getbuffer())
|
| 33 |
return temp_path
|
|
|
|
| 113 |
return None, None
|
| 114 |
|
| 115 |
|
| 116 |
+
def get_groq_client_from_env() -> Groq | None:
|
| 117 |
+
key = os.environ.get("GROQ_API_KEY") or os.environ.get("HF_GROQ_API_KEY") or os.environ.get("HF_API_TOKEN")
|
| 118 |
if not key:
|
| 119 |
return None
|
| 120 |
return Groq(api_key=key)
|
|
|
|
| 143 |
top_k = int(st.sidebar.number_input("Top-k results to retrieve", min_value=1, max_value=50, value=DEFAULT_TOP_K))
|
| 144 |
|
| 145 |
st.sidebar.markdown("---")
|
| 146 |
+
st.sidebar.write("Groq API key is loaded from Hugging Face secrets or environment variables. Do not paste keys in the UI.")
|
|
|
|
| 147 |
|
| 148 |
with st.spinner("Loading embedding model..."):
|
| 149 |
embed_model = load_embedding_model()
|
| 150 |
|
| 151 |
index, metadata = load_index_and_metadata()
|
| 152 |
+
if index is None and "index" in st.session_state:
|
| 153 |
+
index = st.session_state["index"]
|
| 154 |
+
metadata = st.session_state.get("metadata")
|
| 155 |
+
|
| 156 |
if index is None:
|
| 157 |
st.info("No existing FAISS index found. Upload documents and click 'Ingest documents' to build index.")
|
| 158 |
else:
|
|
|
|
| 204 |
embeddings = normalize_embeddings(embeddings)
|
| 205 |
|
| 206 |
index = build_faiss_index(embeddings)
|
| 207 |
+
st.session_state["index"] = index
|
| 208 |
+
st.session_state["metadata"] = metadata
|
| 209 |
+
try:
|
| 210 |
+
save_index_and_metadata(index, metadata)
|
| 211 |
+
st.success(f"Index built and saved. {len(all_chunks)} chunks indexed.")
|
| 212 |
+
except Exception:
|
| 213 |
+
st.success(f"Index built in memory. {len(all_chunks)} chunks indexed.")
|
| 214 |
+
st.info(f"Ingested {len(all_chunks)} chunks from {len(set(m['source'] for m in metadata))} file(s).")
|
| 215 |
+
if len(all_chunks) > 0:
|
| 216 |
+
st.markdown("### Example chunk")
|
| 217 |
+
st.write(all_chunks[0][:1000])
|
| 218 |
|
| 219 |
st.markdown("---")
|
| 220 |
|
|
|
|
| 227 |
if not query or not query.strip():
|
| 228 |
st.warning("Please enter a question.")
|
| 229 |
else:
|
| 230 |
+
if "index" not in st.session_state or "metadata" not in st.session_state:
|
| 231 |
st.warning("No index available. Please ingest documents first.")
|
| 232 |
else:
|
| 233 |
+
index = st.session_state["index"]
|
| 234 |
+
metadata = st.session_state["metadata"]
|
| 235 |
q_emb = embed_model.encode([query], convert_to_numpy=True)
|
| 236 |
q_emb = normalize_embeddings(q_emb)
|
| 237 |
q_emb = np.ascontiguousarray(q_emb.astype('float32'))
|
|
|
|
| 249 |
meta = metadata[idx]
|
| 250 |
contexts.append(meta)
|
| 251 |
|
| 252 |
+
client = get_groq_client_from_env()
|
| 253 |
if client is None:
|
| 254 |
+
st.error("Groq client could not be created. Set GROQ_API_KEY in your Hugging Face secrets or environment.")
|
| 255 |
else:
|
| 256 |
system_msg = {
|
| 257 |
"role": "system",
|
|
|
|
| 288 |
|
| 289 |
st.sidebar.markdown("---")
|
| 290 |
st.sidebar.header("Index status")
|
| 291 |
+
if "index" in st.session_state:
|
| 292 |
try:
|
| 293 |
+
st.sidebar.write(f"Vectors in index: {st.session_state['index'].ntotal}")
|
| 294 |
except Exception:
|
| 295 |
st.sidebar.write("Vectors in index: unknown")
|
| 296 |
else:
|