Sarp Bilgiç commited on
Commit
ad84a49
·
1 Parent(s): 2d97ece

chat history ready for frontend integration, me endpoint

Browse files
src/api/routers/sessions.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, Query
2
+ from typing import Annotated, List
3
+ from src.api.models.user import User
4
+ from src.api.models.chat import ChatMessageRole
5
+ from src.api.dependencies.auth import get_current_user_required, get_current_user_optional
6
+ from src.api.dependencies.clients import get_db, get_chat_history_service
7
+ from sqlmodel.ext.asyncio.session import AsyncSession
8
+ from src.api.selectors.chat.get_session import get_chat_sessions
9
+ from src.api.schemas.session import ChatSessionList, ChatMessageRead
10
+ from src.api.services.chat_history_service import ChatHistoryService
11
+ from llama_index.core.llms import MessageRole
12
+ import uuid
13
+
14
+ router = APIRouter(
15
+ prefix="/api/v1",
16
+ tags=["Sessions"],
17
+ )
18
+
19
+ def _llama_role_to_chat_role(role: MessageRole) -> ChatMessageRole:
20
+ mapping = {
21
+ MessageRole.USER: ChatMessageRole.USER,
22
+ MessageRole.ASSISTANT: ChatMessageRole.ASSISTANT,
23
+ MessageRole.SYSTEM: ChatMessageRole.SYSTEM,
24
+ }
25
+ return mapping.get(role, ChatMessageRole.USER)
26
+
27
+ @router.get("/sessions", response_model=List[ChatSessionList])
28
+ async def list_sessions(
29
+ user: Annotated[User, Depends(get_current_user_required)],
30
+ db: Annotated[AsyncSession, Depends(get_db)],
31
+ limit: Annotated[int, Query(default=20, le=100)] = 20,
32
+ offset: Annotated[int, Query(default=0)] = 0
33
+ ):
34
+ return await get_chat_sessions(
35
+ user_id=user.id,
36
+ db=db,
37
+ limit=limit,
38
+ offset=offset
39
+ )
40
+
41
+ @router.get("/sessions/{session_id}/messages", response_model=List[ChatMessageRead])
42
+ async def get_messages(
43
+ session_id: uuid.UUID,
44
+ chat_history_service: Annotated[ChatHistoryService, Depends(get_chat_history_service)],
45
+ user: Annotated[User, Depends(get_current_user_optional)],
46
+ db: Annotated[AsyncSession, Depends(get_db)],
47
+ ):
48
+ messages = await chat_history_service.get_messages(
49
+ session_id=session_id,
50
+ user=user,
51
+ db=db
52
+ )
53
+ return [
54
+ ChatMessageRead(
55
+ role=_llama_role_to_chat_role(msg.role),
56
+ content=str(msg.content)
57
+ )
58
+ for msg in messages
59
+ ]
60
+
61
+
src/api/routers/user.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends
2
+ from src.api.dependencies.auth import get_current_user_required
3
+ from src.api.models.user import User
4
+ from src.api.schemas.user import UserRead
5
+ from typing import Annotated
6
+
7
+ router = APIRouter(
8
+ prefix="/api/v1/user",
9
+ tags=["User"],
10
+ )
11
+
12
+ @router.get("/me", response_model=UserRead)
13
+ async def get_me(
14
+ user: Annotated[User, Depends(get_current_user_required)],
15
+ ):
16
+ return user
src/api/schemas/session.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ import uuid
3
+ from datetime import datetime
4
+ from src.api.models.chat import ChatMessageRole
5
+
6
+ class ChatSessionList(BaseModel):
7
+ id: uuid.UUID
8
+ title: str
9
+ updated_at: datetime
10
+
11
+ class Config:
12
+ from_attributes = True
13
+
14
+ class ChatMessageRead(BaseModel):
15
+ role: ChatMessageRole
16
+ content: str
17
+
18
+ class Config:
19
+ from_attributes = True
src/api/schemas/user.py CHANGED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr
2
+
3
+ class UserRead(BaseModel):
4
+ id: int
5
+ username: str
6
+ email: EmailStr
7
+
8
+ class Config:
9
+ from_attributes = True
src/api/selectors/chat/create_session.py CHANGED
@@ -3,18 +3,28 @@ from src.api.models.chat import ChatSession
3
  from typing import Optional
4
  import uuid
5
 
6
- async def create_sessions(
7
  user_id: Optional[int],
8
  title: str,
9
- db: AsyncSession
 
10
  ) -> ChatSession:
11
  session = ChatSession(
12
- id = uuid.uuid4(),
13
- title = title,
14
- user_id = user_id,
15
- is_active = True,
16
  )
17
  db.add(session)
18
  await db.commit()
19
  await db.refresh(session)
20
- return session
 
 
 
 
 
 
 
 
 
 
3
  from typing import Optional
4
  import uuid
5
 
6
+ async def create_session(
7
  user_id: Optional[int],
8
  title: str,
9
+ db: AsyncSession,
10
+ session_id: Optional[uuid.UUID] = None
11
  ) -> ChatSession:
12
  session = ChatSession(
13
+ id=session_id or uuid.uuid4(),
14
+ title=title,
15
+ user_id=user_id,
16
+ is_active=True,
17
  )
18
  db.add(session)
19
  await db.commit()
20
  await db.refresh(session)
21
+ return session
22
+
23
+ async def create_sessions(
24
+ user_id: Optional[int],
25
+ title: str,
26
+ db: AsyncSession,
27
+ session_id: Optional[uuid.UUID] = None
28
+ ) -> ChatSession:
29
+ """Deprecated"""
30
+ return await create_session(user_id, title, db, session_id)
src/api/selectors/chat/get_messages.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from typing import List
3
+ from sqlmodel import select
4
+ from sqlmodel.ext.asyncio.session import AsyncSession
5
+ from src.api.models.chat import ChatMessage
6
+
7
+ async def get_chat_messages_by_session(
8
+ session_id: uuid.UUID,
9
+ db: AsyncSession
10
+ ) -> List[ChatMessage]:
11
+ query = (
12
+ select(ChatMessage)
13
+ .where(ChatMessage.session_id == session_id)
14
+ .order_by(ChatMessage.timestamp)
15
+ )
16
+ result = await db.exec(query)
17
+ return result.all()
src/api/selectors/chat/get_session.py CHANGED
@@ -1,17 +1,42 @@
1
  from sqlmodel import select
2
  from sqlmodel.ext.asyncio.session import AsyncSession
3
  from src.api.models.chat import ChatSession
4
- from typing import Optional
5
  import uuid
6
 
7
- async def get_chat_session(
8
  session_id: uuid.UUID,
9
  user_id: Optional[int],
10
  db: AsyncSession
11
  ) -> Optional[ChatSession]:
12
  query = select(ChatSession).where(ChatSession.id == session_id)
13
- if user_id:
14
- query = query.where(ChatSession.user_id == user_id)
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  result = await db.exec(query)
17
- return result.first()
 
 
1
  from sqlmodel import select
2
  from sqlmodel.ext.asyncio.session import AsyncSession
3
  from src.api.models.chat import ChatSession
4
+ from typing import Optional, List
5
  import uuid
6
 
7
+ async def get_chat_session_by_id(
8
  session_id: uuid.UUID,
9
  user_id: Optional[int],
10
  db: AsyncSession
11
  ) -> Optional[ChatSession]:
12
  query = select(ChatSession).where(ChatSession.id == session_id)
13
+ result = await db.exec(query)
14
+ session = result.first()
15
+ if not session:
16
+ return None
17
+ if session.user_id is not None:
18
+ if session.user_id != user_id or user_id is None:
19
+ return None
20
+ else:
21
+ if user_id is not None:
22
+ return None
23
+ return session
24
+
25
 
26
+ async def get_chat_sessions(
27
+ user_id: int,
28
+ db: AsyncSession,
29
+ limit: int = 20,
30
+ offset: int = 0
31
+ ) -> List[ChatSession]:
32
+ query = (
33
+ select(ChatSession)
34
+ .where(ChatSession.user_id == user_id)
35
+ .where(ChatSession.is_active == True)
36
+ .order_by(ChatSession.updated_at.desc())
37
+ .limit(limit)
38
+ .offset(offset)
39
+ )
40
  result = await db.exec(query)
41
+ return result.all()
42
+
src/api/selectors/chat/save_messages.py CHANGED
@@ -1,4 +1,5 @@
1
  from sqlmodel.ext.asyncio.session import AsyncSession
 
2
  from src.api.models.chat import ChatSession, ChatMessage, ChatMessageRole
3
  from llama_index.core.llms import ChatMessage as LlamaChatMessage, MessageRole
4
  from typing import List
@@ -16,19 +17,27 @@ def llama_to_db_role(llama_role: MessageRole) -> ChatMessageRole:
16
  async def save_messages_to_db(
17
  session: ChatSession,
18
  messages: List[LlamaChatMessage],
19
- db: AsyncSession
 
20
  ) -> List[ChatMessage]:
 
 
 
 
 
 
21
  db_messages = []
22
  now = datetime.now(timezone.utc)
23
- for msg in messages:
24
  db_message = ChatMessage(
25
- content=msg.content,
26
  role=llama_to_db_role(msg.role),
27
  timestamp=now,
28
  session_id=session.id,
29
  )
30
  db_messages.append(db_message)
31
  db.add(db_message)
 
32
  await db.commit()
33
  for msg in db_messages:
34
  await db.refresh(msg)
 
1
  from sqlmodel.ext.asyncio.session import AsyncSession
2
+ from sqlmodel import select, delete
3
  from src.api.models.chat import ChatSession, ChatMessage, ChatMessageRole
4
  from llama_index.core.llms import ChatMessage as LlamaChatMessage, MessageRole
5
  from typing import List
 
17
  async def save_messages_to_db(
18
  session: ChatSession,
19
  messages: List[LlamaChatMessage],
20
+ db: AsyncSession,
21
+ replace_existing: bool = True
22
  ) -> List[ChatMessage]:
23
+ if replace_existing:
24
+ await db.exec(
25
+ delete(ChatMessage).where(ChatMessage.session_id == session.id)
26
+ )
27
+ await db.commit()
28
+
29
  db_messages = []
30
  now = datetime.now(timezone.utc)
31
+ for i, msg in enumerate(messages):
32
  db_message = ChatMessage(
33
+ content=str(msg.content),
34
  role=llama_to_db_role(msg.role),
35
  timestamp=now,
36
  session_id=session.id,
37
  )
38
  db_messages.append(db_message)
39
  db.add(db_message)
40
+
41
  await db.commit()
42
  for msg in db_messages:
43
  await db.refresh(msg)
src/api/services/chat_history_service.py CHANGED
@@ -4,9 +4,10 @@ from llama_index.core.llms import ChatMessage, MessageRole
4
  from sqlmodel.ext.asyncio.session import AsyncSession
5
  from src.api.models.user import User
6
  from src.api.models.chat import ChatSession
7
- from src.api.selectors.chat.get_session import get_chat_session
8
  from src.api.selectors.chat.create_session import create_sessions
9
  from src.api.selectors.chat.save_messages import save_messages_to_db
 
10
  import uuid
11
  import logging
12
  from src.core.settings import settings
@@ -28,11 +29,30 @@ class ChatHistoryService:
28
  async def get_messages(
29
  self,
30
  session_id: uuid.UUID,
31
- user: Optional[User]
 
32
  ) -> List[ChatMessage]:
33
  try:
34
  key = self._get_redis_key(session_id, user)
35
- return await self.redis_store.aget_messages(key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  except Exception as e:
37
  logger.warning(f"Failed to get messages from Redis: {e}. Returning empty list.")
38
  return []
@@ -45,7 +65,7 @@ class ChatHistoryService:
45
  ) -> None:
46
  try:
47
  key = self._get_redis_key(session_id, user)
48
- await self.redis_store.async_add_message(key, message)
49
  except Exception as e:
50
  logger.warning(f"Failed to add message to Redis: {e}. Message not stored.")
51
 
@@ -73,6 +93,8 @@ class ChatHistoryService:
73
  logger.warning(f"Failed to delete session from Redis: {e}.")
74
  return None
75
 
 
 
76
  async def sync_to_postgres(
77
  self,
78
  session_id: uuid.UUID,
@@ -80,12 +102,12 @@ class ChatHistoryService:
80
  db: AsyncSession,
81
  title: Optional[str] = None
82
  ) -> Optional[ChatSession]:
83
- messages = await self.get_messages(session_id, user)
84
 
85
  if not messages:
86
  return None
87
 
88
- db_session = await get_chat_session(session_id, user.id, db)
89
 
90
  if not db_session:
91
  if not title:
@@ -99,16 +121,15 @@ class ChatHistoryService:
99
  db_session = await create_sessions(
100
  user_id=user.id,
101
  title=title,
102
- db=db
 
103
  )
104
- db_session.id = session_id
105
- await db.commit()
106
- await db.refresh(db_session)
107
 
108
  await save_messages_to_db(
109
  session=db_session,
110
  messages=messages,
111
- db=db
 
112
  )
113
 
114
  return db_session
 
4
  from sqlmodel.ext.asyncio.session import AsyncSession
5
  from src.api.models.user import User
6
  from src.api.models.chat import ChatSession
7
+ from src.api.selectors.chat.get_session import get_chat_session_by_id
8
  from src.api.selectors.chat.create_session import create_sessions
9
  from src.api.selectors.chat.save_messages import save_messages_to_db
10
+ from src.api.selectors.chat.get_messages import get_chat_messages_by_session
11
  import uuid
12
  import logging
13
  from src.core.settings import settings
 
29
  async def get_messages(
30
  self,
31
  session_id: uuid.UUID,
32
+ user: Optional[User],
33
+ db: Optional[AsyncSession] = None
34
  ) -> List[ChatMessage]:
35
  try:
36
  key = self._get_redis_key(session_id, user)
37
+ messages = await self.redis_store.aget_messages(key)
38
+ if messages:
39
+ return messages
40
+ if db:
41
+ session = await get_chat_session_by_id(session_id, user.id if user else None, db)
42
+ if not session:
43
+ return []
44
+ db_messages = await get_chat_messages_by_session(session_id, db)
45
+ if db_messages:
46
+ llama_messages = [
47
+ ChatMessage(
48
+ role=message.role,
49
+ content=message.content)
50
+ for message in db_messages
51
+ ]
52
+ await self.redis_store.aset_messages(key, llama_messages)
53
+ return llama_messages
54
+
55
+ return []
56
  except Exception as e:
57
  logger.warning(f"Failed to get messages from Redis: {e}. Returning empty list.")
58
  return []
 
65
  ) -> None:
66
  try:
67
  key = self._get_redis_key(session_id, user)
68
+ await self.redis_store.aadd_message(key, message)
69
  except Exception as e:
70
  logger.warning(f"Failed to add message to Redis: {e}. Message not stored.")
71
 
 
93
  logger.warning(f"Failed to delete session from Redis: {e}.")
94
  return None
95
 
96
+
97
+
98
  async def sync_to_postgres(
99
  self,
100
  session_id: uuid.UUID,
 
102
  db: AsyncSession,
103
  title: Optional[str] = None
104
  ) -> Optional[ChatSession]:
105
+ messages = await self.get_messages(session_id, user, db)
106
 
107
  if not messages:
108
  return None
109
 
110
+ db_session = await get_chat_session_by_id(session_id, user.id, db)
111
 
112
  if not db_session:
113
  if not title:
 
121
  db_session = await create_sessions(
122
  user_id=user.id,
123
  title=title,
124
+ db=db,
125
+ session_id=session_id
126
  )
 
 
 
127
 
128
  await save_messages_to_db(
129
  session=db_session,
130
  messages=messages,
131
+ db=db,
132
+ replace_existing=True
133
  )
134
 
135
  return db_session
src/clients/postgres.py CHANGED
@@ -8,6 +8,10 @@ engine = create_async_engine(
8
  settings.async_database_url,
9
  echo=False,
10
  future=True,
 
 
 
 
11
  )
12
 
13
  async_session = sessionmaker(
 
8
  settings.async_database_url,
9
  echo=False,
10
  future=True,
11
+ pool_pre_ping=True,
12
+ pool_recycle=3600,
13
+ pool_size=15,
14
+ max_overflow=25,
15
  )
16
 
17
  async_session = sessionmaker(