nivakaran commited on
Commit
ecf72c3
·
verified ·
1 Parent(s): fa2dc21

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. app.py +30 -12
  2. src/llm/phi_model.py +17 -5
  3. src/rag/pipeline.py +16 -4
  4. src/session.py +223 -0
app.py CHANGED
@@ -154,11 +154,18 @@ def process_files(files):
154
  return "\n".join(messages), get_stats_text()
155
 
156
 
157
- def answer_question(question, top_k, chat_history):
158
- """Answer a question with production-grade error handling."""
 
 
 
 
 
 
 
159
  # Input validation
160
  if not question or not question.strip():
161
- return chat_history, ""
162
 
163
  question = question.strip()
164
 
@@ -167,7 +174,7 @@ def answer_question(question, top_k, chat_history):
167
  if len(question) > MAX_QUESTION_LENGTH:
168
  response = f"⚠️ Your question is too long ({len(question)} chars). Please keep it under {MAX_QUESTION_LENGTH} characters."
169
  chat_history.append((question[:100] + "...", response))
170
- return chat_history, ""
171
 
172
  try:
173
  pipe = get_pipeline()
@@ -180,7 +187,8 @@ def answer_question(question, top_k, chat_history):
180
  )
181
  else:
182
  try:
183
- result = pipe.query(question, top_k=int(top_k))
 
184
  response = result.get("answer", "I couldn't generate a response. Please try again.")
185
 
186
  # Add sources if available
@@ -188,6 +196,9 @@ def answer_question(question, top_k, chat_history):
188
  if sources:
189
  unique_sources = list(set(s.get("filename", "Unknown") for s in sources))
190
  response += f"\n\n---\n📚 *Sources: {', '.join(unique_sources)}*"
 
 
 
191
 
192
  except Exception as e:
193
  logger.error(f"Query error: {e}")
@@ -203,7 +214,7 @@ def answer_question(question, top_k, chat_history):
203
  response = "⚠️ The system is temporarily unavailable. Please try again in a moment."
204
 
205
  chat_history.append((question, response))
206
- return chat_history, ""
207
 
208
 
209
  def get_stats_text() -> str:
@@ -309,6 +320,9 @@ with gr.Blocks(
309
  with gr.Row():
310
  submit_btn = gr.Button("🔍 Ask", variant="primary", scale=2)
311
  clear_chat_btn = gr.Button("🧹 Clear Chat", scale=1)
 
 
 
312
 
313
  # Event handlers
314
  upload_btn.click(
@@ -319,14 +333,14 @@ with gr.Blocks(
319
 
320
  submit_btn.click(
321
  fn=answer_question,
322
- inputs=[question_input, top_k_slider, chatbot],
323
- outputs=[chatbot, question_input]
324
  )
325
 
326
  question_input.submit(
327
  fn=answer_question,
328
- inputs=[question_input, top_k_slider, chatbot],
329
- outputs=[chatbot, question_input]
330
  )
331
 
332
  clear_btn.click(
@@ -334,9 +348,13 @@ with gr.Blocks(
334
  outputs=[upload_status, stats_display]
335
  )
336
 
 
 
 
 
337
  clear_chat_btn.click(
338
- fn=lambda: [],
339
- outputs=[chatbot]
340
  )
341
 
342
  gr.Markdown("""
 
154
  return "\n".join(messages), get_stats_text()
155
 
156
 
157
+ def answer_question(question, top_k, chat_history, session_id):
158
+ """Answer a question with session-based conversation history."""
159
+ from src.session import get_session_manager
160
+ session_mgr = get_session_manager()
161
+
162
+ # Create session if needed
163
+ if not session_id:
164
+ session_id = session_mgr.create_session()
165
+
166
  # Input validation
167
  if not question or not question.strip():
168
+ return chat_history, "", session_id
169
 
170
  question = question.strip()
171
 
 
174
  if len(question) > MAX_QUESTION_LENGTH:
175
  response = f"⚠️ Your question is too long ({len(question)} chars). Please keep it under {MAX_QUESTION_LENGTH} characters."
176
  chat_history.append((question[:100] + "...", response))
177
+ return chat_history, "", session_id
178
 
179
  try:
180
  pipe = get_pipeline()
 
187
  )
188
  else:
189
  try:
190
+ # Pass session_id for conversation history
191
+ result = pipe.query(question, top_k=int(top_k), session_id=session_id)
192
  response = result.get("answer", "I couldn't generate a response. Please try again.")
193
 
194
  # Add sources if available
 
196
  if sources:
197
  unique_sources = list(set(s.get("filename", "Unknown") for s in sources))
198
  response += f"\n\n---\n📚 *Sources: {', '.join(unique_sources)}*"
199
+
200
+ # Store in session history
201
+ session_mgr.add_message(session_id, question, response)
202
 
203
  except Exception as e:
204
  logger.error(f"Query error: {e}")
 
214
  response = "⚠️ The system is temporarily unavailable. Please try again in a moment."
215
 
216
  chat_history.append((question, response))
217
+ return chat_history, "", session_id
218
 
219
 
220
  def get_stats_text() -> str:
 
320
  with gr.Row():
321
  submit_btn = gr.Button("🔍 Ask", variant="primary", scale=2)
322
  clear_chat_btn = gr.Button("🧹 Clear Chat", scale=1)
323
+
324
+ # Hidden state for session management
325
+ session_id = gr.State(value=None)
326
 
327
  # Event handlers
328
  upload_btn.click(
 
333
 
334
  submit_btn.click(
335
  fn=answer_question,
336
+ inputs=[question_input, top_k_slider, chatbot, session_id],
337
+ outputs=[chatbot, question_input, session_id]
338
  )
339
 
340
  question_input.submit(
341
  fn=answer_question,
342
+ inputs=[question_input, top_k_slider, chatbot, session_id],
343
+ outputs=[chatbot, question_input, session_id]
344
  )
345
 
346
  clear_btn.click(
 
348
  outputs=[upload_status, stats_display]
349
  )
350
 
351
+ def clear_chat_and_session():
352
+ """Clear chat history and reset session."""
353
+ return [], None
354
+
355
  clear_chat_btn.click(
356
+ fn=clear_chat_and_session,
357
+ outputs=[chatbot, session_id]
358
  )
359
 
360
  gr.Markdown("""
src/llm/phi_model.py CHANGED
@@ -255,14 +255,16 @@ class PhiModel:
255
  self,
256
  query: str,
257
  context: str,
258
- system_prompt: Optional[str] = None
 
259
  ) -> str:
260
- """Generate response with RAG context.
261
 
262
  Args:
263
  query: User's question.
264
  context: Retrieved context from documents.
265
  system_prompt: Optional system prompt.
 
266
 
267
  Returns:
268
  Generated response.
@@ -286,13 +288,22 @@ class PhiModel:
286
  if not context or not context.strip():
287
  context = "No relevant documents found."
288
 
289
- user_message = f"""Here's some information from the documents:
 
 
 
 
 
 
 
 
 
290
 
291
  {context}
292
 
293
- User's question: {query}
294
 
295
- Please respond naturally and helpfully:"""
296
 
297
  messages = [
298
  {"role": "system", "content": system_prompt},
@@ -300,3 +311,4 @@ Please respond naturally and helpfully:"""
300
  ]
301
 
302
  return self.chat(messages)
 
 
255
  self,
256
  query: str,
257
  context: str,
258
+ system_prompt: Optional[str] = None,
259
+ conversation_history: Optional[str] = None
260
  ) -> str:
261
+ """Generate response with RAG context and conversation history.
262
 
263
  Args:
264
  query: User's question.
265
  context: Retrieved context from documents.
266
  system_prompt: Optional system prompt.
267
+ conversation_history: Optional formatted conversation history (last 6 messages).
268
 
269
  Returns:
270
  Generated response.
 
288
  if not context or not context.strip():
289
  context = "No relevant documents found."
290
 
291
+ # Build message with optional history
292
+ history_section = ""
293
+ if conversation_history and conversation_history.strip():
294
+ history_section = f"""Previous conversation:
295
+ {conversation_history}
296
+
297
+ ---
298
+ """
299
+
300
+ user_message = f"""{history_section}Here's some information from the documents:
301
 
302
  {context}
303
 
304
+ User's current question: {query}
305
 
306
+ Please respond naturally and helpfully, considering the conversation context:"""
307
 
308
  messages = [
309
  {"role": "system", "content": system_prompt},
 
311
  ]
312
 
313
  return self.chat(messages)
314
+
src/rag/pipeline.py CHANGED
@@ -120,17 +120,18 @@ class RAGPipeline:
120
  print(f"Adding {len(chunks)} chunks to vector store...")
121
  return self.vector_store.add_chunks(chunks)
122
 
123
- def query(self, question: str, top_k: Optional[int] = None) -> Dict[str, Any]:
124
- """Query the RAG system with semantic caching.
125
 
126
  Lookup order:
127
  1. Exact cache match (instant)
128
  2. Semantic similarity match (instant)
129
- 3. Model generation (slow)
130
 
131
  Args:
132
  question: User's question.
133
  top_k: Number of documents to retrieve.
 
134
 
135
  Returns:
136
  Dict with answer and sources.
@@ -168,9 +169,20 @@ class RAGPipeline:
168
  }
169
 
170
  # 3. Generate answer using LLM (no cache hit)
 
 
 
 
 
 
 
171
  context = self.retriever.retrieve_text(question, top_k)
172
  sources = self.retriever.retrieve(question, top_k)
173
- answer = self.llm.chat_with_context(question, context)
 
 
 
 
174
 
175
  # Build source list
176
  source_list = [
 
120
  print(f"Adding {len(chunks)} chunks to vector store...")
121
  return self.vector_store.add_chunks(chunks)
122
 
123
+ def query(self, question: str, top_k: Optional[int] = None, session_id: Optional[str] = None) -> Dict[str, Any]:
124
+ """Query the RAG system with semantic caching and session history.
125
 
126
  Lookup order:
127
  1. Exact cache match (instant)
128
  2. Semantic similarity match (instant)
129
+ 3. Model generation with conversation history (slow)
130
 
131
  Args:
132
  question: User's question.
133
  top_k: Number of documents to retrieve.
134
+ session_id: Optional session ID for conversation history.
135
 
136
  Returns:
137
  Dict with answer and sources.
 
169
  }
170
 
171
  # 3. Generate answer using LLM (no cache hit)
172
+ # Get conversation history if session provided
173
+ conversation_history = None
174
+ if session_id:
175
+ from src.session import get_session_manager
176
+ session_mgr = get_session_manager()
177
+ conversation_history = session_mgr.get_history_for_prompt(session_id)
178
+
179
  context = self.retriever.retrieve_text(question, top_k)
180
  sources = self.retriever.retrieve(question, top_k)
181
+ answer = self.llm.chat_with_context(
182
+ question,
183
+ context,
184
+ conversation_history=conversation_history
185
+ )
186
 
187
  # Build source list
188
  source_list = [
src/session.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session management for FreeRAG chat history."""
2
+
3
+ import json
4
+ import logging
5
+ import threading
6
+ import uuid
7
+ from datetime import datetime, timedelta
8
+ from pathlib import Path
9
+ from typing import Optional, Dict, Any, List, Tuple
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class SessionManager:
15
+ """Manages user sessions and chat history.
16
+
17
+ Each session is identified by a UUID and stores:
18
+ - Chat history (question-answer pairs)
19
+ - Session metadata (created_at, last_active)
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ storage_dir: str = "./.cache/sessions",
25
+ max_history: int = 6,
26
+ session_ttl_hours: int = 24
27
+ ):
28
+ """Initialize session manager.
29
+
30
+ Args:
31
+ storage_dir: Directory to store session data.
32
+ max_history: Maximum messages to keep (for context).
33
+ session_ttl_hours: Session expiry time in hours.
34
+ """
35
+ self.storage_dir = Path(storage_dir)
36
+ self.storage_dir.mkdir(parents=True, exist_ok=True)
37
+ self.max_history = max_history
38
+ self.session_ttl = timedelta(hours=session_ttl_hours)
39
+ self._lock = threading.Lock()
40
+ self._sessions: Dict[str, Dict[str, Any]] = {}
41
+
42
+ # Load existing sessions
43
+ self._load_sessions()
44
+ self._cleanup_expired()
45
+ logger.info(f"📋 Session manager initialized with {len(self._sessions)} active sessions")
46
+
47
+ def create_session(self) -> str:
48
+ """Create a new session and return its ID."""
49
+ session_id = str(uuid.uuid4())
50
+
51
+ with self._lock:
52
+ self._sessions[session_id] = {
53
+ "id": session_id,
54
+ "created_at": datetime.now().isoformat(),
55
+ "last_active": datetime.now().isoformat(),
56
+ "history": []
57
+ }
58
+ self._save_session(session_id)
59
+
60
+ logger.info(f"📝 Created new session: {session_id[:8]}...")
61
+ return session_id
62
+
63
+ def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
64
+ """Get session by ID, creating if doesn't exist."""
65
+ with self._lock:
66
+ if session_id not in self._sessions:
67
+ # Try to load from disk
68
+ self._load_session(session_id)
69
+
70
+ if session_id in self._sessions:
71
+ # Update last_active
72
+ self._sessions[session_id]["last_active"] = datetime.now().isoformat()
73
+ return self._sessions[session_id]
74
+
75
+ return None
76
+
77
+ def add_message(
78
+ self,
79
+ session_id: str,
80
+ question: str,
81
+ answer: str
82
+ ) -> None:
83
+ """Add a Q&A pair to session history."""
84
+ with self._lock:
85
+ if session_id not in self._sessions:
86
+ self._sessions[session_id] = {
87
+ "id": session_id,
88
+ "created_at": datetime.now().isoformat(),
89
+ "last_active": datetime.now().isoformat(),
90
+ "history": []
91
+ }
92
+
93
+ session = self._sessions[session_id]
94
+ session["history"].append({
95
+ "question": question,
96
+ "answer": answer,
97
+ "timestamp": datetime.now().isoformat()
98
+ })
99
+
100
+ # Keep only last N messages
101
+ if len(session["history"]) > self.max_history * 2:
102
+ session["history"] = session["history"][-self.max_history:]
103
+
104
+ session["last_active"] = datetime.now().isoformat()
105
+ self._save_session(session_id)
106
+
107
+ def get_history(self, session_id: str, limit: int = None) -> List[Tuple[str, str]]:
108
+ """Get chat history for a session.
109
+
110
+ Args:
111
+ session_id: Session ID.
112
+ limit: Max messages to return (default: max_history).
113
+
114
+ Returns:
115
+ List of (question, answer) tuples.
116
+ """
117
+ limit = limit or self.max_history
118
+ session = self.get_session(session_id)
119
+
120
+ if not session:
121
+ return []
122
+
123
+ history = session.get("history", [])[-limit:]
124
+ return [(h["question"], h["answer"]) for h in history]
125
+
126
+ def get_history_for_prompt(self, session_id: str) -> str:
127
+ """Get formatted history for including in prompt.
128
+
129
+ Returns last 6 messages formatted for the model.
130
+ """
131
+ history = self.get_history(session_id, self.max_history)
132
+
133
+ if not history:
134
+ return ""
135
+
136
+ formatted = []
137
+ for q, a in history:
138
+ # Truncate long messages
139
+ q_short = q[:200] + "..." if len(q) > 200 else q
140
+ a_short = a[:300] + "..." if len(a) > 300 else a
141
+ formatted.append(f"User: {q_short}\nAssistant: {a_short}")
142
+
143
+ return "\n\n".join(formatted)
144
+
145
+ def clear_history(self, session_id: str) -> None:
146
+ """Clear chat history for a session."""
147
+ with self._lock:
148
+ if session_id in self._sessions:
149
+ self._sessions[session_id]["history"] = []
150
+ self._save_session(session_id)
151
+
152
+ def _save_session(self, session_id: str) -> None:
153
+ """Save session to disk."""
154
+ try:
155
+ session_file = self.storage_dir / f"{session_id}.json"
156
+ with open(session_file, 'w', encoding='utf-8') as f:
157
+ json.dump(self._sessions[session_id], f, ensure_ascii=False, indent=2)
158
+ except Exception as e:
159
+ logger.warning(f"Failed to save session {session_id[:8]}: {e}")
160
+
161
+ def _load_session(self, session_id: str) -> None:
162
+ """Load session from disk."""
163
+ try:
164
+ session_file = self.storage_dir / f"{session_id}.json"
165
+ if session_file.exists():
166
+ with open(session_file, 'r', encoding='utf-8') as f:
167
+ self._sessions[session_id] = json.load(f)
168
+ except Exception as e:
169
+ logger.warning(f"Failed to load session {session_id[:8]}: {e}")
170
+
171
+ def _load_sessions(self) -> None:
172
+ """Load all sessions from disk."""
173
+ try:
174
+ for session_file in self.storage_dir.glob("*.json"):
175
+ try:
176
+ with open(session_file, 'r', encoding='utf-8') as f:
177
+ session = json.load(f)
178
+ self._sessions[session["id"]] = session
179
+ except Exception:
180
+ pass
181
+ except Exception as e:
182
+ logger.warning(f"Failed to load sessions: {e}")
183
+
184
+ def _cleanup_expired(self) -> None:
185
+ """Remove expired sessions."""
186
+ now = datetime.now()
187
+ expired = []
188
+
189
+ for sid, session in self._sessions.items():
190
+ try:
191
+ last_active = datetime.fromisoformat(session["last_active"])
192
+ if now - last_active > self.session_ttl:
193
+ expired.append(sid)
194
+ except Exception:
195
+ pass
196
+
197
+ for sid in expired:
198
+ self._delete_session(sid)
199
+
200
+ if expired:
201
+ logger.info(f"♻️ Cleaned up {len(expired)} expired sessions")
202
+
203
+ def _delete_session(self, session_id: str) -> None:
204
+ """Delete a session."""
205
+ with self._lock:
206
+ if session_id in self._sessions:
207
+ del self._sessions[session_id]
208
+
209
+ session_file = self.storage_dir / f"{session_id}.json"
210
+ if session_file.exists():
211
+ session_file.unlink()
212
+
213
+
214
+ # Global session manager
215
+ _session_manager: Optional[SessionManager] = None
216
+
217
+
218
+ def get_session_manager() -> SessionManager:
219
+ """Get or create the global session manager."""
220
+ global _session_manager
221
+ if _session_manager is None:
222
+ _session_manager = SessionManager()
223
+ return _session_manager