anneee266333 commited on
Commit
863c39b
Β·
verified Β·
1 Parent(s): 93f0d23

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -0
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import hashlib
4
+ from typing import List
5
+
6
+ import streamlit as st
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from pypdf import PdfReader
17
+ from streamlit_chat import message
18
+
19
+
20
+ # --------------------------
21
+ # App Config
22
+ # --------------------------
23
+ st.set_page_config(
24
+ page_title="Simple QA - Built-in PDF",
25
+ page_icon="πŸ“˜",
26
+ layout="wide"
27
+ )
28
+
29
+ DEFAULT_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
30
+ DEFAULT_MAX_NEW_TOKENS = 256
31
+ DEFAULT_TEMPERATURE = 0.2
32
+
33
+ SYSTEM_PROMPT = (
34
+ "You are a careful assistant for question answering. "
35
+ "Use ONLY the provided context to answer. "
36
+ "If the answer is not in the context, say you don't know."
37
+ )
38
+
39
+
40
+ # --------------------------
41
+ # Utilities
42
+ # --------------------------
43
+ def read_pdf_bytes_to_text(file_like: io.BytesIO) -> str:
44
+ file_like.seek(0)
45
+ reader = PdfReader(file_like)
46
+ texts = []
47
+ for page in reader.pages:
48
+ texts.append(page.extract_text() or "")
49
+ return "\n".join(texts)
50
+
51
+
52
+ def compute_texts_hash(texts: List[str]) -> str:
53
+ data = "\n".join(texts)
54
+ return hashlib.sha256(data.encode("utf-8")).hexdigest()
55
+
56
+
57
+ def format_docs(docs):
58
+ return "\n\n".join(f"[{i+1}] {d.page_content}" for i, d in enumerate(docs))
59
+
60
+
61
+ # --------------------------
62
+ # Caches
63
+ # --------------------------
64
+ @st.cache_resource(show_spinner=True)
65
+ def get_embeddings():
66
+ return HuggingFaceEmbeddings(
67
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
68
+ model_kwargs={"device": "cpu"}
69
+ )
70
+
71
+
72
+ @st.cache_resource(show_spinner=True)
73
+ def load_llm(
74
+ model_id=DEFAULT_MODEL_ID,
75
+ temperature=DEFAULT_TEMPERATURE,
76
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS
77
+ ):
78
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
79
+ model = AutoModelForCausalLM.from_pretrained(
80
+ model_id,
81
+ torch_dtype=torch.float32,
82
+ low_cpu_mem_usage=True
83
+ )
84
+ gen = pipeline(
85
+ "text-generation",
86
+ model=model,
87
+ tokenizer=tokenizer,
88
+ device=-1,
89
+ temperature=temperature,
90
+ max_new_tokens=max_new_tokens,
91
+ repetition_penalty=1.1,
92
+ pad_token_id=tokenizer.eos_token_id,
93
+ return_full_text=False,
94
+ )
95
+ return HuggingFacePipeline(pipeline=gen)
96
+
97
+
98
+ def build_faiss_index(texts: List[str], chunk_size=800, chunk_overlap=120):
99
+ splitter = RecursiveCharacterTextSplitter(
100
+ chunk_size=chunk_size,
101
+ chunk_overlap=chunk_overlap
102
+ )
103
+ docs = splitter.create_documents(texts)
104
+ emb = get_embeddings()
105
+ vs = FAISS.from_documents(docs, embedding=emb)
106
+ return vs
107
+
108
+
109
+ def make_rag_chain(retriever, llm):
110
+ prompt = ChatPromptTemplate.from_messages([
111
+ ("system", SYSTEM_PROMPT),
112
+ ("human", "Context:\n{context}\n\nQuestion: {question}")
113
+ ])
114
+
115
+ chain = (
116
+ {
117
+ "context": retriever | RunnableLambda(format_docs),
118
+ "question": RunnablePassthrough()
119
+ }
120
+ | prompt
121
+ | llm
122
+ | StrOutputParser()
123
+ )
124
+ return chain
125
+
126
+
127
+ # --------------------------
128
+ # UI
129
+ # --------------------------
130
+ st.title("πŸ“˜ Simple QA with Built-in Handbook PDF")
131
+
132
+ with st.sidebar:
133
+ st.header("βš™οΈ Model Settings")
134
+ model_id = st.text_input("Model ID", value=DEFAULT_MODEL_ID)
135
+ temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE, 0.05)
136
+ max_new_tokens = st.slider("Max new tokens", 32, 1024, DEFAULT_MAX_NEW_TOKENS, 32)
137
+ chunk_size = st.slider("Chunk size (chars)", 200, 1500, 800, 50)
138
+ chunk_overlap = st.slider("Chunk overlap (chars)", 0, 400, 120, 10)
139
+
140
+
141
+ # --------------------------
142
+ # Build Knowledge Base Automatically
143
+ # --------------------------
144
+ st.subheader("πŸ“š Knowledge Base")
145
+ st.info("Using built-in handbook PDF as the knowledge base")
146
+
147
+ pdf_path = "USTP Student Handbook 2023 Edition.pdf" # must be in the same folder
148
+
149
+ if not os.path.exists(pdf_path):
150
+ st.error("handbook.pdf not found. Please place it in the same folder as this app.")
151
+ else:
152
+ with open(pdf_path, "rb") as f:
153
+ texts = [read_pdf_bytes_to_text(f)]
154
+
155
+ kb_hash = compute_texts_hash(texts)
156
+
157
+ with st.spinner("Building FAISS index..."):
158
+ vs = build_faiss_index(texts, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
159
+
160
+ st.session_state["kb_hash"] = kb_hash
161
+ st.session_state["vector store"] = vs
162
+ st.success("Knowledge base built successfully!")
163
+
164
+
165
+ # --------------------------
166
+ # Conversational Q&A Section
167
+ # --------------------------
168
+ st.subheader("πŸ’¬ Chat with the Student Handbook")
169
+
170
+ # Initialize chat history
171
+ if "messages" not in st.session_state:
172
+ st.session_state["messages"] = [
173
+ {"role": "assistant", "content": "Hi πŸ‘‹! Ask me anything about the student handbook."}
174
+ ]
175
+
176
+ # Display chat bubbles
177
+ for i, msg in enumerate(st.session_state["messages"]):
178
+ message(
179
+ msg["content"],
180
+ is_user=(msg["role"] == "user"),
181
+ key=f"{i}_{msg['role']}",
182
+ avatar_style="big-smile" if msg["role"] == "user" else "bottts"
183
+ )
184
+
185
+ # Input box for user
186
+ with st.form(key="chat_form", clear_on_submit=True):
187
+ question = st.text_input(
188
+ "πŸ’¬ Type your question:",
189
+ placeholder="e.g. What are the rules for student discipline?",
190
+ key="chat_input"
191
+ )
192
+ submitted = st.form_submit_button("Send")
193
+
194
+ show_sources = st.checkbox("πŸ“– Show retrieved chunks", value=True)
195
+
196
+ # Load LLM
197
+ if "llm" not in st.session_state:
198
+ with st.spinner("Loading model..."):
199
+ st.session_state["llm"] = load_llm(model_id, temperature, max_new_tokens)
200
+
201
+ # Handle user question
202
+ if submitted and question:
203
+ st.session_state["messages"].append({"role": "user", "content": question})
204
+
205
+ if "vector store" not in st.session_state:
206
+ st.warning("Knowledge base not built yet.")
207
+ else:
208
+ vs = st.session_state["vector store"]
209
+ llm = st.session_state["llm"]
210
+
211
+ retriever = vs.as_retriever(search_type="similarity", search_kwargs={"k": 3})
212
+ chain = make_rag_chain(retriever, llm)
213
+
214
+ with st.spinner("Thinking..."):
215
+ answer = chain.invoke(question)
216
+
217
+ st.session_state["messages"].append({"role": "assistant", "content": answer})
218
+
219
+ docs = retriever.vectorstore.similarity_search(question, k=3)
220
+ if docs and show_sources:
221
+ st.markdown("### πŸ“š Retrieved Chunks")
222
+ for i, d in enumerate(docs, start=1):
223
+ with st.expander(f"Chunk [{i}]"):
224
+ st.write(d.page_content[:800])
225
+
226
+ st.rerun()
227
+
228
+ # --------------------------
229
+ # Styling
230
+ # --------------------------
231
+ st.markdown("""
232
+ <style>
233
+ /* Overall background */
234
+ .stApp {
235
+ background-color: #f4f4ea;
236
+ font-family: 'Segoe UI', sans-serif;
237
+ }
238
+
239
+ /* Sidebar styling */
240
+ section[data-testid="stSidebar"] {
241
+ background-color: #e2e1f5;
242
+ color: black;
243
+ }
244
+
245
+ /* Buttons */
246
+ div.stButton > button {
247
+ background-color: #4a4a4a;
248
+ color: white;
249
+ border-radius: 8px;
250
+ font-size: 16px;
251
+ }
252
+ div.stButton > button:hover {
253
+ background-color: #2980b9;
254
+ }
255
+ h1, h2, h3 {
256
+ color: #2c3e50;
257
+ }
258
+
259
+ /* ---- Continuous Chat Background Fix ---- */
260
+
261
+ /* Remove vertical gaps between chat messages */
262
+ [data-testid="stVerticalBlock"] {
263
+ padding: 0 !important;
264
+ margin: 0 !important;
265
+ }
266
+
267
+ /* Prevent white padding above chat */
268
+ div[data-testid="stVerticalBlock"] > div:nth-child(1) {
269
+ margin-top: 0 !important;
270
+ }
271
+
272
+ /* Chat message bubble styles */
273
+ [class*="stChatMessage"] {
274
+ background-color: #f7f7f0 !important;
275
+ border-radius: 16px;
276
+ padding: 10px 16px !important;
277
+ margin-bottom: 4px !important;
278
+ }
279
+
280
+ /* User bubble color */
281
+ [class*="stChatMessageUser"] {
282
+ background-color: #e6f0ff !important;
283
+ }
284
+
285
+ /* Assistant bubble color */
286
+ [class*="stChatMessageAssistant"] {
287
+ background-color: #f0f0f0 !important;
288
+ }
289
+
290
+ /* Optional: smooth continuous background */
291
+ .stApp {
292
+ background: linear-gradient(to bottom, #f4f4ea 0%, #f4f4ea 100%);
293
+ }
294
+ </style>
295
+ """, unsafe_allow_html=True)