feat: ai chat endpoint added
Browse files- Backend/app/api/v1/api.py +7 -1
- Backend/app/api/v1/endpoints/notes.py +14 -11
- Backend/app/llm.py +50 -2
- Backend/app/schema/__init__.py +2 -2
- Backend/app/schema/models.py +4 -7
Backend/app/api/v1/api.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from fastapi import APIRouter
|
| 2 |
-
from app.api.v1.endpoints import auth, quiz
|
| 3 |
|
| 4 |
api_router = APIRouter()
|
| 5 |
|
|
@@ -17,3 +17,9 @@ api_router.include_router(
|
|
| 17 |
prefix="/quiz",
|
| 18 |
tags=["quiz"]
|
| 19 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import APIRouter
|
| 2 |
+
from app.api.v1.endpoints import auth, quiz, notes
|
| 3 |
|
| 4 |
api_router = APIRouter()
|
| 5 |
|
|
|
|
| 17 |
prefix="/quiz",
|
| 18 |
tags=["quiz"]
|
| 19 |
)
|
| 20 |
+
|
| 21 |
+
api_router.include_router(
|
| 22 |
+
notes.router,
|
| 23 |
+
prefix="/notes",
|
| 24 |
+
tags=["notes"]
|
| 25 |
+
)
|
Backend/app/api/v1/endpoints/notes.py
CHANGED
|
@@ -1,23 +1,26 @@
|
|
| 1 |
from fastapi import APIRouter, Depends, HTTPException, status
|
| 2 |
-
from
|
| 3 |
from app.models import User
|
| 4 |
-
from app.api.deps import get_db, get_current_user
|
| 5 |
-
from app.schema import
|
| 6 |
from .prompts import SYSTEM_PROMPT
|
| 7 |
-
from
|
| 8 |
-
from chromadb.api.models.Collection import Collection # Import Collection type
|
| 9 |
-
from app.api.deps import get_chroma_collection
|
| 10 |
-
from app.llm import call_llm
|
| 11 |
import uuid
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
|
| 15 |
router = APIRouter(prefix="/notes")
|
| 16 |
|
| 17 |
-
@router.post("/
|
| 18 |
async def ai_chat(
|
| 19 |
-
Input_model:
|
| 20 |
-
|
| 21 |
current_user: User = Depends(get_current_user)
|
| 22 |
):
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import APIRouter, Depends, HTTPException, status
|
| 2 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 3 |
from app.models import User
|
| 4 |
+
from app.api.deps import get_db, get_current_user
|
| 5 |
+
from app.schema import ChatMessage, AI_chat_input
|
| 6 |
from .prompts import SYSTEM_PROMPT
|
| 7 |
+
from app.llm import call_llm, stream_chat
|
|
|
|
|
|
|
|
|
|
| 8 |
import uuid
|
| 9 |
+
from fastapi.responses import StreamingResponse
|
| 10 |
|
| 11 |
|
| 12 |
|
| 13 |
router = APIRouter(prefix="/notes")
|
| 14 |
|
| 15 |
+
@router.post("/stram_chat", response_class=StreamingResponse)
|
| 16 |
async def ai_chat(
|
| 17 |
+
Input_model: AI_chat_input,
|
| 18 |
+
# db: AsyncSession = Depends(get_db),
|
| 19 |
current_user: User = Depends(get_current_user)
|
| 20 |
):
|
| 21 |
+
messages_dict = [msg.model_dump() for msg in Input_model.messages]
|
| 22 |
+
|
| 23 |
+
return StreamingResponse(
|
| 24 |
+
stream_chat(messages_dict, Input_model.context),
|
| 25 |
+
media_type="text/plain"
|
| 26 |
+
)
|
Backend/app/llm.py
CHANGED
|
@@ -4,8 +4,9 @@ from pydantic import BaseModel, Field
|
|
| 4 |
from typing import List, Optional, Any
|
| 5 |
from app.schema.models import QuizOutput, QuizQuestion
|
| 6 |
from app.config import settings
|
|
|
|
| 7 |
|
| 8 |
-
client =
|
| 9 |
base_url="https://api.groq.com/openai/v1",
|
| 10 |
api_key=settings.GROQ_API_KEY
|
| 11 |
)
|
|
@@ -32,4 +33,51 @@ async def call_llm(prompt:str):
|
|
| 32 |
|
| 33 |
except Exception as e:
|
| 34 |
print(f"Error calling LiteLLM/Gemini: {e}")
|
| 35 |
-
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from typing import List, Optional, Any
|
| 5 |
from app.schema.models import QuizOutput, QuizQuestion
|
| 6 |
from app.config import settings
|
| 7 |
+
from openai import AsyncOpenAI
|
| 8 |
|
| 9 |
+
client = AsyncOpenAI(
|
| 10 |
base_url="https://api.groq.com/openai/v1",
|
| 11 |
api_key=settings.GROQ_API_KEY
|
| 12 |
)
|
|
|
|
| 33 |
|
| 34 |
except Exception as e:
|
| 35 |
print(f"Error calling LiteLLM/Gemini: {e}")
|
| 36 |
+
raise e
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
async def stream_chat(messages:List[dict], context:str):
|
| 41 |
+
system_instruction = {
|
| 42 |
+
"role": "system",
|
| 43 |
+
"content": "You are a helpful AI assistant. Answer the user's question strictly based on the provided context."
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
conversation_history = [msg.copy() for msg in messages]
|
| 47 |
+
|
| 48 |
+
if conversation_history and conversation_history[-1]['role'] == 'user':
|
| 49 |
+
last_user_msg = conversation_history[-1]
|
| 50 |
+
# Rewrite the content to: Context + \n\n + Question
|
| 51 |
+
last_user_msg['content'] = (
|
| 52 |
+
f"Here is the context/notes you must use:\n"
|
| 53 |
+
f"---------------------\n"
|
| 54 |
+
f"{context}\n"
|
| 55 |
+
f"---------------------\n\n"
|
| 56 |
+
f"User Question: {last_user_msg['content']}"
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
# Fallback: If for some reason there is no user message, add one.
|
| 60 |
+
conversation_history.append({
|
| 61 |
+
"role": "user",
|
| 62 |
+
"content": f"Context:\n{context}\n\nPlease analyze this."
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
# 3. Combine System + Modified User History
|
| 66 |
+
full_history = [system_instruction] + conversation_history
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
# Ensure you are using the async_client initialized earlier
|
| 70 |
+
stream = await client.chat.completions.create(
|
| 71 |
+
model="openai/gpt-oss-20b", # Recommended for speed/quality on Groq
|
| 72 |
+
messages=full_history,
|
| 73 |
+
temperature=0.7,
|
| 74 |
+
stream=True
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
async for chunk in stream:
|
| 78 |
+
if chunk.choices[0].delta.content:
|
| 79 |
+
yield chunk.choices[0].delta.content
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Error in chat stream: {e}")
|
| 83 |
+
yield f"Error: {str(e)}"
|
Backend/app/schema/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
from app.schema.models import UserCreate, Token, LoginRequest, Quiz_input, QuizOutput, IngestRequest
|
| 2 |
|
| 3 |
-
__all__ = ["UserCreate", "Token", "LoginRequest", "Quiz_input", "QuizOutput", "IngestRequest"]
|
|
|
|
| 1 |
+
from app.schema.models import UserCreate, Token, LoginRequest, Quiz_input, QuizOutput, IngestRequest, ChatMessage, AI_chat_input
|
| 2 |
|
| 3 |
+
__all__ = ["UserCreate", "Token", "LoginRequest", "Quiz_input", "QuizOutput", "IngestRequest", "ChatMessage", "AI_chat_input"]
|
Backend/app/schema/models.py
CHANGED
|
@@ -53,15 +53,12 @@ class IngestRequest(BaseModel):
|
|
| 53 |
# #--------Notes models--------#
|
| 54 |
|
| 55 |
class ChatMessage(BaseModel):
|
| 56 |
-
role: Literal[
|
| 57 |
-
content: str=Field(..., min_length=1, description="
|
| 58 |
|
| 59 |
class AI_chat_input(BaseModel):
|
| 60 |
-
messages
|
| 61 |
-
|
| 62 |
-
min_length=1,
|
| 63 |
-
description="The complete conversation history (list of messages) to send to the LLM."
|
| 64 |
-
)
|
| 65 |
session_id: str | None = Field(
|
| 66 |
None, description="The unique ID of the current chat session (optional)."
|
| 67 |
)
|
|
|
|
| 53 |
# #--------Notes models--------#
|
| 54 |
|
| 55 |
class ChatMessage(BaseModel):
|
| 56 |
+
role: Literal["user", "assistant", "system"] = Field(..., description="Role of the message sender")
|
| 57 |
+
content: str = Field(..., min_length=1, description="Message content")
|
| 58 |
|
| 59 |
class AI_chat_input(BaseModel):
|
| 60 |
+
messages: List[ChatMessage] = Field(..., description="Conversation history")
|
| 61 |
+
context: str = Field(..., description="The content of the note/document to chat about")
|
|
|
|
|
|
|
|
|
|
| 62 |
session_id: str | None = Field(
|
| 63 |
None, description="The unique ID of the current chat session (optional)."
|
| 64 |
)
|