Tannuyadav commited on
Commit
af86bb2
Β·
verified Β·
1 Parent(s): b623f28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -0
app.py CHANGED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ import torch
5
+ from langchain_community.document_loaders import PyPDFLoader
6
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_huggingface import HuggingFacePipeline
10
+ from langchain_classic.prompts import PromptTemplate
11
+ from langchain_classic.chains import RetrievalQA
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
13
+ from huggingface_hub import login
14
+
15
+
16
+ # --- Page Config & Styling ---
17
+ st.set_page_config(
18
+ page_title="DocTalk - Chat With PDF",
19
+ page_icon="πŸ“—πŸ’¬",
20
+ layout="wide",
21
+ initial_sidebar_state="expanded"
22
+ )
23
+
24
+ # Custom CSS for polished UI and Footer
25
+ st.markdown("""
26
+ <style>
27
+ /* Chat styling */
28
+ .stChatInput {
29
+ padding-bottom: 1rem;
30
+ }
31
+
32
+ /* Custom Footer */
33
+ .footer {
34
+ position: fixed;
35
+ left: 0;
36
+ bottom: 0;
37
+ width: 100%;
38
+ background-color: white;
39
+ color: #555;
40
+ text-align: center;
41
+ padding: 10px;
42
+ font-size: 14px;
43
+ border-top: 1px solid #eee;
44
+ z-index: 100;
45
+ }
46
+
47
+ /* Hide Streamlit branding for cleaner look */
48
+ #MainMenu {visibility: hidden;}
49
+ footer {visibility: hidden;}
50
+
51
+ /* Adjust sidebar padding for footer */
52
+ [data-testid="stSidebar"] {
53
+ padding-bottom: 50px;
54
+ }
55
+ </style>
56
+ """, unsafe_allow_html=True)
57
+
58
+ # --- Session State Management ---
59
+ if 'qa_chain' not in st.session_state: st.session_state.qa_chain = None
60
+ if 'messages' not in st.session_state: st.session_state.messages = []
61
+ if 'processing_done' not in st.session_state: st.session_state.processing_done = False
62
+
63
+ # --- Authentication (Secrets Only) ---
64
+ hf_token = os.environ.get("HF_TOKEN")
65
+
66
+ # --- Model Loading (Cached & CPU Optimized) ---
67
+
68
+ @st.cache_resource
69
+ def load_embedding_model():
70
+ """Load the embedding model once to save time."""
71
+ try:
72
+ # Using a lightweight, fast embedding model
73
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
74
+ return embeddings
75
+ except Exception as e:
76
+ st.error(f"Error loading embedding model: {e}")
77
+ return None
78
+
79
+ @st.cache_resource
80
+ def load_llm_model(token):
81
+ """Load the Gemma LLM once."""
82
+ try:
83
+ login(token=token)
84
+ model_id = "google/gemma-2-2b-it"
85
+
86
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
87
+
88
+ # Load model to CPU (float32 is safe for CPU stability)
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ model_id,
91
+ device_map="cpu",
92
+ torch_dtype=torch.float32,
93
+ token=token
94
+ )
95
+
96
+ pipe = pipeline(
97
+ "text-generation",
98
+ model=model,
99
+ tokenizer=tokenizer,
100
+ max_new_tokens=512,
101
+ temperature=0.1,
102
+ repetition_penalty=1.1,
103
+ return_full_text=False
104
+ )
105
+ return pipe
106
+ except Exception as e:
107
+ return None
108
+
109
+ # --- PDF Processing ---
110
+ def process_document(uploaded_file, model_pipeline, embedding_model):
111
+ try:
112
+ # Save temp file
113
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
114
+ tmp.write(uploaded_file.getvalue())
115
+ tmp_path = tmp.name
116
+
117
+ # Load & Split
118
+ loader = PyPDFLoader(tmp_path)
119
+ docs = loader.load()
120
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
121
+ chunks = splitter.split_documents(docs)
122
+
123
+ # Vector Store (FAISS is faster for in-memory)
124
+ vector_store = FAISS.from_documents(chunks, embedding_model)
125
+
126
+ # Chain Setup
127
+ llm = HuggingFacePipeline(pipeline=model_pipeline)
128
+
129
+ template = """<start_of_turn>user
130
+ Answer the question based strictly on the context below. Keep answers concise.
131
+ Context: {context}
132
+ Question: {question}<end_of_turn>
133
+ <start_of_turn>model
134
+ """
135
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
136
+
137
+ qa_chain = RetrievalQA.from_chain_type(
138
+ llm=llm,
139
+ retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
140
+ chain_type_kwargs={"prompt": prompt},
141
+ return_source_documents=True
142
+ )
143
+ return qa_chain
144
+ except Exception as e:
145
+ st.error(f"Error processing PDF: {e}")
146
+ return None
147
+
148
+ # --- Main Layout ---
149
+
150
+ # 1. Sidebar Configuration
151
+ with st.sidebar:
152
+ st.title("πŸ€– Configuration")
153
+ st.markdown("---")
154
+
155
+ if not hf_token:
156
+ st.error("🚨 **HF_TOKEN missing!**")
157
+ st.info("Go to Space Settings -> Repository Secrets and add your Hugging Face Access Token as `HF_TOKEN`.")
158
+ st.stop()
159
+ else:
160
+ st.success("βœ… Huggingface Active")
161
+
162
+ st.subheader("πŸ“„ Document Upload")
163
+ uploaded_file = st.file_uploader("Upload your PDF", type="pdf", help="Max file size ~200MB")
164
+
165
+ if uploaded_file:
166
+ process_btn = st.button("πŸš€ Process Document", type="primary", use_container_width=True)
167
+
168
+ if process_btn:
169
+ with st.spinner("🧠 Analyzing PDF"):
170
+ # Load models (cached)
171
+ llm_pipeline = load_llm_model(hf_token)
172
+ embed_model = load_embedding_model()
173
+
174
+ if llm_pipeline and embed_model:
175
+ qa_chain = process_document(uploaded_file, llm_pipeline, embed_model)
176
+ if qa_chain:
177
+ st.session_state.qa_chain = qa_chain
178
+ st.session_state.processing_done = True
179
+ st.success("Done! You can now chat.")
180
+ else:
181
+ st.error("Failed to process document.")
182
+ else:
183
+ st.error("Failed to load AI models. Check token permissions.")
184
+
185
+ if st.session_state.processing_done:
186
+ st.markdown("---")
187
+ if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
188
+ st.session_state.messages = []
189
+ st.rerun()
190
+
191
+ # 2. Main Chat Area
192
+ st.title("πŸ“—πŸ’¬ DocTalk - Chat With PDF")
193
+ #st.caption("Powered by Google Gemma-2-2B-IT")
194
+
195
+ if st.session_state.processing_done:
196
+ # Display History
197
+ for msg in st.session_state.messages:
198
+ with st.chat_message(msg["role"]):
199
+ st.markdown(msg["content"])
200
+
201
+ # Chat Input
202
+ if user_input := st.chat_input("Ask a question about your document..."):
203
+ st.session_state.messages.append({"role": "user", "content": user_input})
204
+ with st.chat_message("user"):
205
+ st.markdown(user_input)
206
+
207
+ with st.chat_message("assistant"):
208
+ with st.spinner("Thinking..."):
209
+ try:
210
+ response = st.session_state.qa_chain.invoke({"query": user_input})
211
+ answer = response['result']
212
+
213
+ st.markdown(answer)
214
+ st.session_state.messages.append({"role": "assistant", "content": answer})
215
+
216
+ # Optional: Show sources
217
+ with st.expander("πŸ”Ž View Source Context"):
218
+ for doc in response['source_documents']:
219
+ st.caption(f"Page {doc.metadata.get('page', '?')}: {doc.page_content[:200]}...")
220
+
221
+ except Exception as e:
222
+ st.error(f"An error occurred: {e}")
223
+ else:
224
+ # Empty State
225
+ st.info("πŸ‘‹ **Welcome!** Please upload a PDF in the sidebar to begin chatting.")
226
+ st.markdown("""
227
+ **How it works:**
228
+ 1. Upload a PDF document.
229
+ 2. Click 'Process Document'.
230
+ 3. Ask questions and get answers based strictly on your file.
231
+ """)
232
+
233
+ # --- Footer ---
234
+ st.markdown("""
235
+ <div class="footer">
236
+ Made with ❀️ with Streamlit and Gemma model, by Tannu Yadav
237
+ </div>
238
+ """, unsafe_allow_html=True)