Commit ·
fa152ae
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +23 -0
- cbh/__init__.py +94 -0
- cbh/api/account/__init__.py +14 -0
- cbh/api/account/dto.py +30 -0
- cbh/api/account/models.py +71 -0
- cbh/api/account/utils.py +0 -0
- cbh/api/account/views.py +20 -0
- cbh/api/ari/__init__.py +5 -0
- cbh/api/ari/db_requests.py +41 -0
- cbh/api/ari/dto.py +7 -0
- cbh/api/ari/schemas.py +15 -0
- cbh/api/ari/services/__init__.py +3 -0
- cbh/api/ari/services/agent/__init__.py +8 -0
- cbh/api/ari/services/agent/agent.py +113 -0
- cbh/api/ari/services/agent/handler.py +37 -0
- cbh/api/ari/services/agent/prompt.py +68 -0
- cbh/api/ari/services/agent/tools.py +69 -0
- cbh/api/ari/services/workflows.py +21 -0
- cbh/api/ari/utils.py +207 -0
- cbh/api/ari/views.py +99 -0
- cbh/api/chats/__init__.py +5 -0
- cbh/api/chats/db_requests.py +45 -0
- cbh/api/chats/models.py +12 -0
- cbh/api/chats/schemas.py +13 -0
- cbh/api/chats/services/__init__.py +3 -0
- cbh/api/chats/services/prompts.py +57 -0
- cbh/api/chats/services/workflows.py +23 -0
- cbh/api/chats/views.py +49 -0
- cbh/api/common/db_requests.py +196 -0
- cbh/api/common/dto.py +149 -0
- cbh/api/common/schemas.py +121 -0
- cbh/api/common/utils.py +244 -0
- cbh/api/messages/__init__.py +5 -0
- cbh/api/messages/db_requests.py +23 -0
- cbh/api/messages/dto.py +0 -0
- cbh/api/messages/models.py +13 -0
- cbh/api/messages/utils.py +0 -0
- cbh/api/messages/views.py +28 -0
- cbh/api/platforms/__init__.py +5 -0
- cbh/api/platforms/db_requests.py +0 -0
- cbh/api/platforms/dto.py +44 -0
- cbh/api/platforms/models.py +23 -0
- cbh/api/platforms/utils.py +0 -0
- cbh/api/platforms/views.py +0 -0
- cbh/api/security/__init__.py +11 -0
- cbh/api/security/db_requests.py +186 -0
- cbh/api/security/dto.py +48 -0
- cbh/api/security/models.py +15 -0
- cbh/api/security/schemas.py +48 -0
- cbh/api/security/services/__init__.py +11 -0
.gitignore
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
env/
|
| 3 |
+
venv/
|
| 4 |
+
.venv/
|
| 5 |
+
.idea/
|
| 6 |
+
*.log
|
| 7 |
+
*.egg-info/
|
| 8 |
+
pip-wheel-EntityData/
|
| 9 |
+
.env
|
| 10 |
+
.DS_Store
|
| 11 |
+
static/
|
| 12 |
+
test.py
|
| 13 |
+
rsa_key.p8
|
| 14 |
+
rsa_key.pub
|
| 15 |
+
aws.pem
|
| 16 |
+
.vscode/
|
| 17 |
+
data/
|
| 18 |
+
*.csv
|
| 19 |
+
test.json
|
| 20 |
+
voiceagentcbh.pem
|
| 21 |
+
*.pem
|
| 22 |
+
download
|
| 23 |
+
investigation
|
cbh/__init__.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=C0415
|
| 2 |
+
"""
|
| 3 |
+
ClipboardHealthAI application package.
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI
|
| 8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
| 10 |
+
|
| 11 |
+
from cbh.core.wrappers import CbhResponseWrapper, ErrorCbhResponse
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_app() -> FastAPI:
|
| 15 |
+
"""
|
| 16 |
+
Create and configure the FastAPI application.
|
| 17 |
+
"""
|
| 18 |
+
app = FastAPI(docs_url="/api/docs", openapi_url="/api/openapi.json")
|
| 19 |
+
|
| 20 |
+
from cbh.api.account import account_router
|
| 21 |
+
|
| 22 |
+
app.include_router(account_router, tags=["account"])
|
| 23 |
+
|
| 24 |
+
from cbh.api.calls import calls_router
|
| 25 |
+
|
| 26 |
+
app.include_router(calls_router, tags=["calls"])
|
| 27 |
+
|
| 28 |
+
from cbh.api.reports import reports_router
|
| 29 |
+
|
| 30 |
+
app.include_router(reports_router, tags=["reports"])
|
| 31 |
+
|
| 32 |
+
from cbh.api.reps import reps_router
|
| 33 |
+
|
| 34 |
+
app.include_router(reps_router, tags=["reps"])
|
| 35 |
+
|
| 36 |
+
from cbh.api.security import security_router
|
| 37 |
+
|
| 38 |
+
app.include_router(security_router, tags=["security"])
|
| 39 |
+
|
| 40 |
+
from cbh.api.userinsights import userinsights_router
|
| 41 |
+
|
| 42 |
+
app.include_router(userinsights_router, tags=["userinsights"])
|
| 43 |
+
|
| 44 |
+
app.add_middleware(
|
| 45 |
+
CORSMiddleware,
|
| 46 |
+
allow_origin_regex=r"https?://([a-z0-9-]+\.)?(localhost|trainwitharena|cbhexp\.com)(:\d+)?",
|
| 47 |
+
allow_credentials=True,
|
| 48 |
+
allow_methods=["*"],
|
| 49 |
+
allow_headers=["*"],
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
@app.exception_handler(StarletteHTTPException)
|
| 53 |
+
async def http_exception_handler(_, exc):
|
| 54 |
+
"""
|
| 55 |
+
Handle HTTP exceptions and convert them to standardized error responses.
|
| 56 |
+
"""
|
| 57 |
+
return CbhResponseWrapper(
|
| 58 |
+
data=None, successful=False, error=ErrorCbhResponse(message=str(exc.detail))
|
| 59 |
+
).response(exc.status_code)
|
| 60 |
+
|
| 61 |
+
@app.on_event("startup")
|
| 62 |
+
async def startup_event():
|
| 63 |
+
"""
|
| 64 |
+
Execute startup tasks when the application starts.
|
| 65 |
+
"""
|
| 66 |
+
from cbh.api.calls.services import run_call_listener
|
| 67 |
+
#
|
| 68 |
+
asyncio.create_task(run_call_listener())
|
| 69 |
+
|
| 70 |
+
@app.get("/api/health")
|
| 71 |
+
async def health():
|
| 72 |
+
"""
|
| 73 |
+
Health check endpoint for container orchestration and monitoring.
|
| 74 |
+
"""
|
| 75 |
+
try:
|
| 76 |
+
return {"status": "healthy", "database": "connected"}
|
| 77 |
+
except Exception as e:
|
| 78 |
+
return {"status": "unhealthy", "database": "disconnected", "error": str(e)}
|
| 79 |
+
|
| 80 |
+
@app.get("/health")
|
| 81 |
+
async def root():
|
| 82 |
+
"""
|
| 83 |
+
Root endpoint for the application.
|
| 84 |
+
"""
|
| 85 |
+
return {"message": "hello!"}
|
| 86 |
+
|
| 87 |
+
@app.get("/api/test")
|
| 88 |
+
async def root():
|
| 89 |
+
"""
|
| 90 |
+
Root endpoint for the application.
|
| 91 |
+
"""
|
| 92 |
+
return {"message": "hi hello!"}
|
| 93 |
+
|
| 94 |
+
return app
|
cbh/api/account/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Account module initialization.
|
| 3 |
+
|
| 4 |
+
This module defines the FastAPI router for account API endpoints
|
| 5 |
+
and imports related views for account management.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from fastapi import APIRouter
|
| 9 |
+
|
| 10 |
+
account_router = APIRouter(
|
| 11 |
+
prefix="/api/account",
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from . import views # noqa # pylint: disable=C0413
|
cbh/api/account/dto.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Account DTOs.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AccountType(Enum):
|
| 11 |
+
"""
|
| 12 |
+
Enum for account types.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
USER = 1
|
| 16 |
+
ADMIN = 2
|
| 17 |
+
OWNER = 3
|
| 18 |
+
SUPER_ADMIN = 4
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RegistrationType(Enum):
|
| 22 |
+
"""
|
| 23 |
+
Enum for registration types.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
ORGANIC = 1
|
| 27 |
+
GOOGLE = 2
|
| 28 |
+
GITHUB = 3
|
| 29 |
+
APPLE = 4
|
| 30 |
+
|
cbh/api/account/models.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Account models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
from passlib.context import CryptContext
|
| 8 |
+
from pydantic import Field, field_validator
|
| 9 |
+
|
| 10 |
+
from cbh.api.account.dto import AccountType, RegistrationType
|
| 11 |
+
from cbh.core.database import MongoBaseModel, MongoBaseShortenModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AccountModel(MongoBaseModel):
|
| 15 |
+
"""
|
| 16 |
+
Account model class.
|
| 17 |
+
|
| 18 |
+
This class represents a user account in the system.
|
| 19 |
+
It includes fields for email, password, and timestamps for creation and update.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
name: str | None = None
|
| 23 |
+
email: str
|
| 24 |
+
password: str | None = Field(exclude=True, default=None)
|
| 25 |
+
|
| 26 |
+
accountType: AccountType = Field(default=AccountType.USER)
|
| 27 |
+
registrationType: RegistrationType | None = Field(
|
| 28 |
+
default=RegistrationType.ORGANIC, exclude=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
| 32 |
+
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
| 33 |
+
|
| 34 |
+
@field_validator("password", mode="before", check_fields=False)
|
| 35 |
+
@classmethod
|
| 36 |
+
def set_password_hash(cls, v: str | None) -> str | None:
|
| 37 |
+
"""
|
| 38 |
+
Set the password hash.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
v (str): The password to hash.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
str: The hashed password.
|
| 45 |
+
"""
|
| 46 |
+
if isinstance(v, str) and not v.startswith("$2b$"):
|
| 47 |
+
return CryptContext(schemes=["bcrypt"], deprecated="auto").hash(v)
|
| 48 |
+
return v
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Config: # pylint: disable=R0903
|
| 52 |
+
"""
|
| 53 |
+
Config for the AccountModel class.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
arbitrary_types_allowed = True
|
| 57 |
+
populate_by_name = True
|
| 58 |
+
json_encoders = {datetime: lambda dt: dt.isoformat()}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class AccountShorten(MongoBaseShortenModel):
|
| 62 |
+
"""
|
| 63 |
+
Account shorten model.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
id: str
|
| 67 |
+
name: str | None = None
|
| 68 |
+
email: str
|
| 69 |
+
pictureUrl: str | None = None
|
| 70 |
+
|
| 71 |
+
accountType: AccountType
|
cbh/api/account/utils.py
ADDED
|
File without changes
|
cbh/api/account/views.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Account views module.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import Depends
|
| 6 |
+
|
| 7 |
+
from cbh.api.account import account_router
|
| 8 |
+
from cbh.api.account.models import AccountModel
|
| 9 |
+
from cbh.core.security import PermissionDependency
|
| 10 |
+
from cbh.core.wrappers import CbhResponseWrapper
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@account_router.get("")
|
| 14 |
+
async def get_own_account(
|
| 15 |
+
account: AccountModel = Depends(PermissionDependency()),
|
| 16 |
+
) -> CbhResponseWrapper[AccountModel]:
|
| 17 |
+
"""
|
| 18 |
+
Get own account.
|
| 19 |
+
"""
|
| 20 |
+
return CbhResponseWrapper(data=account)
|
cbh/api/ari/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
|
| 3 |
+
ari_router = APIRouter(prefix="/api/ari", tags=["ari"])
|
| 4 |
+
|
| 5 |
+
from . import views
|
cbh/api/ari/db_requests.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from pymongo import ReturnDocument
|
| 3 |
+
from cbh.api.ari.dto import Author
|
| 4 |
+
from cbh.api.ari.models import ChatModel, MessageModel
|
| 5 |
+
from cbh.api.account.models import AccountModel, AccountShorten
|
| 6 |
+
from cbh.api.ari.schemas import ChatFilter, CreateMessageRequest
|
| 7 |
+
from cbh.api.common.db_requests import get_all_objs
|
| 8 |
+
from cbh.api.common.schemas import FilterRequest
|
| 9 |
+
from cbh.core.config import settings
|
| 10 |
+
from cbh.core.wrappers import background_task
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
async def truncate_message_history(message_id: str, ids_to_delete: list[str]) -> None:
|
| 14 |
+
await settings.DB_CLIENT.messages.delete_many({"id": {"$in": [message_id, *ids_to_delete]}})
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@background_task()
|
| 18 |
+
async def add_messages_obj(
|
| 19 |
+
chat_id: str,
|
| 20 |
+
account_id: str,
|
| 21 |
+
request: CreateMessageRequest,
|
| 22 |
+
message_args: tuple[str, str, datetime, datetime],
|
| 23 |
+
) -> None:
|
| 24 |
+
user_message = MessageModel(
|
| 25 |
+
id=request.messageId,
|
| 26 |
+
chatId=chat_id,
|
| 27 |
+
accountId=account_id,
|
| 28 |
+
role=Author.HUMAN,
|
| 29 |
+
content=request.content,
|
| 30 |
+
datetimeInserted=message_args[2],
|
| 31 |
+
)
|
| 32 |
+
assistant_message = MessageModel(
|
| 33 |
+
id=message_args[1],
|
| 34 |
+
chatId=chat_id,
|
| 35 |
+
accountId=account_id,
|
| 36 |
+
role=Author.AI,
|
| 37 |
+
content=message_args[0],
|
| 38 |
+
datetimeInserted=message_args[3],
|
| 39 |
+
)
|
| 40 |
+
await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
|
| 41 |
+
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
cbh/api/ari/dto.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Author(str, Enum):
|
| 5 |
+
HUMAN = 'human'
|
| 6 |
+
AI = 'ai'
|
| 7 |
+
|
cbh/api/ari/schemas.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CreateMessageRequest(BaseModel):
|
| 6 |
+
content: str
|
| 7 |
+
messageId: Optional[str] = None
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TranscribeResponse(BaseModel):
|
| 11 |
+
"""
|
| 12 |
+
Response schema for the transcribe endpoint.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
text: str
|
cbh/api/ari/services/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .workflows import convert_audio_to_text
|
| 2 |
+
|
| 3 |
+
__all__ = ["convert_audio_to_text"]
|
cbh/api/ari/services/agent/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module contains the Ari Agent.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .agent import AriAgent
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
__all__ = ["AriAgent"]
|
cbh/api/ari/services/agent/agent.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=R0801
|
| 2 |
+
"""
|
| 3 |
+
This module contains the CBH Agent.
|
| 4 |
+
"""
|
| 5 |
+
import asyncio
|
| 6 |
+
|
| 7 |
+
from langchain_classic.agents import AgentExecutor, create_openai_tools_agent
|
| 8 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 9 |
+
|
| 10 |
+
from cbh.api.account.models import AccountModel
|
| 11 |
+
from cbh.api.ari.schemas import CreateMessageRequest
|
| 12 |
+
from cbh.api.ari.services.agent.handler import StreamingAgentCallbackHandler
|
| 13 |
+
from cbh.core.config import settings
|
| 14 |
+
from .prompt import ARI_PROMPT
|
| 15 |
+
from .tools import ScenarioAgentTools
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AriAgent:
|
| 19 |
+
"""
|
| 20 |
+
CBH Agent for handling schedule creation.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, account: AccountModel):
|
| 24 |
+
"""
|
| 25 |
+
Initialize the CBHAgent with an account. The agent is stateless with
|
| 26 |
+
respect to chat history — history is provided per call to `stream`.
|
| 27 |
+
"""
|
| 28 |
+
self.tools = ScenarioAgentTools.load_tools(account)
|
| 29 |
+
self.agent_executor = self._get_agent()
|
| 30 |
+
|
| 31 |
+
def _get_agent(self):
|
| 32 |
+
"""
|
| 33 |
+
Get the agent_pd instance.
|
| 34 |
+
"""
|
| 35 |
+
return AgentExecutor(
|
| 36 |
+
agent=create_openai_tools_agent(
|
| 37 |
+
llm=self._get_agent_model(),
|
| 38 |
+
tools=self.tools,
|
| 39 |
+
prompt=self._load_system_prompt(),
|
| 40 |
+
),
|
| 41 |
+
tools=self.tools,
|
| 42 |
+
verbose=True,
|
| 43 |
+
return_intermediate_steps=True,
|
| 44 |
+
max_iterations=100,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def _get_agent_model(self):
|
| 48 |
+
"""
|
| 49 |
+
Get the language model used by the agent_pd.
|
| 50 |
+
"""
|
| 51 |
+
return settings.get_llm(
|
| 52 |
+
model="gpt-5.4", reasoning_effort="medium", reasoning_summary="auto"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def _load_system_prompt(self) -> ChatPromptTemplate:
|
| 56 |
+
"""
|
| 57 |
+
Load the system prompt from file.
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
return ChatPromptTemplate.from_messages(
|
| 61 |
+
[
|
| 62 |
+
("system", ARI_PROMPT),
|
| 63 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
| 64 |
+
("human", "{content}"),
|
| 65 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
raise Exception( # pylint: disable=W0719,W0707
|
| 70 |
+
f"Failed to load system prompt: {str(e)}"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
async def stream(
|
| 74 |
+
self,
|
| 75 |
+
message_history: list[dict],
|
| 76 |
+
request: CreateMessageRequest,
|
| 77 |
+
stop_event: asyncio.Event = None,
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Stream the agent's response to the client.
|
| 81 |
+
"""
|
| 82 |
+
queue = []
|
| 83 |
+
|
| 84 |
+
async def send(data):
|
| 85 |
+
queue.append(data)
|
| 86 |
+
|
| 87 |
+
handler = StreamingAgentCallbackHandler(send)
|
| 88 |
+
|
| 89 |
+
task = asyncio.create_task(
|
| 90 |
+
self.agent_executor.ainvoke(
|
| 91 |
+
{
|
| 92 |
+
"content": request.content,
|
| 93 |
+
"chat_history": message_history,
|
| 94 |
+
},
|
| 95 |
+
config={"callbacks": [handler]},
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
while not task.done() or queue:
|
| 100 |
+
if stop_event and stop_event.is_set():
|
| 101 |
+
task.cancel()
|
| 102 |
+
try:
|
| 103 |
+
await task
|
| 104 |
+
except asyncio.CancelledError:
|
| 105 |
+
pass
|
| 106 |
+
break
|
| 107 |
+
while queue:
|
| 108 |
+
data = queue.pop(0)
|
| 109 |
+
yield data
|
| 110 |
+
await asyncio.sleep(0.01)
|
| 111 |
+
|
| 112 |
+
if not task.cancelled():
|
| 113 |
+
await task
|
cbh/api/ari/services/agent/handler.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pydash
|
| 2 |
+
from langchain_core.callbacks import AsyncCallbackHandler
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class StreamingAgentCallbackHandler(AsyncCallbackHandler): # pylint: disable=R0901
|
| 6 |
+
|
| 7 |
+
def __init__(self, send):
|
| 8 |
+
self.send = send
|
| 9 |
+
|
| 10 |
+
async def on_chat_model_start(self, *args, **kwargs):
|
| 11 |
+
await self.send({"type": "init", "content": ""})
|
| 12 |
+
|
| 13 |
+
async def on_tool_start(self, serialized, input_str, **kwargs):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
async def on_tool_end(self, output, **kwargs):
|
| 17 |
+
await self.send({"type": "tool_response", "content": output})
|
| 18 |
+
|
| 19 |
+
async def on_llm_new_token(self, token: str, **kwargs):
|
| 20 |
+
chunk = kwargs.get("chunk")
|
| 21 |
+
if chunk:
|
| 22 |
+
message = pydash.get(chunk.message.content, "[0]") or {}
|
| 23 |
+
if message.get("type") == "reasoning":
|
| 24 |
+
content = pydash.get(message, "summary[0].text")
|
| 25 |
+
if content:
|
| 26 |
+
await self.send({"type": "thinking", "content": content})
|
| 27 |
+
elif message.get("type") == "text":
|
| 28 |
+
await self.send({"type": "ai_token", "content": message.get("text")})
|
| 29 |
+
|
| 30 |
+
async def on_chain_end(self, outputs, **kwargs):
|
| 31 |
+
if (
|
| 32 |
+
isinstance(outputs, dict)
|
| 33 |
+
and "intermediate_steps" in outputs.keys()
|
| 34 |
+
and "output" in outputs.keys()
|
| 35 |
+
):
|
| 36 |
+
ai_message = ""
|
| 37 |
+
await self.send({"type": "ai", "content": ai_message})
|
cbh/api/ari/services/agent/prompt.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable-all
|
| 2 |
+
# flake8: noqa
|
| 3 |
+
"""
|
| 4 |
+
This module contains the prompt for the CBH agent.
|
| 5 |
+
"""
|
| 6 |
+
ARI_PROMPT = """You are **Ari**, a friendly analytics assistant. You help managers and admins understand their organization's sales training performance by answering questions using live data.
|
| 7 |
+
|
| 8 |
+
## Your Capabilities
|
| 9 |
+
You can retrieve the following data by calling the appropriate tool:
|
| 10 |
+
- **Accounts**: search users by name, role, or status
|
| 11 |
+
- **Dashboard overview**: session counts, average scores, engagement, and trends
|
| 12 |
+
- **Scenario performance**: per-scenario averages, top/bottom reps, most common mistakes
|
| 13 |
+
- **Skills averages**: org-wide or per-user breakdown across 6 skill dimensions (communication, active listening, conversation, objection handling, empathy, overall)
|
| 14 |
+
- **Leaderboard**: ranked list of reps by score, filterable by scenario and date
|
| 15 |
+
- **Attention needs**: reps who are inactive, scenarios with low scores, or scenarios never attempted
|
| 16 |
+
- **Scenario details**: statistics, skills breakdown, and leaderboard for a specific scenario
|
| 17 |
+
- **Team details**: statistics, skills breakdown, and leaderboard for a specific team
|
| 18 |
+
- **Insights**: top mistakes and achievements for a user, team, scenario, or the entire organization
|
| 19 |
+
|
| 20 |
+
## Rules
|
| 21 |
+
1. **Always call a tool before answering when user asks for data** — never guess or fabricate data. Every data point in your response must come from a tool result.
|
| 22 |
+
2. **Resolve names to IDs first** — if the user references a person, team, or scenario by name and you need an ID, call the appropriate search tool first (`search_account`, `search_teams`, `search_scenarios`), then call the target tool with the resolved ID.
|
| 23 |
+
3. **Use only the data returned** — summarize and present what the tool gives you. Do not add assumptions, projections, or context that was not returned.
|
| 24 |
+
4. **Stay on topic** — only answer questions about the organization's training performance. Politely decline anything unrelated.
|
| 25 |
+
5. **Never expose internals** — do not mention tool names, internal errors, raw JSON, or how you work under the hood.
|
| 26 |
+
6. **Handle empty results gracefully** — if a tool returns no data, tell the user clearly and suggest a next step (e.g. broaden the date range or check the spelling).
|
| 27 |
+
|
| 28 |
+
## Response Format
|
| 29 |
+
- Be friendly, concise, and conversational
|
| 30 |
+
- Keep replies short — 2–4 sentences or a brief list; avoid walls of text
|
| 31 |
+
- Present numbers clearly (e.g. "Average score: **78/100**")
|
| 32 |
+
- When showing multiple items (e.g. leaderboard, insights), use a short bullet list
|
| 33 |
+
- Do NOT dump raw data — always interpret and summarize it for the user
|
| 34 |
+
|
| 35 |
+
## Examples
|
| 36 |
+
|
| 37 |
+
**User:** How is the team doing overall?
|
| 38 |
+
**Ari:** *[calls retrieve_admin_intro_statistics]*
|
| 39 |
+
The team completed **142 sessions** this month (+12% vs last month), with an average score of **74/100**. **38 out of 45 members** were active. Want a breakdown by scenario or skill?
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
**User:** Who are the top performers?
|
| 44 |
+
**Ari:** *[calls retrieve_leaderboard with page_size=5]*
|
| 45 |
+
Here are your top 5 reps this month:
|
| 46 |
+
1. **Sarah M.** — 91
|
| 47 |
+
2. **James T.** — 88
|
| 48 |
+
3. **Priya K.** — 85
|
| 49 |
+
...
|
| 50 |
+
Want to see how they perform on a specific scenario?
|
| 51 |
+
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
**User:** What mistakes is John making?
|
| 55 |
+
**Ari:** *[calls search_account with search_term="John"]*
|
| 56 |
+
*[calls get_top_account_insights with account_id=..., type_="mistake"]*
|
| 57 |
+
John's most common mistake is **rushing the discovery phase** — he tends to jump to the pitch before fully understanding the customer's situation. Want tips on how to coach this?
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
**User:** Which scenarios need attention?
|
| 62 |
+
**Ari:** *[calls retrieve_attention_needs]*
|
| 63 |
+
Three scenarios haven't been attempted by anyone yet, and **"Cold Call Objection Drill"** has an average score below the threshold. I'd recommend assigning it and scheduling a coaching session.
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
**User:** Write me a poem.
|
| 68 |
+
**Ari:** I can only help with questions about your team's training performance. Ask me about scores, leaderboards, or coaching insights!"""
|
cbh/api/ari/services/agent/tools.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module contains the tools for the CBH Agent.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# pylint: disable=R0801,C0301
|
| 6 |
+
# flake8: noqa
|
| 7 |
+
from typing import Literal
|
| 8 |
+
from langchain_core.tools import StructuredTool
|
| 9 |
+
|
| 10 |
+
from cbh.api.account.models import AccountModel
|
| 11 |
+
from cbh.api.common.dto import OrderType
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
GET_TEAM_LEADERBOARD_DESCRIPTION = """Retrieve a ranked leaderboard of users within a specific team.
|
| 15 |
+
|
| 16 |
+
Use this tool when the user wants to compare individual performance within a team —
|
| 17 |
+
for example, to identify the best and worst performers or run intra-team competitions.
|
| 18 |
+
Call search_teams first to obtain the team_id if needed.
|
| 19 |
+
|
| 20 |
+
Parameters:
|
| 21 |
+
- team_id: The unique ID of the team (required).
|
| 22 |
+
- order: Sort direction. Accepted values:
|
| 23 |
+
"desc" — highest scores first (top performers at the top).
|
| 24 |
+
"asc" — lowest scores first (use to surface struggling members).
|
| 25 |
+
- page_size: Number of leaderboard positions to return (default 1). Increase to see more users.
|
| 26 |
+
|
| 27 |
+
Returns a ranked list of positions, each containing the user (id, name, picture)
|
| 28 |
+
and their Scores across all 6 skill dimensions."""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ScenarioAgentTools:
|
| 32 |
+
"""
|
| 33 |
+
Tools for the CBH Agent.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def create_team_leaderboard(account: AccountModel):
|
| 38 |
+
async def get_team_leaderboard(
|
| 39 |
+
team_id: str,
|
| 40 |
+
order: Literal["asc", "desc"],
|
| 41 |
+
page_size: int = 1,
|
| 42 |
+
):
|
| 43 |
+
sort_order = OrderType.DESCENDING.value
|
| 44 |
+
if order == "asc":
|
| 45 |
+
sort_order = OrderType.ASCENDING.value
|
| 46 |
+
results = await get_team_leaderboard_obj(
|
| 47 |
+
team_id=team_id,
|
| 48 |
+
account=account,
|
| 49 |
+
order=sort_order,
|
| 50 |
+
page_size=page_size,
|
| 51 |
+
page_index=0,
|
| 52 |
+
)
|
| 53 |
+
return {
|
| 54 |
+
"tool": "get_team_leaderboard",
|
| 55 |
+
"value": results.model_dump(mode="json"),
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
return get_team_leaderboard
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def load_tools(account: AccountModel) -> list[StructuredTool]:
|
| 62 |
+
return [
|
| 63 |
+
|
| 64 |
+
StructuredTool.from_function(
|
| 65 |
+
name="get_team_leaderboard",
|
| 66 |
+
description=GET_TEAM_LEADERBOARD_DESCRIPTION,
|
| 67 |
+
coroutine=ScenarioAgentTools.create_team_leaderboard(account),
|
| 68 |
+
),
|
| 69 |
+
]
|
cbh/api/ari/services/workflows.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
|
| 3 |
+
from fastapi import HTTPException
|
| 4 |
+
from cbh.core.config import settings
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
async def convert_audio_to_text(file: bytes, name: str) -> str:
|
| 8 |
+
"""
|
| 9 |
+
Convert an audio file to text using OpenAI's Whisper model.
|
| 10 |
+
"""
|
| 11 |
+
file_content = io.BytesIO(file)
|
| 12 |
+
file_content.name = name
|
| 13 |
+
|
| 14 |
+
transcription = await settings.OPENAI_CLIENT.audio.transcriptions.create(
|
| 15 |
+
file=file_content, model="whisper-1", language="en"
|
| 16 |
+
)
|
| 17 |
+
if isinstance(transcription, str):
|
| 18 |
+
return transcription
|
| 19 |
+
if transcription.text and isinstance(transcription.text, str):
|
| 20 |
+
return str(transcription.text)
|
| 21 |
+
raise HTTPException(status_code=500, detail="Failed to convert audio to text")
|
cbh/api/ari/utils.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import io
|
| 3 |
+
import statistics
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import AsyncGenerator, Callable
|
| 6 |
+
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, UploadFile
|
| 7 |
+
from cbh.api.ari.schemas import CreateMessageRequest
|
| 8 |
+
from cbh.api.ari.db_requests import add_messages_obj, truncate_message_history
|
| 9 |
+
from cbh.api.ari.models import MessageModel
|
| 10 |
+
from pydub import AudioSegment
|
| 11 |
+
|
| 12 |
+
async def send_exception_stream(websocket: WebSocket) -> None:
|
| 13 |
+
"""Send an exception stream to the client."""
|
| 14 |
+
message = "An error occurred while processing your request. Please try again."
|
| 15 |
+
await websocket.send_json({"type": "init", "content": ""})
|
| 16 |
+
for c in message:
|
| 17 |
+
await websocket.send_json({"type": "ai_token", "content": c})
|
| 18 |
+
await asyncio.sleep(0.01)
|
| 19 |
+
await websocket.send_json({"type": "ai", "content": None})
|
| 20 |
+
await websocket.send_json({"type": "finish", "content": None, "messageId": None})
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
async def handle_websocket_streaming(
|
| 24 |
+
websocket: WebSocket,
|
| 25 |
+
pipeline_executor: Callable[[asyncio.Event], AsyncGenerator[dict, None]],
|
| 26 |
+
message_id: str | None = None,
|
| 27 |
+
ai_message_id: str | None = None,
|
| 28 |
+
) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Handle WebSocket streaming with stop signal support.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
websocket: WebSocket connection
|
| 34 |
+
pipeline_executor: Async generator function that yields chunks and accepts stop_event
|
| 35 |
+
"""
|
| 36 |
+
stop_event = asyncio.Event()
|
| 37 |
+
partial_text = ""
|
| 38 |
+
|
| 39 |
+
recv_task = asyncio.create_task(create_stop_receiver(websocket, stop_event))
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
async for chunk in pipeline_executor(stop_event):
|
| 43 |
+
await websocket.send_json(chunk)
|
| 44 |
+
|
| 45 |
+
extracted_text = extract_partial_text(chunk)
|
| 46 |
+
if extracted_text:
|
| 47 |
+
partial_text += extracted_text
|
| 48 |
+
except asyncio.CancelledError:
|
| 49 |
+
stop_event.set()
|
| 50 |
+
raise
|
| 51 |
+
except RuntimeError:
|
| 52 |
+
pass
|
| 53 |
+
finally:
|
| 54 |
+
recv_task.cancel()
|
| 55 |
+
await asyncio.gather(recv_task, return_exceptions=True)
|
| 56 |
+
|
| 57 |
+
if stop_event.is_set() and partial_text.strip():
|
| 58 |
+
print(partial_text)
|
| 59 |
+
await websocket.send_json(
|
| 60 |
+
{"type": "finish", "content": None, "messageId": message_id, "aiMessageId": ai_message_id}
|
| 61 |
+
)
|
| 62 |
+
return partial_text
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
async def create_stop_receiver(websocket: WebSocket, stop_event: asyncio.Event) -> None:
|
| 66 |
+
"""
|
| 67 |
+
Create a receiver task that listens for stop signals from WebSocket.
|
| 68 |
+
"""
|
| 69 |
+
while True:
|
| 70 |
+
try:
|
| 71 |
+
msg = await websocket.receive_json()
|
| 72 |
+
if msg.get("type") == "stop":
|
| 73 |
+
stop_event.set()
|
| 74 |
+
break
|
| 75 |
+
except WebSocketDisconnect:
|
| 76 |
+
break
|
| 77 |
+
except Exception:
|
| 78 |
+
break
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def extract_partial_text(chunk: dict) -> str:
|
| 82 |
+
"""
|
| 83 |
+
Extract text content from a chunk for partial text accumulation.
|
| 84 |
+
"""
|
| 85 |
+
try:
|
| 86 |
+
if chunk.get("type") == "ai_token" and chunk.get("content"):
|
| 87 |
+
return chunk.get("content", "")
|
| 88 |
+
elif chunk.get("type") == "ai" and chunk.get("content"):
|
| 89 |
+
return chunk.get("content", "")
|
| 90 |
+
except Exception:
|
| 91 |
+
pass
|
| 92 |
+
return ""
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def prepare_messages_from_history(message_history: list[MessageModel]) -> list[dict]:
|
| 96 |
+
"""
|
| 97 |
+
Prepare messages from message history.
|
| 98 |
+
"""
|
| 99 |
+
result = []
|
| 100 |
+
for message in message_history:
|
| 101 |
+
result.append(
|
| 102 |
+
{
|
| 103 |
+
"role": message.role.value,
|
| 104 |
+
"content": message.content,
|
| 105 |
+
}
|
| 106 |
+
)
|
| 107 |
+
return result
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
async def truncate_from_message(
|
| 111 |
+
message_history: list[MessageModel], request: CreateMessageRequest
|
| 112 |
+
) -> list[dict]:
|
| 113 |
+
result = []
|
| 114 |
+
ids_to_delete = []
|
| 115 |
+
found = False
|
| 116 |
+
|
| 117 |
+
for message in message_history:
|
| 118 |
+
msg_id = message.id
|
| 119 |
+
if found or msg_id == request.messageId:
|
| 120 |
+
found = True
|
| 121 |
+
ids_to_delete.append(msg_id)
|
| 122 |
+
continue
|
| 123 |
+
result.append({"role": message.role.value, "content": message.content})
|
| 124 |
+
|
| 125 |
+
asyncio.create_task(truncate_message_history(request.messageId, ids_to_delete))
|
| 126 |
+
return result
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
async def add_messages(
|
| 130 |
+
chat_id: str,
|
| 131 |
+
account_id: str,
|
| 132 |
+
request: CreateMessageRequest,
|
| 133 |
+
messages: list[dict],
|
| 134 |
+
message_args: tuple[str, str, datetime, datetime],
|
| 135 |
+
) -> list[dict]:
|
| 136 |
+
messages.append({"role": "human", "content": request.content})
|
| 137 |
+
messages.append({"role": "ai", "content": message_args[0]})
|
| 138 |
+
asyncio.create_task(
|
| 139 |
+
add_messages_obj(chat_id, account_id, request, message_args)
|
| 140 |
+
)
|
| 141 |
+
return messages
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
async def compress_audio(audio_file: UploadFile) -> bytes | None:
|
| 147 |
+
"""
|
| 148 |
+
Compress an uploaded audio file to MP3 format.
|
| 149 |
+
"""
|
| 150 |
+
file_as_bytes = await audio_file.read()
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
audio_segment = AudioSegment.from_file(io.BytesIO(file_as_bytes))
|
| 154 |
+
duration = audio_segment.duration_seconds
|
| 155 |
+
|
| 156 |
+
if duration > 300:
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
mp3_data = io.BytesIO()
|
| 160 |
+
audio_segment.export(mp3_data, format="mp3")
|
| 161 |
+
mp3_data.seek(0)
|
| 162 |
+
|
| 163 |
+
except OSError:
|
| 164 |
+
return file_as_bytes
|
| 165 |
+
|
| 166 |
+
except Exception as e: # pylint: disable=W0703
|
| 167 |
+
raise HTTPException(status_code=500, detail=str(e)) # pylint: disable=W0707
|
| 168 |
+
|
| 169 |
+
return mp3_data.read()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def detect_silence(
|
| 173 |
+
audio_data: bytes, silence_threshold: float = -40.0, min_speech_duration: float = 1.0
|
| 174 |
+
) -> bool:
|
| 175 |
+
try:
|
| 176 |
+
audio_segment = AudioSegment.from_file(io.BytesIO(audio_data))
|
| 177 |
+
chunk_length_ms = 2000
|
| 178 |
+
rms_values = []
|
| 179 |
+
for start_ms in range(0, len(audio_segment), chunk_length_ms):
|
| 180 |
+
end_ms = min(start_ms + chunk_length_ms, len(audio_segment))
|
| 181 |
+
chunk = audio_segment[start_ms:end_ms]
|
| 182 |
+
if len(chunk) > 0:
|
| 183 |
+
chunk_mono = chunk.set_channels(1)
|
| 184 |
+
rms_db = chunk_mono.dBFS
|
| 185 |
+
rms_values.append(rms_db)
|
| 186 |
+
|
| 187 |
+
if not rms_values:
|
| 188 |
+
return False
|
| 189 |
+
|
| 190 |
+
mean_rms = statistics.mean(rms_values)
|
| 191 |
+
|
| 192 |
+
non_silent_chunks = sum(1 for rms in rms_values if rms > silence_threshold)
|
| 193 |
+
total_chunks = len(rms_values)
|
| 194 |
+
non_silent_ratio = non_silent_chunks / total_chunks if total_chunks > 0 else 0
|
| 195 |
+
|
| 196 |
+
non_silent_duration = (non_silent_chunks * chunk_length_ms) / 1000.0 # Convert to seconds
|
| 197 |
+
|
| 198 |
+
has_speech = (
|
| 199 |
+
mean_rms > silence_threshold - 10
|
| 200 |
+
and non_silent_ratio > 0.1
|
| 201 |
+
and non_silent_duration >= min_speech_duration
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return has_speech
|
| 205 |
+
|
| 206 |
+
except Exception:
|
| 207 |
+
return True
|
cbh/api/ari/views.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from fastapi import Depends, WebSocket, WebSocketDisconnect, UploadFile, File, HTTPException
|
| 4 |
+
from bson import ObjectId
|
| 5 |
+
|
| 6 |
+
from cbh.api.account.models import AccountModel
|
| 7 |
+
from cbh.api.ari import ari_router
|
| 8 |
+
from cbh.api.ari.services import convert_audio_to_text
|
| 9 |
+
from cbh.api.ari.schemas import (
|
| 10 |
+
CreateMessageRequest, TranscribeResponse,
|
| 11 |
+
)
|
| 12 |
+
from cbh.api.ari.services.agent import AriAgent
|
| 13 |
+
from cbh.api.ari.utils import (
|
| 14 |
+
prepare_messages_from_history,
|
| 15 |
+
send_exception_stream,
|
| 16 |
+
handle_websocket_streaming,
|
| 17 |
+
truncate_from_message,
|
| 18 |
+
add_messages, detect_silence, compress_audio,
|
| 19 |
+
)
|
| 20 |
+
from cbh.api.common.db_requests import get_all_objs, get_obj_by_id
|
| 21 |
+
from cbh.core.security import PermissionDependency, check_account_token
|
| 22 |
+
from cbh.core.wrappers import CbhResponseWrapper
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@ari_router.websocket("/{chatId}/send")
|
| 26 |
+
async def send_ari_message(chatId: str, websocket: WebSocket):
|
| 27 |
+
await websocket.accept()
|
| 28 |
+
token = websocket.query_params.get("token")
|
| 29 |
+
token = check_account_token(token)
|
| 30 |
+
if not token:
|
| 31 |
+
await websocket.close(code=1008)
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
account, (message_history, _) = await asyncio.gather(
|
| 35 |
+
get_obj_by_id(AccountModel, token["account_id"]),
|
| 36 |
+
get_all_objs(
|
| 37 |
+
MessageModel,
|
| 38 |
+
100000,
|
| 39 |
+
0,
|
| 40 |
+
additional_filter={"accountId": token["account_id"], "chatId": chatId},
|
| 41 |
+
sort=("id", 1),
|
| 42 |
+
),
|
| 43 |
+
)
|
| 44 |
+
agent = AriAgent(account)
|
| 45 |
+
|
| 46 |
+
messages = prepare_messages_from_history(message_history)
|
| 47 |
+
while True:
|
| 48 |
+
try:
|
| 49 |
+
request = await websocket.receive_json()
|
| 50 |
+
request = CreateMessageRequest(**request)
|
| 51 |
+
user_message_time = datetime.now()
|
| 52 |
+
|
| 53 |
+
if request.messageId:
|
| 54 |
+
messages = await truncate_from_message(message_history, request)
|
| 55 |
+
else:
|
| 56 |
+
request.messageId = str(ObjectId())
|
| 57 |
+
|
| 58 |
+
async def agent_executor(stop_event):
|
| 59 |
+
async for chunk in agent.stream(messages, request, stop_event):
|
| 60 |
+
yield chunk
|
| 61 |
+
await asyncio.sleep(0.01)
|
| 62 |
+
|
| 63 |
+
ai_message_id = str(ObjectId())
|
| 64 |
+
ai_response = await handle_websocket_streaming(
|
| 65 |
+
websocket=websocket,
|
| 66 |
+
pipeline_executor=agent_executor,
|
| 67 |
+
message_id=request.messageId,
|
| 68 |
+
ai_message_id=ai_message_id,
|
| 69 |
+
)
|
| 70 |
+
ai_message_time = datetime.now()
|
| 71 |
+
|
| 72 |
+
message_args = (ai_response, ai_message_id, ai_message_time, user_message_time)
|
| 73 |
+
messages = await add_messages(chatId, account.id, request, messages, message_args)
|
| 74 |
+
|
| 75 |
+
except WebSocketDisconnect:
|
| 76 |
+
return
|
| 77 |
+
except Exception as e:
|
| 78 |
+
await send_exception_stream(websocket)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@ari_router.post("/voice/transcript")
|
| 82 |
+
async def get_voice_transcript(
|
| 83 |
+
file: UploadFile = File(...),
|
| 84 |
+
_: AccountModel = Depends(PermissionDependency()),
|
| 85 |
+
) -> CbhResponseWrapper[TranscribeResponse]:
|
| 86 |
+
"""
|
| 87 |
+
Transcribe an uploaded audio file.
|
| 88 |
+
"""
|
| 89 |
+
mp3_data = await compress_audio(file)
|
| 90 |
+
if mp3_data is None:
|
| 91 |
+
raise HTTPException(status_code=400, detail="Could not compress audio file")
|
| 92 |
+
if not detect_silence(mp3_data):
|
| 93 |
+
raise HTTPException(
|
| 94 |
+
status_code=422,
|
| 95 |
+
detail="Recording appears to contain only silence or background noise.",
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
transcribed_text = await convert_audio_to_text(mp3_data, file.filename)
|
| 99 |
+
return CbhResponseWrapper(data=TranscribeResponse(text=transcribed_text))
|
cbh/api/chats/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
|
| 3 |
+
chats_router = APIRouter(prefix="/chats", tags=["chats"])
|
| 4 |
+
|
| 5 |
+
from . import views
|
cbh/api/chats/db_requests.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pymongo import ReturnDocument
|
| 2 |
+
|
| 3 |
+
from cbh.api.account.models import AccountModel, AccountShorten
|
| 4 |
+
from cbh.api.chats.models import ChatModel
|
| 5 |
+
from cbh.api.chats.schemas import ChatFilter
|
| 6 |
+
from cbh.api.common.db_requests import get_all_objs
|
| 7 |
+
from cbh.api.common.schemas import FilterRequest
|
| 8 |
+
from cbh.core.config import settings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
async def create_chat_obj(account: AccountModel, name: str) -> ChatModel:
|
| 12 |
+
"""
|
| 13 |
+
Create a chat object.
|
| 14 |
+
"""
|
| 15 |
+
chat = ChatModel(account=AccountShorten(**account.model_dump()), name=name)
|
| 16 |
+
await settings.DB_CLIENT.chats.insert_one(chat.to_mongo())
|
| 17 |
+
return chat
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
async def update_chat_obj(account: AccountModel, chat_id: str, name: str) -> ChatModel:
|
| 21 |
+
"""
|
| 22 |
+
Update a chat object.
|
| 23 |
+
"""
|
| 24 |
+
chat = await settings.DB_CLIENT.chats.find_one_and_update(
|
| 25 |
+
{"id": chat_id, "account.id": account.id},
|
| 26 |
+
{"$set": {"name": name}},
|
| 27 |
+
return_document=ReturnDocument.AFTER,
|
| 28 |
+
)
|
| 29 |
+
return ChatModel.from_mongo(chat)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
async def filter_chats_objs(
|
| 33 |
+
account: AccountModel, request: FilterRequest[ChatFilter]
|
| 34 |
+
) -> tuple[list[ChatModel], int]:
|
| 35 |
+
"""
|
| 36 |
+
Filter chats objects.
|
| 37 |
+
"""
|
| 38 |
+
filter_ = {"account.id": account.id}
|
| 39 |
+
if request.filter.searchTerm:
|
| 40 |
+
filter_["name"] = {"$regex": f"^{request.filter.searchTerm}", "$options": "i"}
|
| 41 |
+
|
| 42 |
+
chats, total_count = await get_all_objs(
|
| 43 |
+
ChatModel, request.pageSize, request.pageIndex, additional_filter=filter_
|
| 44 |
+
)
|
| 45 |
+
return chats, total_count
|
cbh/api/chats/models.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
|
| 3 |
+
from pydantic import Field
|
| 4 |
+
|
| 5 |
+
from cbh.api.account.models import AccountShorten
|
| 6 |
+
from cbh.core.database import MongoBaseModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ChatModel(MongoBaseModel):
|
| 10 |
+
name: str
|
| 11 |
+
account: AccountShorten
|
| 12 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
cbh/api/chats/schemas.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CreateChatRequest(BaseModel):
|
| 5 |
+
query: str
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class UpdateChatRequest(BaseModel):
|
| 9 |
+
name: str
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ChatFilter(BaseModel):
|
| 13 |
+
searchTerm: str | None = None
|
cbh/api/chats/services/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .workflows import generate_chat_name
|
| 2 |
+
|
| 3 |
+
__all__ = ["generate_chat_name"]
|
cbh/api/chats/services/prompts.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ChatPrompts:
|
| 5 |
+
"""
|
| 6 |
+
Ari prompts.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
generate_chat_name = """You are a title generator for a sales team management assistant.
|
| 10 |
+
|
| 11 |
+
Generate a concise, descriptive title (2–4 words) for a chat based on the user's first message.
|
| 12 |
+
|
| 13 |
+
## User message:
|
| 14 |
+
{query}
|
| 15 |
+
|
| 16 |
+
## Rules
|
| 17 |
+
- Use title case
|
| 18 |
+
- No punctuation or filler words
|
| 19 |
+
- Capture the specific intent, not just the topic
|
| 20 |
+
- Prefer action + subject format when applicable
|
| 21 |
+
- Output only the title, nothing else
|
| 22 |
+
|
| 23 |
+
## Context
|
| 24 |
+
The assistant helps admins who manage sales teams. Users ask for statistics, rep performance, scenario creation, and other team management tasks.
|
| 25 |
+
|
| 26 |
+
## Examples
|
| 27 |
+
<example>
|
| 28 |
+
query: Who is the best sales rep in the "Connor's team" team?
|
| 29 |
+
title: Top Rep Connor's Team
|
| 30 |
+
</example>
|
| 31 |
+
<example>
|
| 32 |
+
query: What is the average score for Maksim Shymanouski?
|
| 33 |
+
title: Maksim Average Score
|
| 34 |
+
</example>
|
| 35 |
+
<example>
|
| 36 |
+
query: Create me a friendly scenario with 5 objections
|
| 37 |
+
title: Friendly 5-Objection Scenario
|
| 38 |
+
</example>
|
| 39 |
+
<example>
|
| 40 |
+
query: Show me the call conversion rate for last month
|
| 41 |
+
title: Last Month Conversion Rate
|
| 42 |
+
</example>
|
| 43 |
+
<example>
|
| 44 |
+
query: Which reps haven't hit their targets this quarter?
|
| 45 |
+
title: Reps Missing Quarterly Targets
|
| 46 |
+
</example>"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@lru_cache()
|
| 50 |
+
def get_prompts() -> ChatPrompts:
|
| 51 |
+
"""
|
| 52 |
+
Get prompts.
|
| 53 |
+
"""
|
| 54 |
+
return ChatPrompts()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
chat_prompts = get_prompts()
|
cbh/api/chats/services/workflows.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
|
| 4 |
+
from cbh.api.chats.services.prompts import chat_prompts
|
| 5 |
+
from cbh.core.config import settings
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ChatNameSchema(BaseModel):
|
| 9 |
+
"""
|
| 10 |
+
Chat name schema.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
name: str = Field(description="A name for the chat.")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
async def generate_chat_name(query: str) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Generate a chat name.
|
| 19 |
+
"""
|
| 20 |
+
prompt = ChatPromptTemplate.from_messages([("system", chat_prompts.generate_chat_name)])
|
| 21 |
+
chain = prompt | settings.get_llm(schema=ChatNameSchema, model="gpt-4.1-nano")
|
| 22 |
+
result = await chain.ainvoke({"query": query})
|
| 23 |
+
return result.name
|
cbh/api/chats/views.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Depends
|
| 2 |
+
|
| 3 |
+
from cbh.api.account.dto import AccountType
|
| 4 |
+
from cbh.api.account.models import AccountModel
|
| 5 |
+
from cbh.api.chats import chats_router
|
| 6 |
+
from cbh.api.chats.db_requests import create_chat_obj, update_chat_obj, filter_chats_objs
|
| 7 |
+
from cbh.api.chats.models import ChatModel
|
| 8 |
+
from cbh.api.chats.schemas import CreateChatRequest, UpdateChatRequest, ChatFilter
|
| 9 |
+
from cbh.api.chats.services import generate_chat_name
|
| 10 |
+
from cbh.api.common.schemas import AllObjectsResponse, Paging
|
| 11 |
+
from cbh.api.common.schemas import FilterRequest
|
| 12 |
+
from cbh.core.security import PermissionDependency
|
| 13 |
+
from cbh.core.wrappers import CbhResponseWrapper
|
| 14 |
+
|
| 15 |
+
@chats_router.post("/chats")
|
| 16 |
+
async def create_new_chat(
|
| 17 |
+
request: CreateChatRequest,
|
| 18 |
+
account: AccountModel = Depends(PermissionDependency()),
|
| 19 |
+
) -> CbhResponseWrapper[ChatModel]:
|
| 20 |
+
chat_name = await generate_chat_name(request.query)
|
| 21 |
+
chat = await create_chat_obj(account, chat_name)
|
| 22 |
+
return CbhResponseWrapper(data=chat)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@chats_router.put("/chats/{chatId}")
|
| 26 |
+
async def update_chat(
|
| 27 |
+
chatId: str,
|
| 28 |
+
request: UpdateChatRequest,
|
| 29 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.ADMIN, AccountType.OWNER])),
|
| 30 |
+
) -> CbhResponseWrapper[ChatModel]:
|
| 31 |
+
chat = await update_chat_obj(account, chatId, request.name)
|
| 32 |
+
return CbhResponseWrapper(data=chat)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@chats_router.post("/chats/filter")
|
| 36 |
+
async def filter_chats(
|
| 37 |
+
request: FilterRequest[ChatFilter],
|
| 38 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.ADMIN, AccountType.OWNER])),
|
| 39 |
+
) -> CbhResponseWrapper[AllObjectsResponse[ChatModel]]:
|
| 40 |
+
chats, total_count = await filter_chats_objs(account, request)
|
| 41 |
+
return CbhResponseWrapper(
|
| 42 |
+
data=AllObjectsResponse(
|
| 43 |
+
data=chats,
|
| 44 |
+
paging=Paging(
|
| 45 |
+
pageSize=request.pageSize, pageIndex=request.pageIndex, totalCount=total_count
|
| 46 |
+
),
|
| 47 |
+
)
|
| 48 |
+
)
|
| 49 |
+
|
cbh/api/common/db_requests.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common database requests.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import re
|
| 7 |
+
from datetime import timedelta, datetime
|
| 8 |
+
from typing import TypeVar
|
| 9 |
+
|
| 10 |
+
from fastapi import HTTPException
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
|
| 13 |
+
from cbh.api.common.schemas import (
|
| 14 |
+
SearchRequest,
|
| 15 |
+
)
|
| 16 |
+
from cbh.core.config import settings
|
| 17 |
+
|
| 18 |
+
T = TypeVar("T", bound=BaseModel)
|
| 19 |
+
|
| 20 |
+
collection_map = {
|
| 21 |
+
"AccountModel": "accounts",
|
| 22 |
+
"AccountShorten": "accounts",
|
| 23 |
+
"UserInsightModel": "userinsights",
|
| 24 |
+
"UserInsightShorten": "userinsights",
|
| 25 |
+
"CallModel": "calls",
|
| 26 |
+
"RepModel": "reps",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
async def get_obj_by_id(
|
| 31 |
+
model: T,
|
| 32 |
+
obj_id: str | None,
|
| 33 |
+
additional_filter: dict | None = None,
|
| 34 |
+
projection: dict | None = None,
|
| 35 |
+
exception: bool = True,
|
| 36 |
+
) -> T | None:
|
| 37 |
+
"""
|
| 38 |
+
Get an object by ID.
|
| 39 |
+
"""
|
| 40 |
+
filter_ = {"id": obj_id} if obj_id else {}
|
| 41 |
+
if additional_filter:
|
| 42 |
+
filter_.update(additional_filter)
|
| 43 |
+
obj = await settings.DB_CLIENT[collection_map[model.__name__]].find_one(filter_, projection)
|
| 44 |
+
if obj is None:
|
| 45 |
+
if exception:
|
| 46 |
+
raise HTTPException(status_code=404, detail="Object not found.")
|
| 47 |
+
else:
|
| 48 |
+
return None
|
| 49 |
+
return model.from_mongo(obj)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
async def get_all_objs(
|
| 53 |
+
model: T,
|
| 54 |
+
page_size: int,
|
| 55 |
+
page_index: int,
|
| 56 |
+
sort: tuple[str, int] = ("id", -1),
|
| 57 |
+
additional_filter: dict | None = None,
|
| 58 |
+
projection: dict | None = None,
|
| 59 |
+
) -> tuple[list[T], int]:
|
| 60 |
+
"""
|
| 61 |
+
Get all objects.
|
| 62 |
+
"""
|
| 63 |
+
filter_ = additional_filter if additional_filter else {}
|
| 64 |
+
skip = page_index * page_size
|
| 65 |
+
objs, total_count = await asyncio.gather(
|
| 66 |
+
settings.DB_CLIENT[collection_map[model.__name__]]
|
| 67 |
+
.find(filter_, projection)
|
| 68 |
+
.sort(*sort)
|
| 69 |
+
.skip(skip)
|
| 70 |
+
.limit(page_size)
|
| 71 |
+
.to_list(page_size),
|
| 72 |
+
settings.DB_CLIENT[collection_map[model.__name__]].count_documents(filter_),
|
| 73 |
+
)
|
| 74 |
+
return [model.from_mongo(obj) for obj in objs], total_count
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
async def delete_obj(
|
| 78 |
+
model: T, obj_id: str | None = None, additional_filter: dict | None = None
|
| 79 |
+
) -> T:
|
| 80 |
+
"""
|
| 81 |
+
Delete an object.
|
| 82 |
+
"""
|
| 83 |
+
filter_ = {"id": obj_id} if obj_id else {}
|
| 84 |
+
if additional_filter:
|
| 85 |
+
filter_.update(additional_filter)
|
| 86 |
+
obj = await settings.DB_CLIENT[collection_map[model.__name__]].find_one(filter_)
|
| 87 |
+
if obj is None:
|
| 88 |
+
raise HTTPException(status_code=404, detail="Object not found.")
|
| 89 |
+
await settings.DB_CLIENT[collection_map[model.__name__]].delete_one(filter_)
|
| 90 |
+
return model.from_mongo(obj)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
async def search_objs(
|
| 94 |
+
model: T,
|
| 95 |
+
data: SearchRequest,
|
| 96 |
+
additional_filter: dict | None = None,
|
| 97 |
+
projection: dict | None = None,
|
| 98 |
+
) -> tuple[list[T], int]:
|
| 99 |
+
"""
|
| 100 |
+
Search for objects in a specified collection based on search filters.
|
| 101 |
+
"""
|
| 102 |
+
filters = []
|
| 103 |
+
date_filters = {}
|
| 104 |
+
|
| 105 |
+
for search_filter in data.filter:
|
| 106 |
+
if isinstance(search_filter.value, str):
|
| 107 |
+
date_match = re.fullmatch(r"^(\d{4}-\d{2}-\d{2});([+-]\d{1,2})$", search_filter.value)
|
| 108 |
+
|
| 109 |
+
if date_match:
|
| 110 |
+
if search_filter.name not in date_filters:
|
| 111 |
+
date_filters[search_filter.name] = []
|
| 112 |
+
|
| 113 |
+
date_filters[search_filter.name].append(
|
| 114 |
+
{
|
| 115 |
+
"date": datetime.strptime(date_match.group(1), "%Y-%m-%d"),
|
| 116 |
+
"timezone_offset": int(date_match.group(2)),
|
| 117 |
+
}
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
filters.append(
|
| 121 |
+
{
|
| 122 |
+
search_filter.name: {
|
| 123 |
+
"$regex": f"^{re.escape(search_filter.value)}",
|
| 124 |
+
"$options": "i",
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
filters.append({search_filter.name: search_filter.value})
|
| 130 |
+
|
| 131 |
+
for field_name, dates in date_filters.items():
|
| 132 |
+
if len(dates) == 1:
|
| 133 |
+
date_info = dates[0]
|
| 134 |
+
user_local_day_start = date_info["date"]
|
| 135 |
+
user_local_day_end = user_local_day_start + timedelta(days=1)
|
| 136 |
+
filters.append(
|
| 137 |
+
{
|
| 138 |
+
field_name: {
|
| 139 |
+
"$gte": (
|
| 140 |
+
user_local_day_start - timedelta(hours=date_info["timezone_offset"])
|
| 141 |
+
).isoformat(),
|
| 142 |
+
"$lt": (
|
| 143 |
+
user_local_day_end - timedelta(hours=date_info["timezone_offset"])
|
| 144 |
+
).isoformat(),
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
)
|
| 148 |
+
elif len(dates) == 2:
|
| 149 |
+
start_date = min(dates, key=lambda x: x["date"])
|
| 150 |
+
end_date = max(dates, key=lambda x: x["date"])
|
| 151 |
+
|
| 152 |
+
start_datetime = start_date["date"] - timedelta(hours=start_date["timezone_offset"])
|
| 153 |
+
end_datetime = (
|
| 154 |
+
end_date["date"] + timedelta(days=1) - timedelta(hours=end_date["timezone_offset"])
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
filters.append(
|
| 158 |
+
{
|
| 159 |
+
field_name: {
|
| 160 |
+
"$gte": start_datetime.isoformat(),
|
| 161 |
+
"$lt": end_datetime.isoformat(),
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
)
|
| 165 |
+
elif len(dates) > 2:
|
| 166 |
+
dates_sorted = sorted(dates, key=lambda x: x["date"])
|
| 167 |
+
start_date = dates_sorted[0]
|
| 168 |
+
end_date = dates_sorted[-1]
|
| 169 |
+
|
| 170 |
+
start_datetime = start_date["date"] - timedelta(hours=start_date["timezone_offset"])
|
| 171 |
+
end_datetime = (
|
| 172 |
+
end_date["date"] + timedelta(days=1) - timedelta(hours=end_date["timezone_offset"])
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
filters.append(
|
| 176 |
+
{
|
| 177 |
+
field_name: {
|
| 178 |
+
"$gte": start_datetime.isoformat(),
|
| 179 |
+
"$lt": end_datetime.isoformat(),
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if additional_filter:
|
| 185 |
+
filters.append(additional_filter)
|
| 186 |
+
regex_filter = {"$and": filters} if filters else {}
|
| 187 |
+
objects, total_count = await asyncio.gather(
|
| 188 |
+
settings.DB_CLIENT[collection_map[model.__name__]]
|
| 189 |
+
.find(regex_filter, projection)
|
| 190 |
+
.sort("id", -1)
|
| 191 |
+
.skip(data.pageSize * data.pageIndex)
|
| 192 |
+
.limit(data.pageSize)
|
| 193 |
+
.to_list(length=data.pageSize),
|
| 194 |
+
settings.DB_CLIENT[collection_map[model.__name__]].count_documents(regex_filter),
|
| 195 |
+
)
|
| 196 |
+
return [model.from_mongo(obj) for obj in objects], total_count
|
cbh/api/common/dto.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common DTOs.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from enum import Enum
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, field_validator
|
| 9 |
+
|
| 10 |
+
from cbh.core.config import settings
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Paging(BaseModel):
|
| 14 |
+
"""
|
| 15 |
+
Pagination model for API responses.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
pageSize: int
|
| 19 |
+
pageIndex: int
|
| 20 |
+
totalCount: int
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SearchFilter(BaseModel):
|
| 24 |
+
"""
|
| 25 |
+
Search filter model for constructing database queries.
|
| 26 |
+
|
| 27 |
+
Attributes:
|
| 28 |
+
name (str): Field name to filter on
|
| 29 |
+
value (str | int): Value to search for
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
name: str
|
| 33 |
+
value: str | int
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Scores(BaseModel):
|
| 37 |
+
"""
|
| 38 |
+
Scores for the recording.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
communication: int
|
| 42 |
+
activeListening: int
|
| 43 |
+
conversation: int
|
| 44 |
+
objection: int
|
| 45 |
+
empathy: int
|
| 46 |
+
final: int
|
| 47 |
+
|
| 48 |
+
def __sub__(self, other: "Scores") -> "Scores":
|
| 49 |
+
return Scores(
|
| 50 |
+
communication=self.communication - other.communication,
|
| 51 |
+
activeListening=self.activeListening - other.activeListening,
|
| 52 |
+
conversation=self.conversation - other.conversation,
|
| 53 |
+
objection=self.objection - other.objection,
|
| 54 |
+
empathy=self.empathy - other.empathy,
|
| 55 |
+
final=self.final - other.final,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def __mod__(self, other: "Scores") -> "Scores":
|
| 59 |
+
def calc_percentage_diff(current: int, previous: int) -> int:
|
| 60 |
+
if previous == 0:
|
| 61 |
+
return 0
|
| 62 |
+
return int(((current - previous) / previous) * 100)
|
| 63 |
+
|
| 64 |
+
return Scores(
|
| 65 |
+
communication=calc_percentage_diff(self.communication, other.communication),
|
| 66 |
+
activeListening=calc_percentage_diff(
|
| 67 |
+
self.activeListening, other.activeListening
|
| 68 |
+
),
|
| 69 |
+
conversation=calc_percentage_diff(self.conversation, other.conversation),
|
| 70 |
+
objection=calc_percentage_diff(self.objection, other.objection),
|
| 71 |
+
empathy=calc_percentage_diff(self.empathy, other.empathy),
|
| 72 |
+
final=calc_percentage_diff(self.final, other.final),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class DateValue(BaseModel):
|
| 77 |
+
"""
|
| 78 |
+
Date value.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
date: datetime
|
| 82 |
+
value: int
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ValueDelta(BaseModel):
|
| 86 |
+
"""
|
| 87 |
+
Value delta.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
value: int
|
| 91 |
+
delta: int
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class IDName(BaseModel):
|
| 95 |
+
id: str | int
|
| 96 |
+
name: str
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class IDNamePicture(IDName):
|
| 100 |
+
pictureUrl: str | None
|
| 101 |
+
|
| 102 |
+
@field_validator("pictureUrl", mode="before")
|
| 103 |
+
@classmethod
|
| 104 |
+
def serialize_picture_url(cls, v: str | None) -> str | None:
|
| 105 |
+
"""
|
| 106 |
+
Serialize the picture URL.
|
| 107 |
+
"""
|
| 108 |
+
if v:
|
| 109 |
+
if not v.startswith("https://"):
|
| 110 |
+
return settings.S3_CLIENT.generate_presigned_url(v, expiration=3600)
|
| 111 |
+
return v
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class OrderType(Enum):
|
| 116 |
+
"""
|
| 117 |
+
Order type.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
ASCENDING = 1
|
| 121 |
+
DESCENDING = -1
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class SortBy(BaseModel):
|
| 125 |
+
"""
|
| 126 |
+
Sort by.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
name: str
|
| 130 |
+
order: OrderType
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class TokenUsage(BaseModel):
|
| 134 |
+
inputTokens: int
|
| 135 |
+
outputTokens: int
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class SkillStatistics(BaseModel):
|
| 139 |
+
"""
|
| 140 |
+
Skill statistics model.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
score: int
|
| 144 |
+
bestAccount: IDNamePicture | None = None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class InsightType(Enum):
|
| 148 |
+
MISTAKE = 1
|
| 149 |
+
ACHIEVEMENT = 2
|
cbh/api/common/schemas.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common schemas.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import date
|
| 6 |
+
from typing import TypeVar, Generic
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
from cbh.api.common.dto import (
|
| 11 |
+
OrderType,
|
| 12 |
+
Paging,
|
| 13 |
+
SearchFilter,
|
| 14 |
+
SkillStatistics,
|
| 15 |
+
SortBy,
|
| 16 |
+
DateValue, InsightType,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
T = TypeVar("T", bound=BaseModel)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AllObjectsResponse(BaseModel, Generic[T]):
|
| 23 |
+
"""
|
| 24 |
+
Response model for all objects.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
paging: Paging
|
| 28 |
+
data: list[T]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SearchRequest(BaseModel):
|
| 32 |
+
"""
|
| 33 |
+
Request schema for searching calls or statistics.
|
| 34 |
+
|
| 35 |
+
Attributes:
|
| 36 |
+
filter (list[SearchFilter]): List of filters to apply
|
| 37 |
+
pageSize (int): Number of items to return per page
|
| 38 |
+
pageIndex (int): Page index to retrieve
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
filter: list[SearchFilter]
|
| 42 |
+
pageSize: int
|
| 43 |
+
pageIndex: int
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class FilterRequest(BaseModel, Generic[T]):
|
| 47 |
+
"""
|
| 48 |
+
Filter request.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
filter: T
|
| 52 |
+
sortBy: SortBy | None = None
|
| 53 |
+
pageSize: int = 10
|
| 54 |
+
pageIndex: int = 0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PlainTextResponse(BaseModel):
|
| 58 |
+
"""
|
| 59 |
+
Response model for plain text.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
text: str
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class BatchIdsRequest(BaseModel):
|
| 66 |
+
"""
|
| 67 |
+
Batch ids request.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
ids: list[str]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class EmailRequest(BaseModel):
|
| 74 |
+
"""
|
| 75 |
+
Email request.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
email: str
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class OrderTypeRequest(BaseModel):
|
| 82 |
+
"""
|
| 83 |
+
Order type request.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
order: OrderType
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class SkillsStatisticsResponse(BaseModel):
|
| 90 |
+
"""
|
| 91 |
+
Scenario skills statistics response.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
communication: SkillStatistics
|
| 95 |
+
activeListening: SkillStatistics
|
| 96 |
+
conversation: SkillStatistics
|
| 97 |
+
objection: SkillStatistics
|
| 98 |
+
empathy: SkillStatistics
|
| 99 |
+
final: int
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class RepProgressResponse(BaseModel):
|
| 103 |
+
"""
|
| 104 |
+
Rep progress response.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
progress: list[DateValue]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class AIInsightsResponse(BaseModel):
|
| 111 |
+
"""
|
| 112 |
+
AI insights.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
aiInsights: str | None = None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class InsightFilter(BaseModel):
|
| 119 |
+
type: InsightType
|
| 120 |
+
startDate: date | None = None
|
| 121 |
+
endDate: date | None = None
|
cbh/api/common/utils.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Common utilities.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import base64
|
| 7 |
+
import re
|
| 8 |
+
from typing import Callable, TypeVar
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
+
from cbh.api.account.dto import AccountType
|
| 13 |
+
from cbh.api.account.models import AccountModel
|
| 14 |
+
from cbh.api.common.dto import (
|
| 15 |
+
LeaderboardStatisticsPosition,
|
| 16 |
+
Scores,
|
| 17 |
+
OrderType,
|
| 18 |
+
SkillStatistics,
|
| 19 |
+
IDNamePicture,
|
| 20 |
+
)
|
| 21 |
+
from cbh.api.common.schemas import SkillsStatisticsResponse
|
| 22 |
+
from cbh.core.config import settings
|
| 23 |
+
|
| 24 |
+
T = TypeVar("T")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def calculate_avg_scores(reports: list) -> Scores:
|
| 28 |
+
if not reports:
|
| 29 |
+
return Scores(
|
| 30 |
+
communication=0,
|
| 31 |
+
activeListening=0,
|
| 32 |
+
conversation=0,
|
| 33 |
+
objection=0,
|
| 34 |
+
empathy=0,
|
| 35 |
+
final=0,
|
| 36 |
+
)
|
| 37 |
+
return Scores(
|
| 38 |
+
communication=round(sum(report.scores.communication for report in reports) / len(reports)),
|
| 39 |
+
activeListening=round(
|
| 40 |
+
sum(report.scores.activeListening for report in reports) / len(reports)
|
| 41 |
+
),
|
| 42 |
+
conversation=round(sum(report.scores.conversation for report in reports) / len(reports)),
|
| 43 |
+
objection=round(sum(report.scores.objection for report in reports) / len(reports)),
|
| 44 |
+
empathy=round(sum(report.scores.empathy for report in reports) / len(reports)),
|
| 45 |
+
final=round(sum(report.scores.final for report in reports) / len(reports)),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def form_user_stats(session_reports: list) -> dict[str, dict]:
|
| 50 |
+
user_stats = {}
|
| 51 |
+
for report in session_reports:
|
| 52 |
+
user_id = report.account.id
|
| 53 |
+
if user_id not in user_stats:
|
| 54 |
+
user_stats[user_id] = {
|
| 55 |
+
"account": report.account,
|
| 56 |
+
"reports": [],
|
| 57 |
+
"attempts": 0,
|
| 58 |
+
}
|
| 59 |
+
user_stats[user_id]["reports"].append(report)
|
| 60 |
+
user_stats[user_id]["attempts"] += 1
|
| 61 |
+
return user_stats
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def leaderboard_sort_key(item: tuple, reverse: bool = False) -> tuple:
|
| 65 |
+
user_data, score = item
|
| 66 |
+
if reverse:
|
| 67 |
+
return -score, user_data["attempts"], user_data["account"].name.lower()
|
| 68 |
+
|
| 69 |
+
return score, -user_data["attempts"], user_data["account"].name.lower()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_leaderboard(
|
| 73 |
+
session_reports: list,
|
| 74 |
+
position_builder: Callable[[dict, Scores], T],
|
| 75 |
+
order: OrderType | None = None,
|
| 76 |
+
) -> list[T]:
|
| 77 |
+
user_stats = form_user_stats(session_reports)
|
| 78 |
+
user_scores = [
|
| 79 |
+
(user_data, calculate_avg_scores(user_data["reports"]).final)
|
| 80 |
+
for user_data in user_stats.values()
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
is_descending = order is None or order == OrderType.DESCENDING
|
| 84 |
+
sorted_users = sorted(
|
| 85 |
+
user_scores, key=lambda item: leaderboard_sort_key(item, reverse=is_descending)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
leaderboard = []
|
| 89 |
+
for user_data, _ in sorted_users:
|
| 90 |
+
avg_scores = calculate_avg_scores(user_data["reports"])
|
| 91 |
+
leaderboard.append(position_builder(user_data, avg_scores))
|
| 92 |
+
|
| 93 |
+
return leaderboard
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def build_leaderboard_simple(
|
| 97 |
+
session_reports: list,
|
| 98 |
+
pageSize: int | None = None,
|
| 99 |
+
pageIndex: int | None = None,
|
| 100 |
+
order: OrderType | None = None,
|
| 101 |
+
) -> list[LeaderboardStatisticsPosition]:
|
| 102 |
+
def position_builder(user_data: dict, avg_scores: Scores) -> LeaderboardStatisticsPosition:
|
| 103 |
+
return LeaderboardStatisticsPosition(
|
| 104 |
+
account=IDNamePicture(
|
| 105 |
+
id=user_data["account"].id,
|
| 106 |
+
name=user_data["account"].name,
|
| 107 |
+
pictureUrl=user_data["account"].pictureUrl,
|
| 108 |
+
),
|
| 109 |
+
scores=avg_scores,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
leaderboard = build_leaderboard(session_reports, position_builder, order)
|
| 113 |
+
return paginate_list(leaderboard, pageSize, pageIndex)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def paginate_list(
|
| 117 |
+
items: list[T], page_size: int | None = None, page_index: int | None = None
|
| 118 |
+
) -> list[T]:
|
| 119 |
+
if page_size is None or page_index is None:
|
| 120 |
+
return items
|
| 121 |
+
return items[page_index * page_size : (page_index + 1) * page_size]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def form_additional_scenario_filter(account: AccountModel, allow_demo: bool = False):
|
| 125 |
+
from cbh.api.scenario.dto import AssigneesType, ScenarioStatus
|
| 126 |
+
|
| 127 |
+
filter_ = {"owner.organization.id": account.organization.id}
|
| 128 |
+
if account.accountType == AccountType.USER:
|
| 129 |
+
filter_.update(
|
| 130 |
+
{
|
| 131 |
+
"$or": [
|
| 132 |
+
{"assignees": {"$size": 0}},
|
| 133 |
+
{
|
| 134 |
+
"assignees": {
|
| 135 |
+
"$elemMatch": {
|
| 136 |
+
"type": AssigneesType.USER.value,
|
| 137 |
+
"account.id": account.id,
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"assignees": {
|
| 143 |
+
"$elemMatch": {
|
| 144 |
+
"type": AssigneesType.TEAM.value,
|
| 145 |
+
"team.members": {"$elemMatch": {"id": account.id}},
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
},
|
| 149 |
+
],
|
| 150 |
+
"isTemplate": False,
|
| 151 |
+
"status": ScenarioStatus.ACTIVE.value,
|
| 152 |
+
}
|
| 153 |
+
)
|
| 154 |
+
if not allow_demo or account.accountType != AccountType.USER:
|
| 155 |
+
filter_.update(
|
| 156 |
+
{
|
| 157 |
+
"isDemo": False,
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
return filter_
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
async def convert_document_to_text(file: bytes, filename: str) -> str:
|
| 164 |
+
if filename.endswith(".txt"):
|
| 165 |
+
return file.decode("utf-8", errors="ignore")
|
| 166 |
+
filename = re.sub(r"[^\w\s.-]", "", filename)
|
| 167 |
+
base64_file = base64.b64encode(file).decode("utf-8")
|
| 168 |
+
headers = {"Content-Type": "application/json"}
|
| 169 |
+
data = {
|
| 170 |
+
"apikey": settings.CONVERTIO_API_KEY,
|
| 171 |
+
"input": "base64",
|
| 172 |
+
"file": base64_file,
|
| 173 |
+
"filename": filename,
|
| 174 |
+
"outputformat": "txt",
|
| 175 |
+
}
|
| 176 |
+
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=120)) as client:
|
| 177 |
+
response = await client.post("https://api.convertio.co/convert", json=data, headers=headers)
|
| 178 |
+
response = response.json()
|
| 179 |
+
if response["code"] == 200:
|
| 180 |
+
conversion_id = response["data"]["id"]
|
| 181 |
+
status = ""
|
| 182 |
+
attempt = 0
|
| 183 |
+
while status != "finish":
|
| 184 |
+
if attempt > 50:
|
| 185 |
+
raise Exception("Please, try again")
|
| 186 |
+
get_status_response = await client.get(
|
| 187 |
+
f"https://api.convertio.co/convert/{conversion_id}/status"
|
| 188 |
+
)
|
| 189 |
+
get_status_response = get_status_response.json()
|
| 190 |
+
if get_status_response["code"] != 200:
|
| 191 |
+
raise Exception("Please, try again")
|
| 192 |
+
else:
|
| 193 |
+
status = get_status_response["data"]["step"]
|
| 194 |
+
await asyncio.sleep(1)
|
| 195 |
+
attempt += 1
|
| 196 |
+
file_url = get_status_response["data"]["output"]["url"]
|
| 197 |
+
response = await client.get(file_url)
|
| 198 |
+
response.raise_for_status()
|
| 199 |
+
return response.content.decode("utf-8", errors="ignore")
|
| 200 |
+
else:
|
| 201 |
+
return ""
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def calculate_skills_statistics(session_reports: list) -> SkillsStatisticsResponse:
|
| 205 |
+
"""
|
| 206 |
+
Calculate team skills statistics.
|
| 207 |
+
"""
|
| 208 |
+
if not session_reports:
|
| 209 |
+
empty_skill = SkillStatistics(score=0, bestAccount=None)
|
| 210 |
+
return SkillsStatisticsResponse(
|
| 211 |
+
communication=empty_skill,
|
| 212 |
+
activeListening=empty_skill,
|
| 213 |
+
conversation=empty_skill,
|
| 214 |
+
objection=empty_skill,
|
| 215 |
+
empathy=empty_skill,
|
| 216 |
+
final=0,
|
| 217 |
+
)
|
| 218 |
+
avg_scores = calculate_avg_scores(session_reports)
|
| 219 |
+
skills = ["communication", "activeListening", "conversation", "objection", "empathy"]
|
| 220 |
+
|
| 221 |
+
skill_stats = {}
|
| 222 |
+
for skill in skills:
|
| 223 |
+
best_report = sorted(
|
| 224 |
+
session_reports,
|
| 225 |
+
key=lambda r, s=skill: (getattr(r.scores, s), r.datetimeInserted.timestamp()),
|
| 226 |
+
reverse=True,
|
| 227 |
+
)[0]
|
| 228 |
+
skill_stats[skill] = SkillStatistics(
|
| 229 |
+
score=getattr(avg_scores, skill),
|
| 230 |
+
bestAccount=IDNamePicture(
|
| 231 |
+
id=best_report.account.id,
|
| 232 |
+
name=best_report.account.name,
|
| 233 |
+
pictureUrl=best_report.account.pictureUrl,
|
| 234 |
+
),
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
return SkillsStatisticsResponse(
|
| 238 |
+
communication=skill_stats["communication"],
|
| 239 |
+
activeListening=skill_stats["activeListening"],
|
| 240 |
+
conversation=skill_stats["conversation"],
|
| 241 |
+
objection=skill_stats["objection"],
|
| 242 |
+
empathy=skill_stats["empathy"],
|
| 243 |
+
final=avg_scores.final,
|
| 244 |
+
)
|
cbh/api/messages/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
|
| 3 |
+
messages_router = APIRouter(prefix="/messages", tags=["messages"])
|
| 4 |
+
|
| 5 |
+
from . import views
|
cbh/api/messages/db_requests.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from cbh.api.common.db_requests import get_all_objs
|
| 2 |
+
from cbh.api.messages.models import MessageModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
async def get_chat_message_history(chat_id: str, account_id: str) -> list[dict]:
|
| 6 |
+
"""
|
| 7 |
+
Get the message history for a chat.
|
| 8 |
+
"""
|
| 9 |
+
messages = await get_all_objs(
|
| 10 |
+
MessageModel,
|
| 11 |
+
100000,
|
| 12 |
+
0,
|
| 13 |
+
additional_filter={"chat.account.id": account_id, "chat.id": chat_id},
|
| 14 |
+
)
|
| 15 |
+
response = []
|
| 16 |
+
for message in messages:
|
| 17 |
+
response.append(
|
| 18 |
+
{
|
| 19 |
+
"role": message.role,
|
| 20 |
+
"content": message.content,
|
| 21 |
+
}
|
| 22 |
+
)
|
| 23 |
+
return messages
|
cbh/api/messages/dto.py
ADDED
|
File without changes
|
cbh/api/messages/models.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from pydantic import Field
|
| 3 |
+
|
| 4 |
+
from cbh.api.ari.dto import Author
|
| 5 |
+
from cbh.core.database import MongoBaseModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MessageModel(MongoBaseModel):
|
| 9 |
+
role: Author
|
| 10 |
+
content: str
|
| 11 |
+
chatId: str
|
| 12 |
+
accountId: str
|
| 13 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
cbh/api/messages/utils.py
ADDED
|
File without changes
|
cbh/api/messages/views.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Depends
|
| 2 |
+
from cbh.api.account.models import AccountModel
|
| 3 |
+
from cbh.api.account.dto import AccountType
|
| 4 |
+
from cbh.api.messages import messages_router
|
| 5 |
+
from cbh.api.messages.models import MessageModel
|
| 6 |
+
from cbh.api.common.schemas import AllObjectsResponse, Paging
|
| 7 |
+
from cbh.api.common.db_requests import get_all_objs
|
| 8 |
+
from cbh.core.security import PermissionDependency
|
| 9 |
+
from cbh.core.wrappers import CbhResponseWrapper
|
| 10 |
+
|
| 11 |
+
@messages_router.get("/messages/{chatId}")
|
| 12 |
+
async def get_chat_messages(
|
| 13 |
+
chatId: str,
|
| 14 |
+
account: AccountModel = Depends(PermissionDependency([AccountType.ADMIN, AccountType.OWNER])),
|
| 15 |
+
) -> CbhResponseWrapper[AllObjectsResponse[MessageModel]]:
|
| 16 |
+
messages, total_count = await get_all_objs(
|
| 17 |
+
MessageModel,
|
| 18 |
+
100000,
|
| 19 |
+
0,
|
| 20 |
+
additional_filter={"accountId": account.id, "chatId": chatId},
|
| 21 |
+
sort=("id", 1),
|
| 22 |
+
)
|
| 23 |
+
return CbhResponseWrapper(
|
| 24 |
+
data=AllObjectsResponse(
|
| 25 |
+
data=messages,
|
| 26 |
+
paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=total_count),
|
| 27 |
+
)
|
| 28 |
+
)
|
cbh/api/platforms/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
|
| 3 |
+
platforms_router = APIRouter(prefix="/platforms", tags=["platforms"])
|
| 4 |
+
|
| 5 |
+
from . import views
|
cbh/api/platforms/db_requests.py
ADDED
|
File without changes
|
cbh/api/platforms/dto.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Category(int, Enum):
|
| 5 |
+
WEB_APPS_SAAS_MVP = 1
|
| 6 |
+
WEBSITES_LANDING_PAGES = 2
|
| 7 |
+
MOBILE_APPS = 3
|
| 8 |
+
UI_UX_DESIGN = 4
|
| 9 |
+
AI_CODING_TOOLS = 5
|
| 10 |
+
AUTOMATION_AI_AGENTS = 6
|
| 11 |
+
VIDEO_CREATIVE = 7
|
| 12 |
+
SEO_GEO = 8
|
| 13 |
+
GROWTH_SOCIAL_REDDIT = 9
|
| 14 |
+
RESEARCH_ANALYTICS = 10
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Level(int, Enum):
|
| 18 |
+
LOW = 1
|
| 19 |
+
LOW_TO_MEDIUM = 2
|
| 20 |
+
MEDIUM = 3
|
| 21 |
+
MEDIUM_TO_HIGH = 4
|
| 22 |
+
HIGH = 5
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ToolType(int, Enum):
|
| 26 |
+
NO_CODE = 1
|
| 27 |
+
HYBRID = 2
|
| 28 |
+
DEV = 3
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Focus(int, Enum):
|
| 32 |
+
WEB = 1
|
| 33 |
+
MOBILE = 2
|
| 34 |
+
DESKTOP = 3
|
| 35 |
+
MULTI_PLATFORM = 4
|
| 36 |
+
MOBILE_DESIGN = 5
|
| 37 |
+
DEVELOPER_WORKFLOW = 6
|
| 38 |
+
DESKTOP_MULTI_PLATFORM_DEV = 7
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MonetizationPriority(int, Enum):
|
| 42 |
+
LOW = 1
|
| 43 |
+
MEDIUM = 2
|
| 44 |
+
HIGH = 3
|
cbh/api/platforms/models.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import Field
|
| 2 |
+
|
| 3 |
+
from cbh.api.platforms.dto import Focus, ToolType, Level, Category, MonetizationPriority
|
| 4 |
+
from cbh.core.database import MongoBaseModel
|
| 5 |
+
|
| 6 |
+
class PlatformModel(MongoBaseModel):
|
| 7 |
+
name: str = Field(description="Platform name, e.g. 'v0', 'Lovable'")
|
| 8 |
+
category: Category = Field(description="Primary category. One of: 1=Web apps/SaaS/MVP, 2=Websites/Landing pages, 3=Mobile apps, 4=UI/UX Design, 5=AI Coding tools, 6=Automation/AI agents, 7=Video/Creative, 8=SEO/GEO, 9=Growth/Social/Reddit, 10=Research/Analytics")
|
| 9 |
+
subcategory: str = Field(description="Specific subcategory within the category, e.g. 'UI + app generation', 'Prompt-based app builder'")
|
| 10 |
+
oneLinePos: str = Field(description="One-line positioning statement describing what the tool does")
|
| 11 |
+
description: str = Field(description="Detailed description of the tool: when to use it, strengths, and best-fit scenarios")
|
| 12 |
+
userQueries: list[str] = Field(description="List of typical user queries/intents this tool covers, e.g. ['I want to build a SaaS', 'I need a dashboard']")
|
| 13 |
+
idealCases: str = Field(description="Description of the ideal client scenario, e.g. 'the client wants a fast visual MVP and is comfortable refining later'")
|
| 14 |
+
personas: list[str] = Field(description="List of recommended user personas, e.g. ['Founder', 'PM', 'Developer']")
|
| 15 |
+
level: Level = Field(description="Required skill level. One of: 1=Low, 2=Low-to-Medium, 3=Medium, 4=Medium-to-High, 5=High")
|
| 16 |
+
toolType: ToolType = Field(description="Tool type. One of: 1=No-code, 2=Hybrid, 3=Dev")
|
| 17 |
+
focus: list[Focus] = Field(description="Platform focus areas. List of: 1=Web, 2=Mobile, 3=Desktop, 4=Multi-platform, 5=Mobile design, 6=Developer workflow, 7=Desktop/Multi-platform dev")
|
| 18 |
+
productStage: list[str] = Field(description="Best product stages, e.g. ['Ideation', 'MVP', 'MVP UI']")
|
| 19 |
+
keyStrengths: list[str] = Field(description="List of key strengths of the tool")
|
| 20 |
+
caveats: list[str] = Field(description="List of main caveats or limitations")
|
| 21 |
+
monetizationPriority: MonetizationPriority = Field(description="Monetization priority. One of: 1=Low, 2=Medium, 3=High")
|
| 22 |
+
website: str = Field(description="Official website URL")
|
| 23 |
+
internalNotes: str = Field(description="Internal notes about when/how to recommend this tool")
|
cbh/api/platforms/utils.py
ADDED
|
File without changes
|
cbh/api/platforms/views.py
ADDED
|
File without changes
|
cbh/api/security/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Security module initialization.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
|
| 7 |
+
security_router = APIRouter(
|
| 8 |
+
prefix="/api/security",
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from . import views # pylint: disable=C0413 # noqa: E402,F401
|
cbh/api/security/db_requests.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database requests module for security functionality.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
from datetime import datetime, timedelta
|
| 7 |
+
|
| 8 |
+
from fastapi import HTTPException
|
| 9 |
+
from passlib.context import CryptContext
|
| 10 |
+
from pydantic import EmailStr
|
| 11 |
+
from pymongo import ReturnDocument
|
| 12 |
+
|
| 13 |
+
from cbh.api.account.dto import AccountType, RegistrationType
|
| 14 |
+
from cbh.api.account.models import AccountModel, AccountShorten
|
| 15 |
+
from cbh.api.security.dto import VerificationCodeStatus, VerificationCodeType
|
| 16 |
+
from cbh.api.security.models import VerificationCodeModel
|
| 17 |
+
from cbh.api.security.schemas import (
|
| 18 |
+
LoginAccountRequest,
|
| 19 |
+
RegisterAccountRequest,
|
| 20 |
+
)
|
| 21 |
+
from cbh.core.config import settings
|
| 22 |
+
from cbh.core.security import verify_password
|
| 23 |
+
from cbh.core.wrappers import background_task
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def check_unique_email(email: EmailStr | str) -> AccountModel | None:
|
| 27 |
+
"""
|
| 28 |
+
Check if a field value already exists in the database to ensure uniqueness.
|
| 29 |
+
"""
|
| 30 |
+
account = await settings.DB_CLIENT.accounts.find_one(
|
| 31 |
+
{"email": {"$regex": f"^{str(email)}$", "$options": "i"}}
|
| 32 |
+
)
|
| 33 |
+
account = AccountModel.from_mongo(account) if account else None
|
| 34 |
+
|
| 35 |
+
return account
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
async def authenticate_account(data: LoginAccountRequest) -> AccountModel:
|
| 39 |
+
"""
|
| 40 |
+
Authenticate a user account using mail and password.
|
| 41 |
+
"""
|
| 42 |
+
account = await settings.DB_CLIENT.accounts.find_one(
|
| 43 |
+
{"email": {"$regex": f"^{data.email}$", "$options": "i"}}
|
| 44 |
+
)
|
| 45 |
+
if account is None:
|
| 46 |
+
raise HTTPException(status_code=404, detail="Invalid email or password.")
|
| 47 |
+
|
| 48 |
+
account = AccountModel.from_mongo(account)
|
| 49 |
+
if account.registrationType != RegistrationType.ORGANIC:
|
| 50 |
+
raise HTTPException(status_code=422, detail="Please sign in with social providers.")
|
| 51 |
+
|
| 52 |
+
if not verify_password(data.password, account.password):
|
| 53 |
+
raise HTTPException(status_code=400, detail="Invalid email or password.")
|
| 54 |
+
|
| 55 |
+
return account
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
async def get_account_by_email(email: str) -> AccountModel | None:
|
| 59 |
+
"""
|
| 60 |
+
Verify if an account exists.
|
| 61 |
+
"""
|
| 62 |
+
account = await settings.DB_CLIENT.accounts.find_one(
|
| 63 |
+
{"email": {"$regex": f"^{email}$", "$options": "i"}}
|
| 64 |
+
)
|
| 65 |
+
return AccountModel.from_mongo(account) if account else None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
async def create_code_obj(
|
| 69 |
+
account: AccountShorten, type_: VerificationCodeType, time_delta: timedelta
|
| 70 |
+
) -> VerificationCodeModel:
|
| 71 |
+
"""
|
| 72 |
+
Create a code object.
|
| 73 |
+
"""
|
| 74 |
+
prev_code = (
|
| 75 |
+
await settings.DB_CLIENT.verificationcodes.find(
|
| 76 |
+
{
|
| 77 |
+
"account.id": account.id,
|
| 78 |
+
"type": type_.value,
|
| 79 |
+
}
|
| 80 |
+
)
|
| 81 |
+
.sort("_id", -1)
|
| 82 |
+
.to_list(length=1)
|
| 83 |
+
)
|
| 84 |
+
prev_code = VerificationCodeModel.from_mongo(prev_code[0]) if prev_code else None
|
| 85 |
+
|
| 86 |
+
if prev_code and prev_code.datetimeInserted > datetime.now() - timedelta(minutes=1):
|
| 87 |
+
raise HTTPException(status_code=429, detail="Too many requests")
|
| 88 |
+
|
| 89 |
+
code = VerificationCodeModel(
|
| 90 |
+
account=account,
|
| 91 |
+
type=type_,
|
| 92 |
+
expiresAt=datetime.now() + time_delta,
|
| 93 |
+
)
|
| 94 |
+
await settings.DB_CLIENT.verificationcodes.insert_one(code.to_mongo())
|
| 95 |
+
return code
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@background_task()
|
| 99 |
+
async def set_used_code(code: VerificationCodeModel | None):
|
| 100 |
+
"""
|
| 101 |
+
Set a code object as used.
|
| 102 |
+
"""
|
| 103 |
+
if code:
|
| 104 |
+
await settings.DB_CLIENT.verificationcodes.update_one(
|
| 105 |
+
{"id": code.id}, {"$set": {"status": VerificationCodeStatus.USED.value}}
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
async def verify_code_obj(
|
| 110 |
+
code_: str, types: list[VerificationCodeType], exception: bool = True, set_used: bool = True
|
| 111 |
+
) -> VerificationCodeModel:
|
| 112 |
+
"""
|
| 113 |
+
Verify a code object.
|
| 114 |
+
"""
|
| 115 |
+
code = (
|
| 116 |
+
await settings.DB_CLIENT.verificationcodes.find(
|
| 117 |
+
{"id": code_, "type": {"$in": [t.value for t in types]}},
|
| 118 |
+
)
|
| 119 |
+
.sort("_id", -1)
|
| 120 |
+
.to_list(length=1)
|
| 121 |
+
)
|
| 122 |
+
code = VerificationCodeModel.from_mongo(code[0]) if code else None
|
| 123 |
+
|
| 124 |
+
if not code and exception:
|
| 125 |
+
error_msg = "Invalid invitation link. Please ask your manager to resend the invite."
|
| 126 |
+
if VerificationCodeType.PASSWORD_RESET in types:
|
| 127 |
+
error_msg = "Invalid password reset link. Please request a new one."
|
| 128 |
+
raise HTTPException(
|
| 129 |
+
status_code=404,
|
| 130 |
+
detail=error_msg,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if code and code.status == VerificationCodeStatus.USED and exception:
|
| 134 |
+
error_map = {
|
| 135 |
+
VerificationCodeType.ORG_INVITATION: "You already created an account. Please sign in.",
|
| 136 |
+
VerificationCodeType.TEAM_INVITATION: "You already accepted this invitation. Please sign in.",
|
| 137 |
+
VerificationCodeType.PASSWORD_RESET: "You already used this reset link. Please request a new one.",
|
| 138 |
+
VerificationCodeType.ORG_CREATION: "You already created an organization. Please sign in.",
|
| 139 |
+
}
|
| 140 |
+
raise HTTPException(status_code=400, detail=error_map[code.type])
|
| 141 |
+
|
| 142 |
+
if code and code.expiresAt < datetime.now() and exception:
|
| 143 |
+
error_msg = "Expired invitation link. Please ask your manager to resend the invite."
|
| 144 |
+
if VerificationCodeType.PASSWORD_RESET in types:
|
| 145 |
+
error_msg = "Expired password reset link. Please request a new one."
|
| 146 |
+
raise HTTPException(status_code=410, detail=error_msg)
|
| 147 |
+
|
| 148 |
+
if code and set_used:
|
| 149 |
+
asyncio.create_task(set_used_code(code))
|
| 150 |
+
return code
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
async def reset_password_obj(account: AccountShorten, password: str) -> AccountShorten:
|
| 154 |
+
"""
|
| 155 |
+
Reset a password object.
|
| 156 |
+
"""
|
| 157 |
+
password = CryptContext(schemes=["bcrypt"], deprecated="auto").hash(password)
|
| 158 |
+
await settings.DB_CLIENT.accounts.update_one(
|
| 159 |
+
{"id": account.id},
|
| 160 |
+
{"$set": {"password": password}},
|
| 161 |
+
)
|
| 162 |
+
return account
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
async def create_google_account(user_info: dict) -> AccountModel:
|
| 167 |
+
account = AccountModel(
|
| 168 |
+
email=user_info["email"],
|
| 169 |
+
name=user_info.get("name"),
|
| 170 |
+
accountType=AccountType.USER,
|
| 171 |
+
registrationType=RegistrationType.GOOGLE,
|
| 172 |
+
)
|
| 173 |
+
await settings.DB_CLIENT.accounts.insert_one(account.to_mongo())
|
| 174 |
+
return account
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
async def create_account(data: RegisterAccountRequest) -> AccountModel:
|
| 178 |
+
account = AccountModel(
|
| 179 |
+
email=data.email,
|
| 180 |
+
password=data.password,
|
| 181 |
+
name=data.name,
|
| 182 |
+
accountType=AccountType.USER,
|
| 183 |
+
registrationType=RegistrationType.ORGANIC,
|
| 184 |
+
)
|
| 185 |
+
await settings.DB_CLIENT.accounts.insert_one(account.to_mongo())
|
| 186 |
+
return account
|
cbh/api/security/dto.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Transfer Objects (DTOs) for security functionality.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from enum import Enum
|
| 6 |
+
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AccessToken(BaseModel):
|
| 11 |
+
"""
|
| 12 |
+
Access token model for authentication.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
type: str = "Bearer"
|
| 16 |
+
value: str
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class GoogleCallbackError(Enum):
|
| 20 |
+
"""
|
| 21 |
+
Error model for Google callback.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
ORGANIZATION_NOT_FOUND = "OrganizationNotFound"
|
| 25 |
+
BLOCKED_ACCOUNT = "BlockedAccount"
|
| 26 |
+
INVALID_CODE = "InvalidCode"
|
| 27 |
+
EMAIL_MISMATCH = "EmailMismatch"
|
| 28 |
+
EXPIRED_CODE = "ExpiredCode"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class VerificationCodeType(Enum):
|
| 32 |
+
"""
|
| 33 |
+
Enum for verification code types.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
PASSWORD_RESET = 1
|
| 37 |
+
TEAM_INVITATION = 2
|
| 38 |
+
ORG_INVITATION = 3
|
| 39 |
+
ORG_CREATION = 4
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class VerificationCodeStatus(Enum):
|
| 43 |
+
"""
|
| 44 |
+
Enum for verification code status.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
PENDING = 1
|
| 48 |
+
USED = 2
|
cbh/api/security/models.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
|
| 3 |
+
from pydantic import Field
|
| 4 |
+
|
| 5 |
+
from cbh.api.account.models import AccountShorten
|
| 6 |
+
from cbh.api.security.dto import VerificationCodeStatus, VerificationCodeType
|
| 7 |
+
from cbh.core.database import MongoBaseModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VerificationCodeModel(MongoBaseModel):
|
| 11 |
+
account: AccountShorten
|
| 12 |
+
expiresAt: datetime
|
| 13 |
+
type: VerificationCodeType
|
| 14 |
+
status: VerificationCodeStatus = VerificationCodeStatus.PENDING
|
| 15 |
+
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
cbh/api/security/schemas.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Schema definitions for security API endpoints.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, EmailStr
|
| 6 |
+
|
| 7 |
+
from cbh.api.account.models import AccountModel
|
| 8 |
+
from cbh.api.account.dto import AccountType
|
| 9 |
+
from cbh.api.security.dto import AccessToken
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LoginAccountRequest(BaseModel):
|
| 14 |
+
"""
|
| 15 |
+
Request model for account login.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
email: EmailStr
|
| 19 |
+
password: str
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LoginAccountResponse(BaseModel):
|
| 23 |
+
"""
|
| 24 |
+
Response model for successful login.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
accessToken: AccessToken | None = None
|
| 28 |
+
account: AccountModel
|
| 29 |
+
code: str | None = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ResetPasswordConfirmRequest(BaseModel):
|
| 33 |
+
"""
|
| 34 |
+
Request model for confirming a password reset.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
code: str
|
| 38 |
+
password: str
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class RegisterAccountRequest(BaseModel):
|
| 42 |
+
"""
|
| 43 |
+
Request model for registering a new account.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
name: str
|
| 47 |
+
email: EmailStr
|
| 48 |
+
password: str
|
cbh/api/security/services/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .oauth import handle_account_oauth
|
| 2 |
+
from .utils import (
|
| 3 |
+
form_google_user_info,
|
| 4 |
+
form_google_login_url,
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"handle_account_oauth",
|
| 9 |
+
"form_google_user_info",
|
| 10 |
+
"form_google_login_url",
|
| 11 |
+
]
|