Update app.py

#17
by Muthuraja18 - opened
Files changed (1) hide show
  1. app.py +72 -93
app.py CHANGED
@@ -7,49 +7,51 @@ from langchain_community.document_loaders import PyPDFLoader, TextLoader
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain_community.vectorstores import FAISS
10
- from langchain_community.llms import HuggingFacePipeline
11
  from langchain.chains import RetrievalQA
12
 
13
- # Transformers
14
- from transformers import pipeline
 
15
 
16
  # Charts
17
  import plotly.express as px
18
 
19
  # -------------------------------
20
- # PAGE CONFIG
21
  # -------------------------------
22
- st.set_page_config(page_title="RAG + Analytics", layout="wide")
23
- st.title("📄 RAG Chatbot + 📊 Analytics Dashboard")
24
 
25
  # -------------------------------
26
- # CACHE (VERY IMPORTANT ⚡)
27
  # -------------------------------
28
  @st.cache_resource
29
  def load_llm():
 
 
 
 
 
30
  pipe = pipeline(
31
  "text2text-generation",
32
- model="google/flan-t5-base",
 
33
  max_length=512
34
  )
35
- return HuggingFacePipeline(pipeline=pipe)
36
 
37
- @st.cache_resource
38
- def load_embeddings():
39
- return HuggingFaceEmbeddings(
40
- model_name="sentence-transformers/all-MiniLM-L6-v2"
41
- )
42
 
43
  # -------------------------------
44
- # LOAD DOCUMENTS
45
  # -------------------------------
46
- def load_documents(files):
47
  docs = []
48
  stats = []
49
 
 
 
50
  for file in files:
51
  path = os.path.join("temp", file.name)
52
- os.makedirs("temp", exist_ok=True)
53
 
54
  with open(path, "wb") as f:
55
  f.write(file.getbuffer())
@@ -73,11 +75,11 @@ def load_documents(files):
73
  return docs, pd.DataFrame(stats)
74
 
75
  # -------------------------------
76
- # SPLIT DOCUMENTS
77
  # -------------------------------
78
  def split_docs(docs):
79
  splitter = RecursiveCharacterTextSplitter(
80
- chunk_size=500,
81
  chunk_overlap=50
82
  )
83
  return splitter.split_documents(docs)
@@ -85,30 +87,43 @@ def split_docs(docs):
85
  # -------------------------------
86
  # VECTOR STORE
87
  # -------------------------------
 
 
 
 
 
 
88
  def create_vectorstore(chunks):
89
- embeddings = load_embeddings()
90
- return FAISS.from_documents(chunks, embeddings)
91
 
92
  # -------------------------------
93
- # QA CHAIN
94
  # -------------------------------
95
  def build_qa(vs):
96
  llm = load_llm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  return RetrievalQA.from_chain_type(
98
  llm=llm,
99
- retriever=vs.as_retriever()
 
100
  )
101
 
102
  # -------------------------------
103
- # FILE UPLOAD
104
- # -------------------------------
105
- files = st.file_uploader(
106
- "Upload PDF / TXT files",
107
- accept_multiple_files=True
108
- )
109
-
110
- # -------------------------------
111
- # SESSION STATE
112
  # -------------------------------
113
  if "qa" not in st.session_state:
114
  st.session_state.qa = None
@@ -117,92 +132,56 @@ if "history" not in st.session_state:
117
  st.session_state.history = []
118
 
119
  # -------------------------------
120
- # PROCESS FILES
 
 
 
 
 
121
  # -------------------------------
122
  if files and st.session_state.qa is None:
123
- with st.spinner("Processing documents..."):
124
- docs, df = load_documents(files)
125
  chunks = split_docs(docs)
126
  vs = create_vectorstore(chunks)
127
  qa = build_qa(vs)
128
 
129
  st.session_state.qa = qa
130
  st.session_state.df = df
131
- st.session_state.chunk_count = len(chunks)
132
  st.session_state.doc_count = len(docs)
 
133
 
134
- st.success("✅ Documents processed!")
135
 
136
  # -------------------------------
137
  # DASHBOARD
138
  # -------------------------------
139
  if st.session_state.qa:
140
-
141
- st.subheader("📊 Analytics Dashboard")
142
 
143
  df = st.session_state.df
144
 
145
- col1, col2, col3 = st.columns(3)
 
146
 
147
- col1.metric("📄 Total Documents", st.session_state.doc_count)
148
- col2.metric("🧩 Total Chunks", st.session_state.chunk_count)
149
- col3.metric("📁 Files Uploaded", len(df))
150
-
151
- # ---- Bar Chart ----
152
- fig1 = px.bar(
153
- df,
154
- x="File",
155
- y="Pages",
156
- color="Type",
157
- title="Pages per File"
158
- )
159
- st.plotly_chart(fig1, use_container_width=True)
160
-
161
- # ---- Pie Chart ----
162
- fig2 = px.pie(
163
- df,
164
- names="Type",
165
- title="File Type Distribution"
166
- )
167
- st.plotly_chart(fig2, use_container_width=True)
168
-
169
- # ---- Line Chart ----
170
- growth_df = pd.DataFrame({
171
- "Stage": ["Documents", "Chunks"],
172
- "Count": [st.session_state.doc_count, st.session_state.chunk_count]
173
- })
174
-
175
- fig3 = px.line(
176
- growth_df,
177
- x="Stage",
178
- y="Count",
179
- markers=True,
180
- title="Processing Growth"
181
- )
182
- st.plotly_chart(fig3, use_container_width=True)
183
 
184
  # -------------------------------
185
- # CHATBOT
186
  # -------------------------------
187
- st.subheader("🤖 Chat with Documents")
188
-
189
- query = st.text_input("Ask your question...")
190
 
191
  if query and st.session_state.qa:
192
- with st.spinner("Thinking..."):
193
- result = st.session_state.qa.invoke({"query": query})
194
- answer = result["result"]
195
 
196
- # Save history
197
- st.session_state.history.append((query, answer))
198
 
199
  # -------------------------------
200
- # CHAT HISTORY
201
  # -------------------------------
202
- if st.session_state.history:
203
- st.subheader("💬 Chat History")
204
-
205
- for q, a in reversed(st.session_state.history):
206
- st.markdown(f"**🧑 Question:** {q}")
207
- st.markdown(f"**🤖 Answer:** {a}")
208
- st.markdown("---")
 
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain_community.vectorstores import FAISS
 
10
  from langchain.chains import RetrievalQA
11
 
12
+ # Local LLM
13
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
14
+ from langchain_community.llms import HuggingFacePipeline
15
 
16
  # Charts
17
  import plotly.express as px
18
 
19
  # -------------------------------
20
+ # CONFIG
21
  # -------------------------------
22
+ st.set_page_config(page_title="Offline GPT RAG", layout="wide")
23
+ st.title("🤖 Offline ChatGPT-like RAG + 📊 Dashboard")
24
 
25
  # -------------------------------
26
+ # CACHE MODEL (IMPORTANT ⚡)
27
  # -------------------------------
28
  @st.cache_resource
29
  def load_llm():
30
+ model_name = "google/flan-t5-base"
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
34
+
35
  pipe = pipeline(
36
  "text2text-generation",
37
+ model=model,
38
+ tokenizer=tokenizer,
39
  max_length=512
40
  )
 
41
 
42
+ return HuggingFacePipeline(pipeline=pipe)
 
 
 
 
43
 
44
  # -------------------------------
45
+ # LOAD DOCS
46
  # -------------------------------
47
+ def load_docs(files):
48
  docs = []
49
  stats = []
50
 
51
+ os.makedirs("temp", exist_ok=True)
52
+
53
  for file in files:
54
  path = os.path.join("temp", file.name)
 
55
 
56
  with open(path, "wb") as f:
57
  f.write(file.getbuffer())
 
75
  return docs, pd.DataFrame(stats)
76
 
77
  # -------------------------------
78
+ # SPLIT
79
  # -------------------------------
80
  def split_docs(docs):
81
  splitter = RecursiveCharacterTextSplitter(
82
+ chunk_size=400,
83
  chunk_overlap=50
84
  )
85
  return splitter.split_documents(docs)
 
87
  # -------------------------------
88
  # VECTOR STORE
89
  # -------------------------------
90
+ @st.cache_resource
91
+ def load_embeddings():
92
+ return HuggingFaceEmbeddings(
93
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
94
+ )
95
+
96
  def create_vectorstore(chunks):
97
+ return FAISS.from_documents(chunks, load_embeddings())
 
98
 
99
  # -------------------------------
100
+ # QA CHAIN (BETTER PROMPT)
101
  # -------------------------------
102
  def build_qa(vs):
103
  llm = load_llm()
104
+
105
+ prompt_template = """
106
+ You are an intelligent assistant.
107
+ Answer ONLY from the provided context.
108
+ If the answer is not in the context, say "Not found in document".
109
+
110
+ Context:
111
+ {context}
112
+
113
+ Question:
114
+ {question}
115
+
116
+ Answer:
117
+ """
118
+
119
  return RetrievalQA.from_chain_type(
120
  llm=llm,
121
+ retriever=vs.as_retriever(search_kwargs={"k": 3}),
122
+ chain_type_kwargs={"prompt": prompt_template}
123
  )
124
 
125
  # -------------------------------
126
+ # SESSION
 
 
 
 
 
 
 
 
127
  # -------------------------------
128
  if "qa" not in st.session_state:
129
  st.session_state.qa = None
 
132
  st.session_state.history = []
133
 
134
  # -------------------------------
135
+ # UPLOAD
136
+ # -------------------------------
137
+ files = st.file_uploader("Upload PDF/TXT", accept_multiple_files=True)
138
+
139
+ # -------------------------------
140
+ # PROCESS
141
  # -------------------------------
142
  if files and st.session_state.qa is None:
143
+ with st.spinner("Processing..."):
144
+ docs, df = load_docs(files)
145
  chunks = split_docs(docs)
146
  vs = create_vectorstore(chunks)
147
  qa = build_qa(vs)
148
 
149
  st.session_state.qa = qa
150
  st.session_state.df = df
 
151
  st.session_state.doc_count = len(docs)
152
+ st.session_state.chunk_count = len(chunks)
153
 
154
+ st.success("✅ Ready!")
155
 
156
  # -------------------------------
157
  # DASHBOARD
158
  # -------------------------------
159
  if st.session_state.qa:
160
+ st.subheader("📊 Analytics")
 
161
 
162
  df = st.session_state.df
163
 
164
+ st.metric("Docs", st.session_state.doc_count)
165
+ st.metric("Chunks", st.session_state.chunk_count)
166
 
167
+ st.plotly_chart(px.bar(df, x="File", y="Pages", color="Type"))
168
+ st.plotly_chart(px.pie(df, names="Type"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # -------------------------------
171
+ # CHAT
172
  # -------------------------------
173
+ query = st.text_input("Ask your question")
 
 
174
 
175
  if query and st.session_state.qa:
176
+ result = st.session_state.qa.invoke({"query": query})
177
+ answer = result["result"]
 
178
 
179
+ st.session_state.history.append((query, answer))
 
180
 
181
  # -------------------------------
182
+ # HISTORY
183
  # -------------------------------
184
+ for q, a in reversed(st.session_state.history):
185
+ st.markdown(f"**Q:** {q}")
186
+ st.markdown(f"**A:** {a}")
187
+ st.markdown("---")