Himanshu kumar Vishwakrma commited on
Commit
968023b
·
1 Parent(s): 462faf7

HF Spaces compatible version

Browse files
Files changed (3) hide show
  1. Dockerfile +10 -16
  2. requirements.txt +11 -6
  3. src/streamlit_app.py +52 -229
Dockerfile CHANGED
@@ -1,21 +1,15 @@
1
- FROM python:3.9-slim
2
 
3
  WORKDIR /app
 
4
 
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- software-properties-common \
9
- git \
10
- && rm -rf /var/lib/apt/lists/*
11
 
12
- COPY requirements.txt ./
13
- COPY src/ ./src/
14
 
15
- RUN pip3 install -r requirements.txt
16
-
17
- EXPOSE 8501
18
-
19
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
-
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
1
+ FROM python:3.10-slim # Hugging Face currently supports up to 3.10
2
 
3
  WORKDIR /app
4
+ COPY . .
5
 
6
+ RUN apt-get update && apt-get install -y gcc python3-dev && \
7
+ pip install --upgrade pip && \
8
+ pip install -r requirements.txt --no-cache-dir && \
9
+ python -m spacy download en_core_web_sm && \
10
+ python -m nltk.downloader punkt wordnet
 
11
 
12
+ ENV STREAMLIT_SERVER_PORT=7860
13
+ EXPOSE 7860
14
 
15
+ CMD ["streamlit", "run", "src/streamlit_app.py", "--server.port=7860", "--server.address=0.0.0.0"]
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,11 @@
1
- altair
2
- pandas
3
- streamlit
4
- pypdf
5
- docx
6
- chromadb
 
 
 
 
 
 
1
+ streamlit==1.31.0
2
+ pypdf==4.2.0
3
+ python-docx==1.1.0 # Replaces 'docx' which causes the exceptions error
4
+ chromadb==0.4.24
5
+ sentence-transformers==2.6.0
6
+ transformers==4.38.2
7
+ torch==2.2.1
8
+ accelerate==0.29.3
9
+ huggingface-hub==0.22.2
10
+ spacy==3.7.4
11
+ nltk==3.8.1
src/streamlit_app.py CHANGED
@@ -1,293 +1,116 @@
1
  import streamlit as st
2
  from pypdf import PdfReader
3
  from docx import Document
4
- import os
5
- import time
6
  import chromadb
7
  from chromadb.utils import embedding_functions
8
- from typing import List, Tuple
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
- import torch
11
-
12
- # Initialize ChromaDB
13
- client = chromadb.PersistentClient(path="./chroma_db")
14
- sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
15
-
16
- try:
17
- collection = client.get_collection(name="documents", embedding_function=sentence_transformer_ef)
18
- except:
19
- collection = client.create_collection(name="documents", embedding_function=sentence_transformer_ef)
20
 
21
- # Initialize Hugging Face model and tokenizer
22
- @st.cache_resource
23
- def load_model():
24
- model_name = "google/gemma-1.1-7b-it" # Using the 7B instruct-tuned version
25
- tokenizer = AutoTokenizer.from_pretrained(model_name)
26
- model = AutoModelForCausalLM.from_pretrained(
27
- model_name,
28
- device_map="auto",
29
- torch_dtype=torch.float16
30
- )
31
- return model, tokenizer
32
 
33
- model, tokenizer = load_model()
 
34
 
35
  def chunk_text(text: str, chunk_size: int = 1000) -> List[str]:
36
- """Split text into chunks of approximately chunk_size characters"""
37
  chunks = []
38
  start = 0
39
  while start < len(text):
40
  end = min(start + chunk_size, len(text))
41
- # Try to split at sentence boundary
42
  if end < len(text):
43
  while end > start and text[end] not in {'.', '!', '?', '\n'}:
44
  end -= 1
45
- if end == start: # No sentence boundary found
46
  end = start + chunk_size
47
  chunks.append(text[start:end].strip())
48
  start = end
49
  return chunks
50
 
51
- def process_document(uploaded_file, progress_bar=None, status_text=None):
52
- """Extract text from document and store in ChromaDB with progress tracking"""
53
  text = ""
54
-
55
- # Update status
56
- if status_text:
57
- status_text.text(f"Extracting text from {uploaded_file.name}...")
58
-
59
  if uploaded_file.type == "application/pdf":
60
  reader = PdfReader(uploaded_file)
61
- total_pages = len(reader.pages)
62
- for i, page in enumerate(reader.pages):
63
- text += page.extract_text()
64
- if progress_bar:
65
- progress_bar.progress((i + 1) / (total_pages * 2)) # First half is for extraction
66
-
67
  elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
68
  doc = Document(uploaded_file)
69
- total_paras = len(doc.paragraphs)
70
- for i, para in enumerate(doc.paragraphs):
71
- text += para.text + "\n"
72
- if progress_bar:
73
- progress_bar.progress((i + 1) / (total_paras * 2)) # First half is for extraction
74
-
75
  elif uploaded_file.type == "text/plain":
76
  text = str(uploaded_file.read(), "utf-8")
77
- if progress_bar:
78
- progress_bar.progress(0.5) # Mark extraction as 50% complete
79
-
80
- # Update status
81
- if status_text:
82
- status_text.text(f"Chunking and storing {uploaded_file.name} in database...")
83
 
84
- # Split text into chunks
85
  chunks = chunk_text(text)
86
-
87
- # Store in ChromaDB
88
  ids = [f"{uploaded_file.name}-{i}" for i in range(len(chunks))]
89
-
90
- # Add chunks in batches for smoother progress updates
91
- batch_size = max(1, len(chunks) // 10) # Create 10 progress updates
92
- for i in range(0, len(chunks), batch_size):
93
- end_idx = min(i + batch_size, len(chunks))
94
- collection.add(
95
- documents=chunks[i:end_idx],
96
- ids=ids[i:end_idx],
97
- metadatas=[{"source": uploaded_file.name} for _ in range(i, end_idx)]
98
- )
99
- if progress_bar:
100
- # Calculate progress for second half (storage)
101
- extraction_half = 0.5 # First 50% was for extraction
102
- storage_progress = (end_idx / len(chunks)) * 0.5 # Second 50% for storage
103
- progress_bar.progress(extraction_half + storage_progress)
104
-
105
- # Complete the progress
106
- if progress_bar:
107
- progress_bar.progress(1.0)
108
- if status_text:
109
- status_text.text(f"Completed processing {uploaded_file.name}")
110
-
111
  return len(chunks)
112
 
113
- @st.cache_data(ttl=300) # Cache results for 5 minutes
114
- def retrieve_relevant_chunks(query: str, k: int = 5) -> Tuple[List[str], List[str]]:
115
- """Retrieve relevant document chunks from ChromaDB with caching for performance"""
116
  results = collection.query(
117
  query_texts=[query],
118
  n_results=k
119
  )
120
  return results['documents'][0], results['metadatas'][0]
121
 
122
- @st.cache_data(ttl=60, show_spinner=False) # Cache for 1 minute
123
- def generate_response(query: str, context: str, temp: float = 0.7) -> str:
124
- """Generate response using Hugging Face Gemma with RAG context and caching"""
125
- prompt = f"""Use the following context to answer the question. If you don't know the answer, say you don't know.
126
-
127
- Context:
128
- {context}
129
-
130
- Question: {query}
131
-
132
- Answer:"""
133
-
134
- # Tokenize the input
135
- input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
136
-
137
- # Generate response
138
- with torch.no_grad():
139
- outputs = model.generate(
140
- **input_ids,
141
- max_new_tokens=512,
142
- temperature=temp,
143
- do_sample=True if temp > 0 else False,
144
- top_k=50,
145
- top_p=0.95
146
- )
147
-
148
- # Decode the response
149
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
150
-
151
- # Remove the input prompt from the response
152
- response = response[len(prompt):].strip()
153
- return response
154
-
155
- # Initialize session states
156
- if "messages" not in st.session_state:
157
- st.session_state.messages = []
158
- if "uploaded_files" not in st.session_state:
159
- st.session_state.uploaded_files = []
160
-
161
- # Initialize performance tracking
162
- if "performance_metrics" not in st.session_state:
163
- st.session_state.performance_metrics = {
164
- "total_queries": 0,
165
- "avg_response_time": 0,
166
- "last_response_time": 0
167
- }
168
 
169
- # App title
170
  st.title("📄 Document Q&A Assistant")
171
 
172
- # Sidebar for document upload
173
  with st.sidebar:
174
- st.header("Document Management")
175
  uploaded_files = st.file_uploader(
176
- "Upload documents",
177
  type=["pdf", "docx", "txt"],
178
  accept_multiple_files=True
179
  )
180
 
181
- st.markdown("---")
182
- st.header("Settings")
183
- temperature = st.slider("Temperature", 0.0, 1.0, 0.7)
184
- st.markdown("ℹ️ All processing happens locally")
185
-
186
- if uploaded_files and st.button("Process Documents", use_container_width=True):
187
- progress_container = st.container()
188
-
189
- with progress_container:
190
- st.markdown("### Processing Documents")
191
- progress_bar = st.progress(0)
192
- status_text = st.empty()
193
-
194
- st.markdown("**Progress Metrics**")
195
- metric_col1, metric_col2 = st.columns(2)
196
- total_chunks_metric = metric_col1.empty()
197
- eta_metric = metric_col2.empty()
198
-
199
- start_time = time.time()
200
- total_chunks = 0
201
- files_processed = 0
202
-
203
- for uploaded_file in uploaded_files:
204
- if uploaded_file.name not in st.session_state.uploaded_files:
205
- status_text.text(f"Starting to process {uploaded_file.name}...")
206
-
207
- chunks_count = process_document(uploaded_file, progress_bar, status_text)
208
- total_chunks += chunks_count
209
- files_processed += 1
210
-
211
- elapsed = time.time() - start_time
212
- eta = (elapsed / files_processed) * (len(uploaded_files) - files_processed) if files_processed > 0 else 0
213
- total_chunks_metric.metric("Chunks Created", f"{total_chunks}")
214
- eta_metric.metric("Time Remaining", f"{eta:.1f}s")
215
-
216
- st.session_state.uploaded_files.append(uploaded_file.name)
217
-
218
- progress_bar.progress(1.0)
219
- status_text.text("✅ Processing completed!")
220
-
221
- st.success(f"Successfully processed {files_processed} document(s) into {total_chunks} searchable chunks.")
222
- st.balloons()
223
- st.markdown("### 🎉 Your documents are now ready!")
224
- st.markdown("You can start asking questions about your documents in the chat below.")
225
 
226
- # Display chat messages
227
  for message in st.session_state.messages:
228
  with st.chat_message(message["role"]):
229
  st.markdown(message["content"])
230
 
231
- # Optional: Display performance metrics in an expandable section
232
- with st.sidebar:
233
- if st.session_state.performance_metrics["total_queries"] > 0:
234
- with st.expander("Performance Metrics"):
235
- st.metric("Average Response Time", f"{st.session_state.performance_metrics['avg_response_time']:.2f} seconds")
236
- st.metric("Last Response Time", f"{st.session_state.performance_metrics['last_response_time']:.2f} seconds")
237
- st.metric("Total Queries", f"{st.session_state.performance_metrics['total_queries']}")
238
-
239
- # Chat input
240
- if prompt := st.chat_input("Ask about your documents..."):
241
- query_start_time = time.time()
242
-
243
  st.session_state.messages.append({"role": "user", "content": prompt})
244
 
245
  with st.chat_message("user"):
246
  st.markdown(prompt)
247
 
248
  with st.chat_message("assistant"):
249
- message_placeholder = st.empty()
250
- full_response = ""
251
-
252
- with st.status("Searching documents for relevant information...", expanded=True) as status:
253
- st.write("🔍 Finding relevant information...")
254
- chunks, metadata = retrieve_relevant_chunks(prompt)
255
  context = "\n\n".join(chunks)
256
 
 
 
257
  sources = list(set([m['source'] for m in metadata]))
258
- st.write(f"📚 Found information in {len(sources)} document(s)")
259
 
260
- st.write("💭 Generating response...")
261
- response = generate_response(prompt, context, temp=temperature)
262
- status.update(label="✅ Answer ready!", state="complete", expanded=False)
263
-
264
- words = response.split()
265
- total_words = len(words)
266
- update_frequency = max(1, total_words // 20)
267
-
268
- for i in range(0, total_words, update_frequency):
269
- end_idx = min(i + update_frequency, total_words)
270
- full_response += " ".join(words[i:end_idx]) + " "
271
- message_placeholder.markdown(full_response + "▌")
272
- time.sleep(0.01)
273
-
274
- if sources:
275
- full_response += f"\n\nSources: {', '.join(sources)}"
276
-
277
- message_placeholder.markdown(full_response)
278
-
279
- st.session_state.messages.append({"role": "assistant", "content": full_response})
280
-
281
- end_time = time.time()
282
- query_time = end_time - query_start_time
283
-
284
- st.session_state.performance_metrics["total_queries"] += 1
285
- st.session_state.performance_metrics["last_response_time"] = query_time
286
-
287
- prev_avg = st.session_state.performance_metrics["avg_response_time"]
288
- prev_count = st.session_state.performance_metrics["total_queries"] - 1
289
 
290
- if prev_count > 0:
291
- st.session_state.performance_metrics["avg_response_time"] = (prev_avg * prev_count + query_time) / st.session_state.performance_metrics["total_queries"]
292
- else:
293
- st.session_state.performance_metrics["avg_response_time"] = query_time
 
1
  import streamlit as st
2
  from pypdf import PdfReader
3
  from docx import Document
 
 
4
  import chromadb
5
  from chromadb.utils import embedding_functions
6
+ from huggingface_hub import InferenceClient
7
+ import time
8
+ import os
 
 
 
 
 
 
 
 
 
9
 
10
+ # Initialize ChromaDB (ephemeral for HF Spaces)
11
+ client = chromadb.EphemeralClient()
12
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
13
+ model_name="all-MiniLM-L6-v2"
14
+ )
15
+ collection = client.get_or_create_collection(
16
+ name="documents",
17
+ embedding_function=sentence_transformer_ef
18
+ )
 
 
19
 
20
+ # Initialize HF Inference Client
21
+ hf_client = InferenceClient(model="google/gemma-2b-it")
22
 
23
  def chunk_text(text: str, chunk_size: int = 1000) -> List[str]:
 
24
  chunks = []
25
  start = 0
26
  while start < len(text):
27
  end = min(start + chunk_size, len(text))
 
28
  if end < len(text):
29
  while end > start and text[end] not in {'.', '!', '?', '\n'}:
30
  end -= 1
31
+ if end == start:
32
  end = start + chunk_size
33
  chunks.append(text[start:end].strip())
34
  start = end
35
  return chunks
36
 
37
+ def process_document(uploaded_file):
 
38
  text = ""
 
 
 
 
 
39
  if uploaded_file.type == "application/pdf":
40
  reader = PdfReader(uploaded_file)
41
+ text = "\n".join([page.extract_text() for page in reader.pages])
 
 
 
 
 
42
  elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
43
  doc = Document(uploaded_file)
44
+ text = "\n".join([para.text for para in doc.paragraphs])
 
 
 
 
 
45
  elif uploaded_file.type == "text/plain":
46
  text = str(uploaded_file.read(), "utf-8")
 
 
 
 
 
 
47
 
 
48
  chunks = chunk_text(text)
 
 
49
  ids = [f"{uploaded_file.name}-{i}" for i in range(len(chunks))]
50
+ collection.add(
51
+ documents=chunks,
52
+ ids=ids,
53
+ metadatas=[{"source": uploaded_file.name} for _ in chunks]
54
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  return len(chunks)
56
 
57
+ def retrieve_chunks(query: str, k: int = 3) -> Tuple[List[str], List[str]]:
 
 
58
  results = collection.query(
59
  query_texts=[query],
60
  n_results=k
61
  )
62
  return results['documents'][0], results['metadatas'][0]
63
 
64
+ def generate_response(query: str, context: str) -> str:
65
+ prompt = f"""Context: {context}\n\nQuestion: {query}\nAnswer:"""
66
+ return hf_client.text_generation(
67
+ prompt,
68
+ max_new_tokens=512,
69
+ temperature=0.7
70
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # Streamlit UI
73
  st.title("📄 Document Q&A Assistant")
74
 
 
75
  with st.sidebar:
76
+ st.header("Upload Documents")
77
  uploaded_files = st.file_uploader(
78
+ "Choose files",
79
  type=["pdf", "docx", "txt"],
80
  accept_multiple_files=True
81
  )
82
 
83
+ if uploaded_files:
84
+ with st.spinner("Processing documents..."):
85
+ for file in uploaded_files:
86
+ chunks = process_document(file)
87
+ st.success(f"Processed {file.name} into {chunks} chunks")
88
+
89
+ if "messages" not in st.session_state:
90
+ st.session_state.messages = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
92
  for message in st.session_state.messages:
93
  with st.chat_message(message["role"]):
94
  st.markdown(message["content"])
95
 
96
+ if prompt := st.chat_input("Ask about your documents"):
 
 
 
 
 
 
 
 
 
 
 
97
  st.session_state.messages.append({"role": "user", "content": prompt})
98
 
99
  with st.chat_message("user"):
100
  st.markdown(prompt)
101
 
102
  with st.chat_message("assistant"):
103
+ with st.spinner("Searching documents..."):
104
+ chunks, metadata = retrieve_chunks(prompt)
 
 
 
 
105
  context = "\n\n".join(chunks)
106
 
107
+ with st.spinner("Generating response..."):
108
+ response = generate_response(prompt, context)
109
  sources = list(set([m['source'] for m in metadata]))
 
110
 
111
+ if sources:
112
+ response += f"\n\nSources: {', '.join(sources)}"
113
+
114
+ st.markdown(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ st.session_state.messages.append({"role": "assistant", "content": response})