Starburst15 commited on
Commit
44669ca
Β·
verified Β·
1 Parent(s): 11f2ac4

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +287 -33
src/streamlit_app.py CHANGED
@@ -1,40 +1,294 @@
1
- import altair as alt
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================
2
+ # πŸ“˜ USTP Student Handbook Assistant (2023 Edition)
3
+ # =============================================================
4
+ # Enhanced: dynamic model selection + real (printed) page numbering
5
+
6
+ import os
7
+ import glob
8
+ import json
9
+ import time
10
+ from typing import List, Dict, Any
11
  import numpy as np
 
12
  import streamlit as st
13
+ import PyPDF2
14
+ import requests
15
+ from dotenv import load_dotenv
16
+ from huggingface_hub import InferenceClient, login
17
+ from streamlit_chat import message as st_message
18
+
19
+ # Optional: FAISS for fast vector search
20
+ try:
21
+ import faiss
22
+ except ImportError:
23
+ faiss = None
24
+
25
+ # =============================================================
26
+ # 🌐 Startup Fix for PermissionError
27
+ # =============================================================
28
+ os.environ["STREAMLIT_HOME"] = "/tmp/.streamlit"
29
+ os.makedirs("/tmp/.streamlit", exist_ok=True)
30
+
31
+ # =============================================================
32
+ # βš™οΈ Streamlit Page Setup
33
+ # =============================================================
34
+ st.set_page_config(page_title="πŸ“˜ Handbook Assistant", page_icon="πŸ“˜", layout="wide")
35
+ st.title("πŸ“˜ USTP Student Handbook Assistant (2023 Edition)")
36
+ st.caption("Answers sourced only from the official *USTP Student Handbook 2023 Edition.pdf*.")
37
+
38
+ load_dotenv()
39
+ HF_TOKEN = os.getenv("HF_TOKEN")
40
+
41
+ if not HF_TOKEN:
42
+ st.warning("⚠️ No Hugging Face API token found in .env file. Online models will be unavailable.")
43
+ else:
44
+ try:
45
+ login(HF_TOKEN)
46
+ except Exception:
47
+ pass
48
+
49
+ hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
50
+
51
+ # =============================================================
52
+ # βš™οΈ Sidebar Configuration
53
+ # =============================================================
54
+ with st.sidebar:
55
+ st.header("βš™οΈ Settings")
56
+
57
+ model_options = {
58
+ "Qwen 2.5 14B Instruct": "Qwen/Qwen2.5-14B-Instruct",
59
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
60
+ "Llama 3 8B Instruct": "meta-llama/Meta-Llama-3-8B-Instruct",
61
+ "Mixtral 8x7B Instruct": "mistralai/Mixtral-8x7B-Instruct-v0.1",
62
+ "Falcon 7B Instruct": "tiiuae/falcon-7b-instruct",
63
+ }
64
+ model_choice = st.selectbox("Select reasoning model", list(model_options.keys()), index=0)
65
+ DEFAULT_MODEL = model_options[model_choice]
66
+
67
+ st.markdown("---")
68
+ similarity_threshold = st.slider("Similarity threshold", 0.3, 1.0, 0.6, 0.01)
69
+ top_k = st.slider("Top K retrieved chunks", 1, 10, 4)
70
+ chunk_size_chars = st.number_input("Chunk size (chars)", 400, 2500, 1200, 100)
71
+ chunk_overlap = st.number_input("Chunk overlap (chars)", 20, 600, 150, 10)
72
+ front_matter_pages = st.number_input(
73
+ "Pages before main content (e.g. table of contents, cover)", min_value=0, max_value=50, value=12
74
+ )
75
+ regenerate_index = st.button("πŸ” Rebuild handbook index")
76
+
77
+ # =============================================================
78
+ # πŸ“‚ File Config
79
+ # =============================================================
80
+ INDEX_FILE = "handbook_faiss.index"
81
+ META_FILE = "handbook_metadata.json"
82
+ EMB_DIM_FILE = "handbook_emb_dim.json"
83
+ EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
84
+
85
+ # =============================================================
86
+ # 🧩 Utility Functions
87
+ # =============================================================
88
+ def find_handbook() -> List[str]:
89
+ preferred = "USTP Student Handbook 2023 Edition.pdf"
90
+ pdfs = glob.glob("*.pdf")
91
+ for f in pdfs:
92
+ if preferred.lower() in f.lower():
93
+ st.success(f"πŸ“˜ Found handbook: {f}")
94
+ return [f]
95
+ if pdfs:
96
+ st.warning(f"⚠️ Preferred handbook not found. Using {os.path.basename(pdfs[0])}.")
97
+ return [pdfs[0]]
98
+ st.error("❌ No PDF found in current folder.")
99
+ return []
100
+
101
+
102
+ def load_pdf_texts(pdf_paths: List[str]) -> List[Dict[str, Any]]:
103
+ """Extract page text while adjusting page numbering to printed handbook numbers."""
104
+ pages = []
105
+ for path in pdf_paths:
106
+ with open(path, "rb") as f:
107
+ reader = PyPDF2.PdfReader(f)
108
+ for i, page in enumerate(reader.pages):
109
+ text = page.extract_text() or ""
110
+ if text.strip():
111
+ # Adjust logical page number to printed numbering
112
+ logical_page = i + 1
113
+ printed_page = logical_page - front_matter_pages
114
+ if printed_page < 1:
115
+ printed_page = 1
116
+ pages.append({
117
+ "filename": os.path.basename(path),
118
+ "page": printed_page,
119
+ "text": text.strip()
120
+ })
121
+ return pages
122
+
123
+
124
+ def chunk_text(pages: List[Dict[str, Any]], size: int, overlap: int) -> List[Dict[str, Any]]:
125
+ chunks = []
126
+ for p in pages:
127
+ text = p["text"]
128
+ start = 0
129
+ while start < len(text):
130
+ end = start + size
131
+ chunk = text[start:end]
132
+ chunks.append({
133
+ "filename": p["filename"],
134
+ "page": p["page"],
135
+ "content": chunk.strip()
136
+ })
137
+ start += size - overlap
138
+ return chunks
139
 
 
 
140
 
141
+ def embed_texts(texts: List[str]) -> np.ndarray:
142
+ """Generate embeddings using Hugging Face feature extraction."""
143
+ if not HF_TOKEN or not hf_client:
144
+ st.error("❌ Missing Hugging Face token or client.")
145
+ return np.zeros((len(texts), 768))
146
+ try:
147
+ embeddings = hf_client.feature_extraction(texts, model=EMBED_MODEL)
148
+ if isinstance(embeddings[0][0], list):
149
+ embeddings = [np.mean(np.array(e), axis=0) for e in embeddings]
150
+ return np.array(embeddings)
151
+ except Exception as e1:
152
+ st.warning(f"⚠️ feature_extraction failed, using REST API fallback: {e1}")
153
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
154
+ resp = requests.post(
155
+ f"https://api-inference.huggingface.co/models/{EMBED_MODEL}",
156
+ headers=headers,
157
+ json={"inputs": texts}
158
+ )
159
+ data = resp.json()
160
+ if isinstance(data[0][0], list):
161
+ data = [np.mean(np.array(e), axis=0) for e in data]
162
+ return np.array(data)
163
 
164
+
165
+ def build_faiss_index(chunks: List[Dict[str, Any]]):
166
+ """Build FAISS index for chunks."""
167
+ texts = [c["content"] for c in chunks]
168
+ embeddings = embed_texts(texts)
169
+ if embeddings.size == 0:
170
+ st.error("❌ Embedding generation failed.")
171
+ return
172
+ dim = embeddings.shape[1]
173
+ index = faiss.IndexFlatL2(dim)
174
+ index.add(embeddings.astype("float32"))
175
+ faiss.write_index(index, INDEX_FILE)
176
+ with open(META_FILE, "w") as f:
177
+ json.dump(chunks, f)
178
+ with open(EMB_DIM_FILE, "w") as f:
179
+ json.dump({"dim": dim}, f)
180
+ st.success(f"βœ… Indexed {len(chunks)} chunks.")
181
+
182
+
183
+ def load_faiss_index():
184
+ if not os.path.exists(INDEX_FILE) or not os.path.exists(META_FILE):
185
+ return None, None
186
+ index = faiss.read_index(INDEX_FILE)
187
+ with open(META_FILE) as f:
188
+ meta = json.load(f)
189
+ return index, meta
190
+
191
+
192
+ def search_index(query: str, index, meta, top_k: int, threshold: float):
193
+ query_emb = embed_texts([query])
194
+ distances, indices = index.search(query_emb.astype("float32"), top_k)
195
+ results = []
196
+ for i, dist in zip(indices[0], distances[0]):
197
+ if i < len(meta):
198
+ r = meta[i]
199
+ r["distance"] = float(dist)
200
+ results.append(r)
201
+ return results
202
+
203
+
204
+ def generate_answer(context: str, query: str) -> str:
205
+ """Generate model-based answer using selected open-source model."""
206
+ prompt = f"""
207
+ You are a precise academic assistant specialized in university policy.
208
+ Use only the *USTP Student Handbook 2023 Edition* below.
209
+ If the answer is not in the text, reply:
210
+ "The handbook does not specify that."
211
+
212
+ ---
213
+ πŸ“˜ Context:
214
+ {context}
215
+ ---
216
+ 🧭 Question:
217
+ {query}
218
+ ---
219
+ 🎯 Instructions:
220
+ - Be factual and concise.
221
+ - Cite the correct printed page number.
222
+ - Never make assumptions.
223
  """
224
 
225
+ try:
226
+ response = hf_client.text_generation(
227
+ model=DEFAULT_MODEL,
228
+ prompt=prompt,
229
+ max_new_tokens=400,
230
+ temperature=0.25
231
+ )
232
+ return response if isinstance(response, str) else str(response)
233
+ except Exception as e1:
234
+ try:
235
+ chat_response = hf_client.chat.completions.create(
236
+ model=DEFAULT_MODEL,
237
+ messages=[{"role": "user", "content": prompt}],
238
+ max_tokens=400
239
+ )
240
+ return chat_response.choices[0].message["content"]
241
+ except Exception as e2:
242
+ return f"⚠️ Error generating answer: {e2}"
243
+
244
+
245
+ def ensure_index():
246
+ """Ensure FAISS index exists or rebuild."""
247
+ if regenerate_index or not os.path.exists(INDEX_FILE):
248
+ pdfs = find_handbook()
249
+ if not pdfs:
250
+ st.stop()
251
+ st.info("πŸ“„ Extracting handbook text...")
252
+ pages = load_pdf_texts(pdfs)
253
+ chunks = chunk_text(pages, chunk_size_chars, chunk_overlap)
254
+ build_faiss_index(chunks)
255
+ index, meta = load_faiss_index()
256
+ if index is None or meta is None:
257
+ st.error("❌ Could not load FAISS index.")
258
+ st.stop()
259
+ return index, meta
260
+
261
+ # =============================================================
262
+ # πŸ’¬ Chat Interface
263
+ # =============================================================
264
+ st.divider()
265
+ st.subheader("πŸ’¬ Ask about the Handbook")
266
+
267
+ if "history" not in st.session_state:
268
+ st.session_state.history = []
269
+
270
+ user_query = st.text_input("Enter your question:")
271
+ index, meta = ensure_index()
272
+
273
+ if st.button("Ask") and user_query.strip():
274
+ results = search_index(user_query, index, meta, top_k, similarity_threshold)
275
+ if not results:
276
+ st.warning("No relevant section found in the handbook.")
277
+ else:
278
+ context = "\n\n".join(
279
+ [f"(πŸ“„ Page {r['page']})\n{r['content']}" for r in results]
280
+ )
281
+ answer = generate_answer(context, user_query)
282
+ st.session_state.history.append({
283
+ "user": user_query,
284
+ "assistant": answer,
285
+ "timestamp": time.time()
286
+ })
287
+
288
+ # βœ… Ensure unique keys to prevent StreamlitDuplicateElementId
289
+ for i, chat in enumerate(st.session_state.history):
290
+ st_message(chat["user"], is_user=True, key=f"user_{i}")
291
+ st_message(chat["assistant"], key=f"assistant_{i}")
292
+
293
+ st.caption("⚑ Powered by FAISS + Open Source Models + Accurate Page Referencing")
294
+