aki-008 commited on
Commit
657674a
Β·
1 Parent(s): 0d22fa6

feat:Chat session mgmt added

Browse files
Backend/app/api/v1/endpoints/auth.py CHANGED
@@ -4,7 +4,6 @@ from sqlalchemy import select
4
  from datetime import timedelta
5
  from app.schema import UserCreate, LoginRequest
6
  from app.schema.models import LoginResponse
7
- # from app.schema.models import LoginRequest
8
  from app.models import User
9
  from app.core import verify_password, get_password_hash, create_access_token
10
  from app.api.deps import get_db
 
4
  from datetime import timedelta
5
  from app.schema import UserCreate, LoginRequest
6
  from app.schema.models import LoginResponse
 
7
  from app.models import User
8
  from app.core import verify_password, get_password_hash, create_access_token
9
  from app.api.deps import get_db
Backend/app/api/v1/endpoints/notes.py CHANGED
@@ -13,9 +13,15 @@ from llama_index.readers.file import PyMuPDFReader
13
  from llama_index.core.node_parser import SentenceSplitter
14
  from typing import Annotated
15
  import shutil
 
16
  import os
17
  from sentence_transformers import SentenceTransformer
18
  from .quiz import search_logic
 
 
 
 
 
19
 
20
  router = APIRouter(prefix="/notes")
21
 
@@ -28,6 +34,7 @@ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
28
  async def ai_chat(
29
  Input_model: AI_chat_input,
30
  collection: Collection = Depends(get_chroma_collection),
 
31
  current_user: User = Depends(get_current_user)
32
  ):
33
  messages_dict = [msg.model_dump() for msg in Input_model.messages]
@@ -39,6 +46,8 @@ async def ai_chat(
39
  media_type="text/plain"
40
  )
41
 
 
 
42
  @router.post("/upload_notes")
43
  async def upload_notes(
44
  file: Annotated[UploadFile, File(description="A PDF file to upload")],
@@ -46,20 +55,14 @@ async def upload_notes(
46
  db: AsyncSession = Depends(get_db),
47
  current_user: User = Depends(get_current_user)
48
  ):
49
- file_content = file.read()
50
-
51
- await file.seek(0)
52
-
53
 
54
  safe_filename = f"{uuid.uuid4()}_{file.filename}"
55
  file_path = Path(UPLOAD_DIRECTORY) / safe_filename
56
 
57
  try:
58
-
59
  with open(file_path, "wb") as buffer:
60
  shutil.copyfileobj(file.file, buffer)
61
 
62
- # 2. Process PDF into chunks
63
  chunks = await pdf_process(str(file_path))
64
 
65
  if not chunks:
@@ -68,11 +71,11 @@ async def upload_notes(
68
  full_text_preview = " ".join(chunks)[:2000]
69
  doc_embedding = embedding_model.encode(full_text_preview).tolist()
70
 
71
-
 
72
  new_doc = PDFData(
73
- pdf_blob=file_path.read_bytes(),
74
- messages_list=[],
75
- pdf_embedding=doc_embedding,
76
  user_id=current_user.id
77
  )
78
 
@@ -80,13 +83,14 @@ async def upload_notes(
80
  await db.commit()
81
  await db.refresh(new_doc)
82
 
83
- # Generate unique IDs for each chunk
84
  ids = [str(uuid.uuid4()) for _ in chunks]
85
-
86
- # Create metadata so you know which file the chunk came from
87
- metadatas = [{"source_file": file.filename, "chunk_index": new_doc.id,"chunk_index": i} for i in range(len(chunks))]
88
 
89
- # Add to ChromaDB
 
 
 
 
 
90
  await collection.add(
91
  ids=ids,
92
  documents=chunks,
@@ -96,15 +100,16 @@ async def upload_notes(
96
  return {
97
  "status": "success",
98
  "filename": file.filename,
 
99
  "chunks_ingested": len(chunks)
100
  }
101
 
102
  except Exception as e:
103
- print(f"Error: {e}") # Log for server console
104
  raise HTTPException(status_code=500, detail=f"Error processing PDF: {str(e)}")
105
 
106
  finally:
107
- # 3. Cleanup: Remove the temp file
108
  if file_path.exists():
109
  os.remove(file_path)
110
 
@@ -132,4 +137,179 @@ async def pdf_process(pdf_path: str):
132
  return text_chunks
133
  except Exception as e:
134
  print(f"PDF Processing Error: {e}")
135
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from llama_index.core.node_parser import SentenceSplitter
14
  from typing import Annotated
15
  import shutil
16
+ import tempfile
17
  import os
18
  from sentence_transformers import SentenceTransformer
19
  from .quiz import search_logic
20
+ from sqlalchemy import select, desc, asc
21
+ from app.models.tables import ChatSession, ChatMessage
22
+ from app.schema.models import SessionCreate, SessionResponse, MessageResponse
23
+ from app.database import async_session_maker
24
+ from typing import List
25
 
26
  router = APIRouter(prefix="/notes")
27
 
 
34
  async def ai_chat(
35
  Input_model: AI_chat_input,
36
  collection: Collection = Depends(get_chroma_collection),
37
+ db: AsyncSession = Depends(get_db),
38
  current_user: User = Depends(get_current_user)
39
  ):
40
  messages_dict = [msg.model_dump() for msg in Input_model.messages]
 
46
  media_type="text/plain"
47
  )
48
 
49
+ # Backend/app/api/v1/endpoints/notes.py
50
+
51
  @router.post("/upload_notes")
52
  async def upload_notes(
53
  file: Annotated[UploadFile, File(description="A PDF file to upload")],
 
55
  db: AsyncSession = Depends(get_db),
56
  current_user: User = Depends(get_current_user)
57
  ):
 
 
 
 
58
 
59
  safe_filename = f"{uuid.uuid4()}_{file.filename}"
60
  file_path = Path(UPLOAD_DIRECTORY) / safe_filename
61
 
62
  try:
 
63
  with open(file_path, "wb") as buffer:
64
  shutil.copyfileobj(file.file, buffer)
65
 
 
66
  chunks = await pdf_process(str(file_path))
67
 
68
  if not chunks:
 
71
  full_text_preview = " ".join(chunks)[:2000]
72
  doc_embedding = embedding_model.encode(full_text_preview).tolist()
73
 
74
+ file.file.seek(0)
75
+
76
  new_doc = PDFData(
77
+ pdf_blob=file.file.read(),
78
+ pdf_embedding=doc_embedding,
 
79
  user_id=current_user.id
80
  )
81
 
 
83
  await db.commit()
84
  await db.refresh(new_doc)
85
 
 
86
  ids = [str(uuid.uuid4()) for _ in chunks]
 
 
 
87
 
88
+ metadatas = [{
89
+ "source_file": file.filename,
90
+ "pdf_id": new_doc.id,
91
+ "chunk_index": i
92
+ } for i in range(len(chunks))]
93
+
94
  await collection.add(
95
  ids=ids,
96
  documents=chunks,
 
100
  return {
101
  "status": "success",
102
  "filename": file.filename,
103
+ "doc_id": new_doc.id,
104
  "chunks_ingested": len(chunks)
105
  }
106
 
107
  except Exception as e:
108
+ print(f"Error: {e}")
109
  raise HTTPException(status_code=500, detail=f"Error processing PDF: {str(e)}")
110
 
111
  finally:
112
+ # Cleanup temp file
113
  if file_path.exists():
114
  os.remove(file_path)
115
 
 
137
  return text_chunks
138
  except Exception as e:
139
  print(f"PDF Processing Error: {e}")
140
+ raise e
141
+
142
+ # -------------------------
143
+ # 1. Session Management
144
+ # -------------------------
145
+
146
+ @router.post("/sessions", response_model=SessionResponse)
147
+ async def create_session(
148
+ session_in: SessionCreate,
149
+ db: AsyncSession = Depends(get_db),
150
+ current_user: User = Depends(get_current_user)
151
+ ):
152
+ result = await db.execute(select(PDFData).filter(PDFData.id == session_in.pdf_id, PDFData.user_id == current_user.id))
153
+ pdf = result.scalar_one_or_none()
154
+ if not pdf:
155
+ raise HTTPException(404, "PDF not found")
156
+
157
+ new_session = ChatSession(
158
+ id=str(uuid.uuid4()),
159
+ name=session_in.name,
160
+ pdf_id=session_in.pdf_id,
161
+ user_id=current_user.id
162
+ )
163
+ db.add(new_session)
164
+ await db.commit()
165
+ await db.refresh(new_session)
166
+ return new_session
167
+
168
+ @router.get("/sessions/{pdf_id}", response_model=List[SessionResponse])
169
+ async def get_sessions(
170
+ pdf_id: int,
171
+ db: AsyncSession = Depends(get_db),
172
+ current_user: User = Depends(get_current_user)
173
+ ):
174
+
175
+ result = await db.execute(
176
+ select(ChatSession)
177
+ .where(ChatSession.pdf_id == pdf_id)
178
+ .where(ChatSession.user_id == current_user.id)
179
+ .order_by(desc(ChatSession.created_at))
180
+ )
181
+ return result.scalars().all()
182
+
183
+ @router.get("/history/{session_id}", response_model=List[MessageResponse])
184
+ async def get_history(
185
+ session_id: str,
186
+ db: AsyncSession = Depends(get_db),
187
+ current_user: User = Depends(get_current_user)
188
+ ):
189
+ result = await db.execute(
190
+ select(ChatMessage)
191
+ .where(ChatMessage.session_id == session_id)
192
+ .order_by(asc(ChatMessage.created_at))
193
+ )
194
+ return result.scalars().all()
195
+
196
+ # -------------------------
197
+ # 2. Chat with Memory
198
+ # -------------------------
199
+
200
+ @router.post("/chat/{session_id}")
201
+ async def chat_session(
202
+ session_id: str,
203
+ user_prompt: str,
204
+ db: AsyncSession = Depends(get_db),
205
+ collection: Collection = Depends(get_chroma_collection),
206
+ current_user: User = Depends(get_current_user)
207
+ ):
208
+ # 1. Verify Session
209
+ session_res = await db.execute(select(ChatSession).where(ChatSession.id == session_id))
210
+ session = session_res.scalar_one_or_none()
211
+ if not session:
212
+ raise HTTPException(404, "Session not found")
213
+
214
+ await ensure_pdf_in_chroma(session.pdf_id, db, collection)
215
+ # ---------------------------------------------------------
216
+
217
+ # 3. Save User Message
218
+ user_msg = ChatMessage(session_id=session_id, role="user", content=user_prompt)
219
+ db.add(user_msg)
220
+ await db.commit()
221
+
222
+ # 4. Filter & Search
223
+ filter_dict = {"pdf_id": session.pdf_id}
224
+ retrieved_context = await search_logic(user_prompt, collection, filter_dict)
225
+
226
+ # 5. Fetch History & Stream (Rest of your code remains the same)
227
+ history_res = await db.execute(
228
+ select(ChatMessage)
229
+ .where(ChatMessage.session_id == session_id)
230
+ .order_by(asc(ChatMessage.created_at))
231
+ )
232
+ history_msgs = history_res.scalars().all()
233
+ messages_payload = [{"role": m.role, "content": m.content} for m in history_msgs]
234
+
235
+ async def response_generator():
236
+ full_response = ""
237
+ async for chunk in stream_chat(messages_payload, "", retrieved_context):
238
+ full_response += chunk
239
+ yield chunk
240
+
241
+ async with async_session_maker() as new_db_session:
242
+ ai_msg = ChatMessage(session_id=session_id, role="assistant", content=full_response)
243
+ new_db_session.add(ai_msg)
244
+ await new_db_session.commit()
245
+
246
+ return StreamingResponse(response_generator(), media_type="text/plain")
247
+
248
+
249
+
250
+ async def ensure_pdf_in_chroma(pdf_id: int, db: AsyncSession, collection: Collection):
251
+ """
252
+ Checks if embeddings exist for the given PDF ID.
253
+ If not, it fetches the blob from SQL, chunks it, and re-uploads to Chroma.
254
+ """
255
+ # 1. Check Chroma first (Fast check)
256
+ # We query for just 1 ID to see if any exist with this metadata
257
+ existing = await collection.get(
258
+ where={"pdf_id": pdf_id},
259
+ limit=1
260
+ )
261
+
262
+ if existing and len(existing['ids']) > 0:
263
+ print(f"βœ… Embeddings found for PDF {pdf_id}. No action needed.")
264
+ return
265
+
266
+ print(f"⚠️ Embeddings missing for PDF {pdf_id}. Restoring from SQL...")
267
+
268
+ # 2. Fetch Blob from SQL
269
+ result = await db.execute(select(PDFData).where(PDFData.id == pdf_id))
270
+ pdf_record = result.scalar_one_or_none()
271
+
272
+ if not pdf_record:
273
+ raise HTTPException(404, "PDF Data not found in database")
274
+
275
+ # 3. Write Blob to Temp File (Required because pdf_process expects a path)
276
+ # We use valid suffixes so PyMuPDF knows it's a PDF
277
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
278
+ tmp_file.write(pdf_record.pdf_blob)
279
+ tmp_path = tmp_file.name
280
+
281
+ try:
282
+ # 4. Re-Process (Reuse your existing chunking logic)
283
+ chunks = await pdf_process(tmp_path)
284
+
285
+ if not chunks:
286
+ print("Warning: Restored PDF has no text.")
287
+ return
288
+
289
+ # 5. Re-Embed and Upload to Chroma
290
+ # Generate new UUIDs for the chunks
291
+ ids = [str(uuid.uuid4()) for _ in chunks]
292
+
293
+ # EXACT SAME metadata structure as upload_notes
294
+ metadatas = [{
295
+ "source_file": pdf_record.filename,
296
+ "pdf_id": pdf_id,
297
+ "chunk_index": i
298
+ } for i in range(len(chunks))]
299
+
300
+ # Re-add to Chroma
301
+ await collection.add(
302
+ ids=ids,
303
+ documents=chunks,
304
+ metadatas=metadatas
305
+ )
306
+ print(f"♻️ Successfully restored {len(chunks)} chunks for PDF {pdf_id}")
307
+
308
+ except Exception as e:
309
+ print(f"❌ Error restoring PDF: {e}")
310
+ raise HTTPException(500, f"Failed to restore PDF embeddings: {str(e)}")
311
+
312
+ finally:
313
+ # Cleanup temp file
314
+ if os.path.exists(tmp_path):
315
+ os.remove(tmp_path)
Backend/app/api/v1/endpoints/quiz.py CHANGED
@@ -14,27 +14,23 @@ import logging
14
 
15
  router = APIRouter(prefix="/quiz")
16
 
 
17
 
18
- # 1. Set up a logger (if you haven't already globally)
19
- logger = logging.getLogger("uvicorn.error") # reusing uvicorn's logger ensures it shows up in your terminal
20
-
21
- async def search_logic(query: str, collection: Collection):
22
- # Log the incoming query
23
  logger.info(f"πŸ” [Search Logic] Starting search for query: '{query}'")
24
 
25
  try:
26
  results = await collection.query(
27
- query_texts=[query],
28
- n_results=5
29
- )
 
30
 
31
- # Log the raw results to see exactly what ChromaDB returned (helps spot NoneTypes)
32
  logger.info(f"πŸ“„ [Search Logic] Raw results from DB: {results}")
33
 
34
  if results and results.get('documents') and len(results['documents']) > 0:
35
  raw_docs = results['documents'][0]
36
-
37
- # Filter None values and Log how many were found vs valid
38
  valid_docs = [str(doc) for doc in raw_docs if doc is not None]
39
 
40
  logger.info(f"βœ… [Search Logic] Processing: Found {len(raw_docs)} items. Valid text items: {len(valid_docs)}")
@@ -42,7 +38,6 @@ async def search_logic(query: str, collection: Collection):
42
  if len(raw_docs) != len(valid_docs):
43
  logger.warning("⚠️ [Search Logic] Warning: Some documents contained NoneType and were skipped.")
44
 
45
- # Join with a space (safer than empty string)
46
  final_context = " ".join(valid_docs)
47
  return final_context
48
 
@@ -51,9 +46,7 @@ async def search_logic(query: str, collection: Collection):
51
  return ""
52
 
53
  except Exception as e:
54
- # Log the full error if something crashes
55
  logger.error(f"❌ [Search Logic] CRITICAL ERROR: {str(e)}")
56
- # You might want to re-raise the error or return empty depending on your needs
57
  return ""
58
 
59
  @router.get("/search_docs")
 
14
 
15
  router = APIRouter(prefix="/quiz")
16
 
17
+ logger = logging.getLogger("uvicorn.error")
18
 
19
+ async def search_logic(query: str, collection: Collection, filter_dict: dict = None):
 
 
 
 
20
  logger.info(f"πŸ” [Search Logic] Starting search for query: '{query}'")
21
 
22
  try:
23
  results = await collection.query(
24
+ query_texts=[query],
25
+ n_results=5,
26
+ where=filter_dict
27
+ )
28
 
 
29
  logger.info(f"πŸ“„ [Search Logic] Raw results from DB: {results}")
30
 
31
  if results and results.get('documents') and len(results['documents']) > 0:
32
  raw_docs = results['documents'][0]
33
+
 
34
  valid_docs = [str(doc) for doc in raw_docs if doc is not None]
35
 
36
  logger.info(f"βœ… [Search Logic] Processing: Found {len(raw_docs)} items. Valid text items: {len(valid_docs)}")
 
38
  if len(raw_docs) != len(valid_docs):
39
  logger.warning("⚠️ [Search Logic] Warning: Some documents contained NoneType and were skipped.")
40
 
 
41
  final_context = " ".join(valid_docs)
42
  return final_context
43
 
 
46
  return ""
47
 
48
  except Exception as e:
 
49
  logger.error(f"❌ [Search Logic] CRITICAL ERROR: {str(e)}")
 
50
  return ""
51
 
52
  @router.get("/search_docs")
Backend/app/models/tables.py CHANGED
@@ -1,4 +1,4 @@
1
- from sqlalchemy import String, LargeBinary, JSON, ForeignKey
2
  from sqlalchemy.orm import Mapped, mapped_column, relationship
3
  from datetime import datetime
4
  from app.database import Base
@@ -19,7 +19,33 @@ class PDFData(Base):
19
 
20
  id: Mapped[int] = mapped_column(primary_key=True, index=True)
21
  pdf_blob: Mapped[bytes] = mapped_column(LargeBinary)
22
- messages_list: Mapped[List] = mapped_column(JSON)
23
  pdf_embedding: Mapped[list[float]] = mapped_column(JSON)
24
  user_id: Mapped[int] = mapped_column(ForeignKey('users.id'))
25
- user: Mapped["User"] = relationship(back_populates="pdf_data")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import String, LargeBinary, JSON, ForeignKey, Text, DateTime
2
  from sqlalchemy.orm import Mapped, mapped_column, relationship
3
  from datetime import datetime
4
  from app.database import Base
 
19
 
20
  id: Mapped[int] = mapped_column(primary_key=True, index=True)
21
  pdf_blob: Mapped[bytes] = mapped_column(LargeBinary)
 
22
  pdf_embedding: Mapped[list[float]] = mapped_column(JSON)
23
  user_id: Mapped[int] = mapped_column(ForeignKey('users.id'))
24
+
25
+ user: Mapped["User"] = relationship(back_populates="pdf_data")
26
+ chat_sessions: Mapped[List["ChatSession"]] = relationship(back_populates="pdf_data", cascade="all, delete-orphan")
27
+
28
+ class ChatSession(Base):
29
+ __tablename__ = "chat_sessions"
30
+
31
+ id: Mapped[str] = mapped_column(String, primary_key=True)
32
+ name: Mapped[str] = mapped_column(String(100))
33
+ created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
34
+
35
+ pdf_id: Mapped[int] = mapped_column(ForeignKey('pdf_data.id'))
36
+ pdf_data: Mapped["PDFData"] = relationship(back_populates="chat_sessions")
37
+
38
+ user_id: Mapped[int] = mapped_column(ForeignKey('users.id'))
39
+
40
+ messages: Mapped[List["ChatMessage"]] = relationship(back_populates="session", cascade="all, delete-orphan")
41
+
42
+ class ChatMessage(Base):
43
+ __tablename__ = "chat_messages"
44
+
45
+ id: Mapped[int] = mapped_column(primary_key=True, index=True)
46
+ session_id: Mapped[str] = mapped_column(ForeignKey('chat_sessions.id'))
47
+ session: Mapped["ChatSession"] = relationship(back_populates="messages")
48
+
49
+ role: Mapped[str] = mapped_column(String(20))
50
+ content: Mapped[str] = mapped_column(Text)
51
+ created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
Backend/app/schema/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from app.schema.models import UserCreate, Token, LoginRequest, Quiz_input, QuizOutput, IngestRequest, ChatMessage, AI_chat_input, pdf_input
2
 
3
- __all__ = ["UserCreate", "Token", "LoginRequest", "Quiz_input", "QuizOutput", "IngestRequest", "ChatMessage", "AI_chat_input", "pdf_input"]
 
1
+ from app.schema.models import UserCreate, Token, LoginRequest, Quiz_input, QuizOutput, IngestRequest, ChatMessage, AI_chat_input, pdf_input, SessionCreate, SessionResponse, MessageResponse
2
 
3
+ __all__ = ["UserCreate", "Token", "LoginRequest", "Quiz_input", "QuizOutput", "IngestRequest", "ChatMessage", "AI_chat_input", "pdf_input", "SessionCreate", "SessionResponse", "MessageResponse"]
Backend/app/schema/models.py CHANGED
@@ -61,6 +61,22 @@ class AI_chat_input(BaseModel):
61
  None, description="The unique ID of the current chat session (optional)."
62
  )
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  #--------Notes page models--------#
65
 
66
  class pdf_input(BaseModel):
 
61
  None, description="The unique ID of the current chat session (optional)."
62
  )
63
 
64
+ class SessionCreate(BaseModel):
65
+ pdf_id: int
66
+ name: str = "New Chat"
67
+
68
+ class SessionResponse(BaseModel):
69
+ id: str
70
+ name: str
71
+ created_at: datetime
72
+ pdf_id: int
73
+
74
+ class MessageResponse(BaseModel):
75
+ id: int
76
+ role: str
77
+ content: str
78
+ created_at: datetime
79
+
80
  #--------Notes page models--------#
81
 
82
  class pdf_input(BaseModel):