diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a2185e174e33dbce4328f7f1d22d46441a94bd8a --- /dev/null +++ b/.gitignore @@ -0,0 +1,23 @@ +__pycache__/ +env/ +venv/ +.venv/ +.idea/ +*.log +*.egg-info/ +pip-wheel-EntityData/ +.env +.DS_Store +static/ +test.py +rsa_key.p8 +rsa_key.pub +aws.pem +.vscode/ +data/ +*.csv +test.json +voiceagentcbh.pem +*.pem +download +investigation \ No newline at end of file diff --git a/cbh/__init__.py b/cbh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bddf8feb0d104870773d1482720306cc436079eb --- /dev/null +++ b/cbh/__init__.py @@ -0,0 +1,94 @@ +# pylint: disable=C0415 +""" +ClipboardHealthAI application package. +""" +import asyncio + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from starlette.exceptions import HTTPException as StarletteHTTPException + +from cbh.core.wrappers import CbhResponseWrapper, ErrorCbhResponse + + +def create_app() -> FastAPI: + """ + Create and configure the FastAPI application. + """ + app = FastAPI(docs_url="/api/docs", openapi_url="/api/openapi.json") + + from cbh.api.account import account_router + + app.include_router(account_router, tags=["account"]) + + from cbh.api.calls import calls_router + + app.include_router(calls_router, tags=["calls"]) + + from cbh.api.reports import reports_router + + app.include_router(reports_router, tags=["reports"]) + + from cbh.api.reps import reps_router + + app.include_router(reps_router, tags=["reps"]) + + from cbh.api.security import security_router + + app.include_router(security_router, tags=["security"]) + + from cbh.api.userinsights import userinsights_router + + app.include_router(userinsights_router, tags=["userinsights"]) + + app.add_middleware( + CORSMiddleware, + allow_origin_regex=r"https?://([a-z0-9-]+\.)?(localhost|trainwitharena|cbhexp\.com)(:\d+)?", + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler(_, exc): + """ + Handle HTTP exceptions and convert them to standardized error responses. + """ + return CbhResponseWrapper( + data=None, successful=False, error=ErrorCbhResponse(message=str(exc.detail)) + ).response(exc.status_code) + + @app.on_event("startup") + async def startup_event(): + """ + Execute startup tasks when the application starts. + """ + from cbh.api.calls.services import run_call_listener + # + asyncio.create_task(run_call_listener()) + + @app.get("/api/health") + async def health(): + """ + Health check endpoint for container orchestration and monitoring. + """ + try: + return {"status": "healthy", "database": "connected"} + except Exception as e: + return {"status": "unhealthy", "database": "disconnected", "error": str(e)} + + @app.get("/health") + async def root(): + """ + Root endpoint for the application. + """ + return {"message": "hello!"} + + @app.get("/api/test") + async def root(): + """ + Root endpoint for the application. + """ + return {"message": "hi hello!"} + + return app diff --git a/cbh/api/account/__init__.py b/cbh/api/account/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e61dd2e2b5d54ae3ac5c938aeae5e64cad463600 --- /dev/null +++ b/cbh/api/account/__init__.py @@ -0,0 +1,14 @@ +""" +Account module initialization. + +This module defines the FastAPI router for account API endpoints +and imports related views for account management. +""" + +from fastapi import APIRouter + +account_router = APIRouter( + prefix="/api/account", +) + +from . import views # noqa # pylint: disable=C0413 diff --git a/cbh/api/account/dto.py b/cbh/api/account/dto.py new file mode 100644 index 0000000000000000000000000000000000000000..32edd1816c2742c8dcae8d37ef401afa4b33f55f --- /dev/null +++ b/cbh/api/account/dto.py @@ -0,0 +1,30 @@ +""" +Account DTOs. +""" + +from enum import Enum + +from pydantic import BaseModel + + +class AccountType(Enum): + """ + Enum for account types. + """ + + USER = 1 + ADMIN = 2 + OWNER = 3 + SUPER_ADMIN = 4 + + +class RegistrationType(Enum): + """ + Enum for registration types. + """ + + ORGANIC = 1 + GOOGLE = 2 + GITHUB = 3 + APPLE = 4 + diff --git a/cbh/api/account/models.py b/cbh/api/account/models.py new file mode 100644 index 0000000000000000000000000000000000000000..8d7509142fb18b8aa009fcdff61067e15f6e03b0 --- /dev/null +++ b/cbh/api/account/models.py @@ -0,0 +1,71 @@ +""" +Account models. +""" + +from datetime import datetime + +from passlib.context import CryptContext +from pydantic import Field, field_validator + +from cbh.api.account.dto import AccountType, RegistrationType +from cbh.core.database import MongoBaseModel, MongoBaseShortenModel + + +class AccountModel(MongoBaseModel): + """ + Account model class. + + This class represents a user account in the system. + It includes fields for email, password, and timestamps for creation and update. + """ + + name: str | None = None + email: str + password: str | None = Field(exclude=True, default=None) + + accountType: AccountType = Field(default=AccountType.USER) + registrationType: RegistrationType | None = Field( + default=RegistrationType.ORGANIC, exclude=True + ) + + datetimeInserted: datetime = Field(default_factory=datetime.now) + datetimeUpdated: datetime = Field(default_factory=datetime.now) + + @field_validator("password", mode="before", check_fields=False) + @classmethod + def set_password_hash(cls, v: str | None) -> str | None: + """ + Set the password hash. + + Args: + v (str): The password to hash. + + Returns: + str: The hashed password. + """ + if isinstance(v, str) and not v.startswith("$2b$"): + return CryptContext(schemes=["bcrypt"], deprecated="auto").hash(v) + return v + + + class Config: # pylint: disable=R0903 + """ + Config for the AccountModel class. + """ + + arbitrary_types_allowed = True + populate_by_name = True + json_encoders = {datetime: lambda dt: dt.isoformat()} + + +class AccountShorten(MongoBaseShortenModel): + """ + Account shorten model. + """ + + id: str + name: str | None = None + email: str + pictureUrl: str | None = None + + accountType: AccountType diff --git a/cbh/api/account/utils.py b/cbh/api/account/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cbh/api/account/views.py b/cbh/api/account/views.py new file mode 100644 index 0000000000000000000000000000000000000000..04693e07d871f167e81d0bea82740e6df781f69c --- /dev/null +++ b/cbh/api/account/views.py @@ -0,0 +1,20 @@ +""" +Account views module. +""" + +from fastapi import Depends + +from cbh.api.account import account_router +from cbh.api.account.models import AccountModel +from cbh.core.security import PermissionDependency +from cbh.core.wrappers import CbhResponseWrapper + + +@account_router.get("") +async def get_own_account( + account: AccountModel = Depends(PermissionDependency()), +) -> CbhResponseWrapper[AccountModel]: + """ + Get own account. + """ + return CbhResponseWrapper(data=account) diff --git a/cbh/api/ari/__init__.py b/cbh/api/ari/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f173a0f912a949c1ab00526c82f9111363219b30 --- /dev/null +++ b/cbh/api/ari/__init__.py @@ -0,0 +1,5 @@ +from fastapi import APIRouter + +ari_router = APIRouter(prefix="/api/ari", tags=["ari"]) + +from . import views diff --git a/cbh/api/ari/db_requests.py b/cbh/api/ari/db_requests.py new file mode 100644 index 0000000000000000000000000000000000000000..787ce4b80f712fdfe684590fb506e59d361cdc7d --- /dev/null +++ b/cbh/api/ari/db_requests.py @@ -0,0 +1,41 @@ +from datetime import datetime +from pymongo import ReturnDocument +from cbh.api.ari.dto import Author +from cbh.api.ari.models import ChatModel, MessageModel +from cbh.api.account.models import AccountModel, AccountShorten +from cbh.api.ari.schemas import ChatFilter, CreateMessageRequest +from cbh.api.common.db_requests import get_all_objs +from cbh.api.common.schemas import FilterRequest +from cbh.core.config import settings +from cbh.core.wrappers import background_task + + +async def truncate_message_history(message_id: str, ids_to_delete: list[str]) -> None: + await settings.DB_CLIENT.messages.delete_many({"id": {"$in": [message_id, *ids_to_delete]}}) + + +@background_task() +async def add_messages_obj( + chat_id: str, + account_id: str, + request: CreateMessageRequest, + message_args: tuple[str, str, datetime, datetime], +) -> None: + user_message = MessageModel( + id=request.messageId, + chatId=chat_id, + accountId=account_id, + role=Author.HUMAN, + content=request.content, + datetimeInserted=message_args[2], + ) + assistant_message = MessageModel( + id=message_args[1], + chatId=chat_id, + accountId=account_id, + role=Author.AI, + content=message_args[0], + datetimeInserted=message_args[3], + ) + await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo()) + await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo()) diff --git a/cbh/api/ari/dto.py b/cbh/api/ari/dto.py new file mode 100644 index 0000000000000000000000000000000000000000..607a0eab8983ea16ed210eb24bcd1857749c9879 --- /dev/null +++ b/cbh/api/ari/dto.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class Author(str, Enum): + HUMAN = 'human' + AI = 'ai' + diff --git a/cbh/api/ari/schemas.py b/cbh/api/ari/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..51d3856e0862be0ea227d9c7ac3db3264701ae91 --- /dev/null +++ b/cbh/api/ari/schemas.py @@ -0,0 +1,15 @@ +from typing import Optional +from pydantic import BaseModel + + +class CreateMessageRequest(BaseModel): + content: str + messageId: Optional[str] = None + + +class TranscribeResponse(BaseModel): + """ + Response schema for the transcribe endpoint. + """ + + text: str diff --git a/cbh/api/ari/services/__init__.py b/cbh/api/ari/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4abd423eb5a21aeab3a2c1e9bd441cf19f459f00 --- /dev/null +++ b/cbh/api/ari/services/__init__.py @@ -0,0 +1,3 @@ +from .workflows import convert_audio_to_text + +__all__ = ["convert_audio_to_text"] diff --git a/cbh/api/ari/services/agent/__init__.py b/cbh/api/ari/services/agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd478def2bd31887e1087bb49c591b10e21fa57e --- /dev/null +++ b/cbh/api/ari/services/agent/__init__.py @@ -0,0 +1,8 @@ +""" +This module contains the Ari Agent. +""" + +from .agent import AriAgent + + +__all__ = ["AriAgent"] diff --git a/cbh/api/ari/services/agent/agent.py b/cbh/api/ari/services/agent/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..ed1cb12c0a69bf472956025fa0f26e79f44bf078 --- /dev/null +++ b/cbh/api/ari/services/agent/agent.py @@ -0,0 +1,113 @@ +# pylint: disable=R0801 +""" +This module contains the CBH Agent. +""" +import asyncio + +from langchain_classic.agents import AgentExecutor, create_openai_tools_agent +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + +from cbh.api.account.models import AccountModel +from cbh.api.ari.schemas import CreateMessageRequest +from cbh.api.ari.services.agent.handler import StreamingAgentCallbackHandler +from cbh.core.config import settings +from .prompt import ARI_PROMPT +from .tools import ScenarioAgentTools + + +class AriAgent: + """ + CBH Agent for handling schedule creation. + """ + + def __init__(self, account: AccountModel): + """ + Initialize the CBHAgent with an account. The agent is stateless with + respect to chat history — history is provided per call to `stream`. + """ + self.tools = ScenarioAgentTools.load_tools(account) + self.agent_executor = self._get_agent() + + def _get_agent(self): + """ + Get the agent_pd instance. + """ + return AgentExecutor( + agent=create_openai_tools_agent( + llm=self._get_agent_model(), + tools=self.tools, + prompt=self._load_system_prompt(), + ), + tools=self.tools, + verbose=True, + return_intermediate_steps=True, + max_iterations=100, + ) + + def _get_agent_model(self): + """ + Get the language model used by the agent_pd. + """ + return settings.get_llm( + model="gpt-5.4", reasoning_effort="medium", reasoning_summary="auto" + ) + + def _load_system_prompt(self) -> ChatPromptTemplate: + """ + Load the system prompt from file. + """ + try: + return ChatPromptTemplate.from_messages( + [ + ("system", ARI_PROMPT), + MessagesPlaceholder(variable_name="chat_history"), + ("human", "{content}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + except Exception as e: + raise Exception( # pylint: disable=W0719,W0707 + f"Failed to load system prompt: {str(e)}" + ) + + async def stream( + self, + message_history: list[dict], + request: CreateMessageRequest, + stop_event: asyncio.Event = None, + ): + """ + Stream the agent's response to the client. + """ + queue = [] + + async def send(data): + queue.append(data) + + handler = StreamingAgentCallbackHandler(send) + + task = asyncio.create_task( + self.agent_executor.ainvoke( + { + "content": request.content, + "chat_history": message_history, + }, + config={"callbacks": [handler]}, + ) + ) + + while not task.done() or queue: + if stop_event and stop_event.is_set(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + break + while queue: + data = queue.pop(0) + yield data + await asyncio.sleep(0.01) + + if not task.cancelled(): + await task diff --git a/cbh/api/ari/services/agent/handler.py b/cbh/api/ari/services/agent/handler.py new file mode 100644 index 0000000000000000000000000000000000000000..80fd9028fb0d27f74d9b53ab50a0e08fcba3fb1d --- /dev/null +++ b/cbh/api/ari/services/agent/handler.py @@ -0,0 +1,37 @@ +import pydash +from langchain_core.callbacks import AsyncCallbackHandler + + +class StreamingAgentCallbackHandler(AsyncCallbackHandler): # pylint: disable=R0901 + + def __init__(self, send): + self.send = send + + async def on_chat_model_start(self, *args, **kwargs): + await self.send({"type": "init", "content": ""}) + + async def on_tool_start(self, serialized, input_str, **kwargs): + pass + + async def on_tool_end(self, output, **kwargs): + await self.send({"type": "tool_response", "content": output}) + + async def on_llm_new_token(self, token: str, **kwargs): + chunk = kwargs.get("chunk") + if chunk: + message = pydash.get(chunk.message.content, "[0]") or {} + if message.get("type") == "reasoning": + content = pydash.get(message, "summary[0].text") + if content: + await self.send({"type": "thinking", "content": content}) + elif message.get("type") == "text": + await self.send({"type": "ai_token", "content": message.get("text")}) + + async def on_chain_end(self, outputs, **kwargs): + if ( + isinstance(outputs, dict) + and "intermediate_steps" in outputs.keys() + and "output" in outputs.keys() + ): + ai_message = "" + await self.send({"type": "ai", "content": ai_message}) diff --git a/cbh/api/ari/services/agent/prompt.py b/cbh/api/ari/services/agent/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..14ecef822025879c5e8c70dc250efcd64139597f --- /dev/null +++ b/cbh/api/ari/services/agent/prompt.py @@ -0,0 +1,68 @@ +# pylint: disable-all +# flake8: noqa +""" +This module contains the prompt for the CBH agent. +""" +ARI_PROMPT = """You are **Ari**, a friendly analytics assistant. You help managers and admins understand their organization's sales training performance by answering questions using live data. + +## Your Capabilities +You can retrieve the following data by calling the appropriate tool: +- **Accounts**: search users by name, role, or status +- **Dashboard overview**: session counts, average scores, engagement, and trends +- **Scenario performance**: per-scenario averages, top/bottom reps, most common mistakes +- **Skills averages**: org-wide or per-user breakdown across 6 skill dimensions (communication, active listening, conversation, objection handling, empathy, overall) +- **Leaderboard**: ranked list of reps by score, filterable by scenario and date +- **Attention needs**: reps who are inactive, scenarios with low scores, or scenarios never attempted +- **Scenario details**: statistics, skills breakdown, and leaderboard for a specific scenario +- **Team details**: statistics, skills breakdown, and leaderboard for a specific team +- **Insights**: top mistakes and achievements for a user, team, scenario, or the entire organization + +## Rules +1. **Always call a tool before answering when user asks for data** — never guess or fabricate data. Every data point in your response must come from a tool result. +2. **Resolve names to IDs first** — if the user references a person, team, or scenario by name and you need an ID, call the appropriate search tool first (`search_account`, `search_teams`, `search_scenarios`), then call the target tool with the resolved ID. +3. **Use only the data returned** — summarize and present what the tool gives you. Do not add assumptions, projections, or context that was not returned. +4. **Stay on topic** — only answer questions about the organization's training performance. Politely decline anything unrelated. +5. **Never expose internals** — do not mention tool names, internal errors, raw JSON, or how you work under the hood. +6. **Handle empty results gracefully** — if a tool returns no data, tell the user clearly and suggest a next step (e.g. broaden the date range or check the spelling). + +## Response Format +- Be friendly, concise, and conversational +- Keep replies short — 2–4 sentences or a brief list; avoid walls of text +- Present numbers clearly (e.g. "Average score: **78/100**") +- When showing multiple items (e.g. leaderboard, insights), use a short bullet list +- Do NOT dump raw data — always interpret and summarize it for the user + +## Examples + +**User:** How is the team doing overall? +**Ari:** *[calls retrieve_admin_intro_statistics]* +The team completed **142 sessions** this month (+12% vs last month), with an average score of **74/100**. **38 out of 45 members** were active. Want a breakdown by scenario or skill? + +--- + +**User:** Who are the top performers? +**Ari:** *[calls retrieve_leaderboard with page_size=5]* +Here are your top 5 reps this month: +1. **Sarah M.** — 91 +2. **James T.** — 88 +3. **Priya K.** — 85 +... +Want to see how they perform on a specific scenario? + +--- + +**User:** What mistakes is John making? +**Ari:** *[calls search_account with search_term="John"]* +*[calls get_top_account_insights with account_id=..., type_="mistake"]* +John's most common mistake is **rushing the discovery phase** — he tends to jump to the pitch before fully understanding the customer's situation. Want tips on how to coach this? + +--- + +**User:** Which scenarios need attention? +**Ari:** *[calls retrieve_attention_needs]* +Three scenarios haven't been attempted by anyone yet, and **"Cold Call Objection Drill"** has an average score below the threshold. I'd recommend assigning it and scheduling a coaching session. + +--- + +**User:** Write me a poem. +**Ari:** I can only help with questions about your team's training performance. Ask me about scores, leaderboards, or coaching insights!""" diff --git a/cbh/api/ari/services/agent/tools.py b/cbh/api/ari/services/agent/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..0844d491e71605b44309867c6e42fa773175e7eb --- /dev/null +++ b/cbh/api/ari/services/agent/tools.py @@ -0,0 +1,69 @@ +""" +This module contains the tools for the CBH Agent. +""" + +# pylint: disable=R0801,C0301 +# flake8: noqa +from typing import Literal +from langchain_core.tools import StructuredTool + +from cbh.api.account.models import AccountModel +from cbh.api.common.dto import OrderType + + +GET_TEAM_LEADERBOARD_DESCRIPTION = """Retrieve a ranked leaderboard of users within a specific team. + +Use this tool when the user wants to compare individual performance within a team — +for example, to identify the best and worst performers or run intra-team competitions. +Call search_teams first to obtain the team_id if needed. + +Parameters: +- team_id: The unique ID of the team (required). +- order: Sort direction. Accepted values: + "desc" — highest scores first (top performers at the top). + "asc" — lowest scores first (use to surface struggling members). +- page_size: Number of leaderboard positions to return (default 1). Increase to see more users. + +Returns a ranked list of positions, each containing the user (id, name, picture) +and their Scores across all 6 skill dimensions.""" + + +class ScenarioAgentTools: + """ + Tools for the CBH Agent. + """ + + @staticmethod + def create_team_leaderboard(account: AccountModel): + async def get_team_leaderboard( + team_id: str, + order: Literal["asc", "desc"], + page_size: int = 1, + ): + sort_order = OrderType.DESCENDING.value + if order == "asc": + sort_order = OrderType.ASCENDING.value + results = await get_team_leaderboard_obj( + team_id=team_id, + account=account, + order=sort_order, + page_size=page_size, + page_index=0, + ) + return { + "tool": "get_team_leaderboard", + "value": results.model_dump(mode="json"), + } + + return get_team_leaderboard + + @staticmethod + def load_tools(account: AccountModel) -> list[StructuredTool]: + return [ + + StructuredTool.from_function( + name="get_team_leaderboard", + description=GET_TEAM_LEADERBOARD_DESCRIPTION, + coroutine=ScenarioAgentTools.create_team_leaderboard(account), + ), + ] diff --git a/cbh/api/ari/services/workflows.py b/cbh/api/ari/services/workflows.py new file mode 100644 index 0000000000000000000000000000000000000000..c2eadc847251c9c1435e5d7201be05488ece5662 --- /dev/null +++ b/cbh/api/ari/services/workflows.py @@ -0,0 +1,21 @@ +import io + +from fastapi import HTTPException +from cbh.core.config import settings + + +async def convert_audio_to_text(file: bytes, name: str) -> str: + """ + Convert an audio file to text using OpenAI's Whisper model. + """ + file_content = io.BytesIO(file) + file_content.name = name + + transcription = await settings.OPENAI_CLIENT.audio.transcriptions.create( + file=file_content, model="whisper-1", language="en" + ) + if isinstance(transcription, str): + return transcription + if transcription.text and isinstance(transcription.text, str): + return str(transcription.text) + raise HTTPException(status_code=500, detail="Failed to convert audio to text") diff --git a/cbh/api/ari/utils.py b/cbh/api/ari/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc04e4e8697be5099a876ad64d55582ed26ea49 --- /dev/null +++ b/cbh/api/ari/utils.py @@ -0,0 +1,207 @@ +import asyncio +import io +import statistics +from datetime import datetime +from typing import AsyncGenerator, Callable +from fastapi import WebSocket, WebSocketDisconnect, HTTPException, UploadFile +from cbh.api.ari.schemas import CreateMessageRequest +from cbh.api.ari.db_requests import add_messages_obj, truncate_message_history +from cbh.api.ari.models import MessageModel +from pydub import AudioSegment + +async def send_exception_stream(websocket: WebSocket) -> None: + """Send an exception stream to the client.""" + message = "An error occurred while processing your request. Please try again." + await websocket.send_json({"type": "init", "content": ""}) + for c in message: + await websocket.send_json({"type": "ai_token", "content": c}) + await asyncio.sleep(0.01) + await websocket.send_json({"type": "ai", "content": None}) + await websocket.send_json({"type": "finish", "content": None, "messageId": None}) + + +async def handle_websocket_streaming( + websocket: WebSocket, + pipeline_executor: Callable[[asyncio.Event], AsyncGenerator[dict, None]], + message_id: str | None = None, + ai_message_id: str | None = None, +) -> str: + """ + Handle WebSocket streaming with stop signal support. + + Args: + websocket: WebSocket connection + pipeline_executor: Async generator function that yields chunks and accepts stop_event + """ + stop_event = asyncio.Event() + partial_text = "" + + recv_task = asyncio.create_task(create_stop_receiver(websocket, stop_event)) + + try: + async for chunk in pipeline_executor(stop_event): + await websocket.send_json(chunk) + + extracted_text = extract_partial_text(chunk) + if extracted_text: + partial_text += extracted_text + except asyncio.CancelledError: + stop_event.set() + raise + except RuntimeError: + pass + finally: + recv_task.cancel() + await asyncio.gather(recv_task, return_exceptions=True) + + if stop_event.is_set() and partial_text.strip(): + print(partial_text) + await websocket.send_json( + {"type": "finish", "content": None, "messageId": message_id, "aiMessageId": ai_message_id} + ) + return partial_text + + +async def create_stop_receiver(websocket: WebSocket, stop_event: asyncio.Event) -> None: + """ + Create a receiver task that listens for stop signals from WebSocket. + """ + while True: + try: + msg = await websocket.receive_json() + if msg.get("type") == "stop": + stop_event.set() + break + except WebSocketDisconnect: + break + except Exception: + break + + +def extract_partial_text(chunk: dict) -> str: + """ + Extract text content from a chunk for partial text accumulation. + """ + try: + if chunk.get("type") == "ai_token" and chunk.get("content"): + return chunk.get("content", "") + elif chunk.get("type") == "ai" and chunk.get("content"): + return chunk.get("content", "") + except Exception: + pass + return "" + + +def prepare_messages_from_history(message_history: list[MessageModel]) -> list[dict]: + """ + Prepare messages from message history. + """ + result = [] + for message in message_history: + result.append( + { + "role": message.role.value, + "content": message.content, + } + ) + return result + + +async def truncate_from_message( + message_history: list[MessageModel], request: CreateMessageRequest +) -> list[dict]: + result = [] + ids_to_delete = [] + found = False + + for message in message_history: + msg_id = message.id + if found or msg_id == request.messageId: + found = True + ids_to_delete.append(msg_id) + continue + result.append({"role": message.role.value, "content": message.content}) + + asyncio.create_task(truncate_message_history(request.messageId, ids_to_delete)) + return result + + +async def add_messages( + chat_id: str, + account_id: str, + request: CreateMessageRequest, + messages: list[dict], + message_args: tuple[str, str, datetime, datetime], +) -> list[dict]: + messages.append({"role": "human", "content": request.content}) + messages.append({"role": "ai", "content": message_args[0]}) + asyncio.create_task( + add_messages_obj(chat_id, account_id, request, message_args) + ) + return messages + + + + +async def compress_audio(audio_file: UploadFile) -> bytes | None: + """ + Compress an uploaded audio file to MP3 format. + """ + file_as_bytes = await audio_file.read() + + try: + audio_segment = AudioSegment.from_file(io.BytesIO(file_as_bytes)) + duration = audio_segment.duration_seconds + + if duration > 300: + return None + + mp3_data = io.BytesIO() + audio_segment.export(mp3_data, format="mp3") + mp3_data.seek(0) + + except OSError: + return file_as_bytes + + except Exception as e: # pylint: disable=W0703 + raise HTTPException(status_code=500, detail=str(e)) # pylint: disable=W0707 + + return mp3_data.read() + + +def detect_silence( + audio_data: bytes, silence_threshold: float = -40.0, min_speech_duration: float = 1.0 +) -> bool: + try: + audio_segment = AudioSegment.from_file(io.BytesIO(audio_data)) + chunk_length_ms = 2000 + rms_values = [] + for start_ms in range(0, len(audio_segment), chunk_length_ms): + end_ms = min(start_ms + chunk_length_ms, len(audio_segment)) + chunk = audio_segment[start_ms:end_ms] + if len(chunk) > 0: + chunk_mono = chunk.set_channels(1) + rms_db = chunk_mono.dBFS + rms_values.append(rms_db) + + if not rms_values: + return False + + mean_rms = statistics.mean(rms_values) + + non_silent_chunks = sum(1 for rms in rms_values if rms > silence_threshold) + total_chunks = len(rms_values) + non_silent_ratio = non_silent_chunks / total_chunks if total_chunks > 0 else 0 + + non_silent_duration = (non_silent_chunks * chunk_length_ms) / 1000.0 # Convert to seconds + + has_speech = ( + mean_rms > silence_threshold - 10 + and non_silent_ratio > 0.1 + and non_silent_duration >= min_speech_duration + ) + + return has_speech + + except Exception: + return True diff --git a/cbh/api/ari/views.py b/cbh/api/ari/views.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4cca419bea35fb2c8ee4027f7c4d3178bc041f --- /dev/null +++ b/cbh/api/ari/views.py @@ -0,0 +1,99 @@ +import asyncio +from datetime import datetime +from fastapi import Depends, WebSocket, WebSocketDisconnect, UploadFile, File, HTTPException +from bson import ObjectId + +from cbh.api.account.models import AccountModel +from cbh.api.ari import ari_router +from cbh.api.ari.services import convert_audio_to_text +from cbh.api.ari.schemas import ( + CreateMessageRequest, TranscribeResponse, +) +from cbh.api.ari.services.agent import AriAgent +from cbh.api.ari.utils import ( + prepare_messages_from_history, + send_exception_stream, + handle_websocket_streaming, + truncate_from_message, + add_messages, detect_silence, compress_audio, +) +from cbh.api.common.db_requests import get_all_objs, get_obj_by_id +from cbh.core.security import PermissionDependency, check_account_token +from cbh.core.wrappers import CbhResponseWrapper + + +@ari_router.websocket("/{chatId}/send") +async def send_ari_message(chatId: str, websocket: WebSocket): + await websocket.accept() + token = websocket.query_params.get("token") + token = check_account_token(token) + if not token: + await websocket.close(code=1008) + return + + account, (message_history, _) = await asyncio.gather( + get_obj_by_id(AccountModel, token["account_id"]), + get_all_objs( + MessageModel, + 100000, + 0, + additional_filter={"accountId": token["account_id"], "chatId": chatId}, + sort=("id", 1), + ), + ) + agent = AriAgent(account) + + messages = prepare_messages_from_history(message_history) + while True: + try: + request = await websocket.receive_json() + request = CreateMessageRequest(**request) + user_message_time = datetime.now() + + if request.messageId: + messages = await truncate_from_message(message_history, request) + else: + request.messageId = str(ObjectId()) + + async def agent_executor(stop_event): + async for chunk in agent.stream(messages, request, stop_event): + yield chunk + await asyncio.sleep(0.01) + + ai_message_id = str(ObjectId()) + ai_response = await handle_websocket_streaming( + websocket=websocket, + pipeline_executor=agent_executor, + message_id=request.messageId, + ai_message_id=ai_message_id, + ) + ai_message_time = datetime.now() + + message_args = (ai_response, ai_message_id, ai_message_time, user_message_time) + messages = await add_messages(chatId, account.id, request, messages, message_args) + + except WebSocketDisconnect: + return + except Exception as e: + await send_exception_stream(websocket) + + +@ari_router.post("/voice/transcript") +async def get_voice_transcript( + file: UploadFile = File(...), + _: AccountModel = Depends(PermissionDependency()), +) -> CbhResponseWrapper[TranscribeResponse]: + """ + Transcribe an uploaded audio file. + """ + mp3_data = await compress_audio(file) + if mp3_data is None: + raise HTTPException(status_code=400, detail="Could not compress audio file") + if not detect_silence(mp3_data): + raise HTTPException( + status_code=422, + detail="Recording appears to contain only silence or background noise.", + ) + + transcribed_text = await convert_audio_to_text(mp3_data, file.filename) + return CbhResponseWrapper(data=TranscribeResponse(text=transcribed_text)) diff --git a/cbh/api/chats/__init__.py b/cbh/api/chats/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..159433367f0fa1b0f654cb4bceb2ebb9b8d78ac2 --- /dev/null +++ b/cbh/api/chats/__init__.py @@ -0,0 +1,5 @@ +from fastapi import APIRouter + +chats_router = APIRouter(prefix="/chats", tags=["chats"]) + +from . import views \ No newline at end of file diff --git a/cbh/api/chats/db_requests.py b/cbh/api/chats/db_requests.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b19e13b3d29d78d18f51974a21a5f518d9f55d --- /dev/null +++ b/cbh/api/chats/db_requests.py @@ -0,0 +1,45 @@ +from pymongo import ReturnDocument + +from cbh.api.account.models import AccountModel, AccountShorten +from cbh.api.chats.models import ChatModel +from cbh.api.chats.schemas import ChatFilter +from cbh.api.common.db_requests import get_all_objs +from cbh.api.common.schemas import FilterRequest +from cbh.core.config import settings + + +async def create_chat_obj(account: AccountModel, name: str) -> ChatModel: + """ + Create a chat object. + """ + chat = ChatModel(account=AccountShorten(**account.model_dump()), name=name) + await settings.DB_CLIENT.chats.insert_one(chat.to_mongo()) + return chat + + +async def update_chat_obj(account: AccountModel, chat_id: str, name: str) -> ChatModel: + """ + Update a chat object. + """ + chat = await settings.DB_CLIENT.chats.find_one_and_update( + {"id": chat_id, "account.id": account.id}, + {"$set": {"name": name}}, + return_document=ReturnDocument.AFTER, + ) + return ChatModel.from_mongo(chat) + + +async def filter_chats_objs( + account: AccountModel, request: FilterRequest[ChatFilter] +) -> tuple[list[ChatModel], int]: + """ + Filter chats objects. + """ + filter_ = {"account.id": account.id} + if request.filter.searchTerm: + filter_["name"] = {"$regex": f"^{request.filter.searchTerm}", "$options": "i"} + + chats, total_count = await get_all_objs( + ChatModel, request.pageSize, request.pageIndex, additional_filter=filter_ + ) + return chats, total_count diff --git a/cbh/api/chats/models.py b/cbh/api/chats/models.py new file mode 100644 index 0000000000000000000000000000000000000000..61eb2459565e67ba8dc0ecccc5a9e118e224957b --- /dev/null +++ b/cbh/api/chats/models.py @@ -0,0 +1,12 @@ +from datetime import datetime + +from pydantic import Field + +from cbh.api.account.models import AccountShorten +from cbh.core.database import MongoBaseModel + + +class ChatModel(MongoBaseModel): + name: str + account: AccountShorten + datetimeInserted: datetime = Field(default_factory=datetime.now) diff --git a/cbh/api/chats/schemas.py b/cbh/api/chats/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..3164f3ff978d6ae2fb05bb7071b01c78e3ef9a16 --- /dev/null +++ b/cbh/api/chats/schemas.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + + +class CreateChatRequest(BaseModel): + query: str + + +class UpdateChatRequest(BaseModel): + name: str + + +class ChatFilter(BaseModel): + searchTerm: str | None = None diff --git a/cbh/api/chats/services/__init__.py b/cbh/api/chats/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbfec9b8aea8f3e817c0ac58c2609806ab267472 --- /dev/null +++ b/cbh/api/chats/services/__init__.py @@ -0,0 +1,3 @@ +from .workflows import generate_chat_name + +__all__ = ["generate_chat_name"] diff --git a/cbh/api/chats/services/prompts.py b/cbh/api/chats/services/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..473d593f06e8d829c6f5a86c52c9a5202687c554 --- /dev/null +++ b/cbh/api/chats/services/prompts.py @@ -0,0 +1,57 @@ +from functools import lru_cache + + +class ChatPrompts: + """ + Ari prompts. + """ + + generate_chat_name = """You are a title generator for a sales team management assistant. + +Generate a concise, descriptive title (2–4 words) for a chat based on the user's first message. + +## User message: +{query} + +## Rules +- Use title case +- No punctuation or filler words +- Capture the specific intent, not just the topic +- Prefer action + subject format when applicable +- Output only the title, nothing else + +## Context +The assistant helps admins who manage sales teams. Users ask for statistics, rep performance, scenario creation, and other team management tasks. + +## Examples + +query: Who is the best sales rep in the "Connor's team" team? +title: Top Rep Connor's Team + + +query: What is the average score for Maksim Shymanouski? +title: Maksim Average Score + + +query: Create me a friendly scenario with 5 objections +title: Friendly 5-Objection Scenario + + +query: Show me the call conversion rate for last month +title: Last Month Conversion Rate + + +query: Which reps haven't hit their targets this quarter? +title: Reps Missing Quarterly Targets +""" + + +@lru_cache() +def get_prompts() -> ChatPrompts: + """ + Get prompts. + """ + return ChatPrompts() + + +chat_prompts = get_prompts() diff --git a/cbh/api/chats/services/workflows.py b/cbh/api/chats/services/workflows.py new file mode 100644 index 0000000000000000000000000000000000000000..29dd08bc8b4b5ccb771096b7ece1dbea13d40665 --- /dev/null +++ b/cbh/api/chats/services/workflows.py @@ -0,0 +1,23 @@ +from langchain_core.prompts import ChatPromptTemplate +from pydantic import BaseModel, Field + +from cbh.api.chats.services.prompts import chat_prompts +from cbh.core.config import settings + + +class ChatNameSchema(BaseModel): + """ + Chat name schema. + """ + + name: str = Field(description="A name for the chat.") + + +async def generate_chat_name(query: str) -> str: + """ + Generate a chat name. + """ + prompt = ChatPromptTemplate.from_messages([("system", chat_prompts.generate_chat_name)]) + chain = prompt | settings.get_llm(schema=ChatNameSchema, model="gpt-4.1-nano") + result = await chain.ainvoke({"query": query}) + return result.name diff --git a/cbh/api/chats/views.py b/cbh/api/chats/views.py new file mode 100644 index 0000000000000000000000000000000000000000..27dab5c58bd0a0363d29ff339c541d9c88047bc6 --- /dev/null +++ b/cbh/api/chats/views.py @@ -0,0 +1,49 @@ +from fastapi import Depends + +from cbh.api.account.dto import AccountType +from cbh.api.account.models import AccountModel +from cbh.api.chats import chats_router +from cbh.api.chats.db_requests import create_chat_obj, update_chat_obj, filter_chats_objs +from cbh.api.chats.models import ChatModel +from cbh.api.chats.schemas import CreateChatRequest, UpdateChatRequest, ChatFilter +from cbh.api.chats.services import generate_chat_name +from cbh.api.common.schemas import AllObjectsResponse, Paging +from cbh.api.common.schemas import FilterRequest +from cbh.core.security import PermissionDependency +from cbh.core.wrappers import CbhResponseWrapper + +@chats_router.post("/chats") +async def create_new_chat( + request: CreateChatRequest, + account: AccountModel = Depends(PermissionDependency()), +) -> CbhResponseWrapper[ChatModel]: + chat_name = await generate_chat_name(request.query) + chat = await create_chat_obj(account, chat_name) + return CbhResponseWrapper(data=chat) + + +@chats_router.put("/chats/{chatId}") +async def update_chat( + chatId: str, + request: UpdateChatRequest, + account: AccountModel = Depends(PermissionDependency([AccountType.ADMIN, AccountType.OWNER])), +) -> CbhResponseWrapper[ChatModel]: + chat = await update_chat_obj(account, chatId, request.name) + return CbhResponseWrapper(data=chat) + + +@chats_router.post("/chats/filter") +async def filter_chats( + request: FilterRequest[ChatFilter], + account: AccountModel = Depends(PermissionDependency([AccountType.ADMIN, AccountType.OWNER])), +) -> CbhResponseWrapper[AllObjectsResponse[ChatModel]]: + chats, total_count = await filter_chats_objs(account, request) + return CbhResponseWrapper( + data=AllObjectsResponse( + data=chats, + paging=Paging( + pageSize=request.pageSize, pageIndex=request.pageIndex, totalCount=total_count + ), + ) + ) + diff --git a/cbh/api/common/db_requests.py b/cbh/api/common/db_requests.py new file mode 100644 index 0000000000000000000000000000000000000000..35fea13ac381dd5a75511fbaef918b411f906f73 --- /dev/null +++ b/cbh/api/common/db_requests.py @@ -0,0 +1,196 @@ +""" +Common database requests. +""" + +import asyncio +import re +from datetime import timedelta, datetime +from typing import TypeVar + +from fastapi import HTTPException +from pydantic import BaseModel + +from cbh.api.common.schemas import ( + SearchRequest, +) +from cbh.core.config import settings + +T = TypeVar("T", bound=BaseModel) + +collection_map = { + "AccountModel": "accounts", + "AccountShorten": "accounts", + "UserInsightModel": "userinsights", + "UserInsightShorten": "userinsights", + "CallModel": "calls", + "RepModel": "reps", +} + + +async def get_obj_by_id( + model: T, + obj_id: str | None, + additional_filter: dict | None = None, + projection: dict | None = None, + exception: bool = True, +) -> T | None: + """ + Get an object by ID. + """ + filter_ = {"id": obj_id} if obj_id else {} + if additional_filter: + filter_.update(additional_filter) + obj = await settings.DB_CLIENT[collection_map[model.__name__]].find_one(filter_, projection) + if obj is None: + if exception: + raise HTTPException(status_code=404, detail="Object not found.") + else: + return None + return model.from_mongo(obj) + + +async def get_all_objs( + model: T, + page_size: int, + page_index: int, + sort: tuple[str, int] = ("id", -1), + additional_filter: dict | None = None, + projection: dict | None = None, +) -> tuple[list[T], int]: + """ + Get all objects. + """ + filter_ = additional_filter if additional_filter else {} + skip = page_index * page_size + objs, total_count = await asyncio.gather( + settings.DB_CLIENT[collection_map[model.__name__]] + .find(filter_, projection) + .sort(*sort) + .skip(skip) + .limit(page_size) + .to_list(page_size), + settings.DB_CLIENT[collection_map[model.__name__]].count_documents(filter_), + ) + return [model.from_mongo(obj) for obj in objs], total_count + + +async def delete_obj( + model: T, obj_id: str | None = None, additional_filter: dict | None = None +) -> T: + """ + Delete an object. + """ + filter_ = {"id": obj_id} if obj_id else {} + if additional_filter: + filter_.update(additional_filter) + obj = await settings.DB_CLIENT[collection_map[model.__name__]].find_one(filter_) + if obj is None: + raise HTTPException(status_code=404, detail="Object not found.") + await settings.DB_CLIENT[collection_map[model.__name__]].delete_one(filter_) + return model.from_mongo(obj) + + +async def search_objs( + model: T, + data: SearchRequest, + additional_filter: dict | None = None, + projection: dict | None = None, +) -> tuple[list[T], int]: + """ + Search for objects in a specified collection based on search filters. + """ + filters = [] + date_filters = {} + + for search_filter in data.filter: + if isinstance(search_filter.value, str): + date_match = re.fullmatch(r"^(\d{4}-\d{2}-\d{2});([+-]\d{1,2})$", search_filter.value) + + if date_match: + if search_filter.name not in date_filters: + date_filters[search_filter.name] = [] + + date_filters[search_filter.name].append( + { + "date": datetime.strptime(date_match.group(1), "%Y-%m-%d"), + "timezone_offset": int(date_match.group(2)), + } + ) + else: + filters.append( + { + search_filter.name: { + "$regex": f"^{re.escape(search_filter.value)}", + "$options": "i", + } + } + ) + else: + filters.append({search_filter.name: search_filter.value}) + + for field_name, dates in date_filters.items(): + if len(dates) == 1: + date_info = dates[0] + user_local_day_start = date_info["date"] + user_local_day_end = user_local_day_start + timedelta(days=1) + filters.append( + { + field_name: { + "$gte": ( + user_local_day_start - timedelta(hours=date_info["timezone_offset"]) + ).isoformat(), + "$lt": ( + user_local_day_end - timedelta(hours=date_info["timezone_offset"]) + ).isoformat(), + } + } + ) + elif len(dates) == 2: + start_date = min(dates, key=lambda x: x["date"]) + end_date = max(dates, key=lambda x: x["date"]) + + start_datetime = start_date["date"] - timedelta(hours=start_date["timezone_offset"]) + end_datetime = ( + end_date["date"] + timedelta(days=1) - timedelta(hours=end_date["timezone_offset"]) + ) + + filters.append( + { + field_name: { + "$gte": start_datetime.isoformat(), + "$lt": end_datetime.isoformat(), + } + } + ) + elif len(dates) > 2: + dates_sorted = sorted(dates, key=lambda x: x["date"]) + start_date = dates_sorted[0] + end_date = dates_sorted[-1] + + start_datetime = start_date["date"] - timedelta(hours=start_date["timezone_offset"]) + end_datetime = ( + end_date["date"] + timedelta(days=1) - timedelta(hours=end_date["timezone_offset"]) + ) + + filters.append( + { + field_name: { + "$gte": start_datetime.isoformat(), + "$lt": end_datetime.isoformat(), + } + } + ) + + if additional_filter: + filters.append(additional_filter) + regex_filter = {"$and": filters} if filters else {} + objects, total_count = await asyncio.gather( + settings.DB_CLIENT[collection_map[model.__name__]] + .find(regex_filter, projection) + .sort("id", -1) + .skip(data.pageSize * data.pageIndex) + .limit(data.pageSize) + .to_list(length=data.pageSize), + settings.DB_CLIENT[collection_map[model.__name__]].count_documents(regex_filter), + ) + return [model.from_mongo(obj) for obj in objects], total_count diff --git a/cbh/api/common/dto.py b/cbh/api/common/dto.py new file mode 100644 index 0000000000000000000000000000000000000000..a4cded81fc774b95be8b0dabbf41fc6d3308b5f4 --- /dev/null +++ b/cbh/api/common/dto.py @@ -0,0 +1,149 @@ +""" +Common DTOs. +""" + +from datetime import datetime +from enum import Enum + +from pydantic import BaseModel, field_validator + +from cbh.core.config import settings + + +class Paging(BaseModel): + """ + Pagination model for API responses. + """ + + pageSize: int + pageIndex: int + totalCount: int + + +class SearchFilter(BaseModel): + """ + Search filter model for constructing database queries. + + Attributes: + name (str): Field name to filter on + value (str | int): Value to search for + """ + + name: str + value: str | int + + +class Scores(BaseModel): + """ + Scores for the recording. + """ + + communication: int + activeListening: int + conversation: int + objection: int + empathy: int + final: int + + def __sub__(self, other: "Scores") -> "Scores": + return Scores( + communication=self.communication - other.communication, + activeListening=self.activeListening - other.activeListening, + conversation=self.conversation - other.conversation, + objection=self.objection - other.objection, + empathy=self.empathy - other.empathy, + final=self.final - other.final, + ) + + def __mod__(self, other: "Scores") -> "Scores": + def calc_percentage_diff(current: int, previous: int) -> int: + if previous == 0: + return 0 + return int(((current - previous) / previous) * 100) + + return Scores( + communication=calc_percentage_diff(self.communication, other.communication), + activeListening=calc_percentage_diff( + self.activeListening, other.activeListening + ), + conversation=calc_percentage_diff(self.conversation, other.conversation), + objection=calc_percentage_diff(self.objection, other.objection), + empathy=calc_percentage_diff(self.empathy, other.empathy), + final=calc_percentage_diff(self.final, other.final), + ) + + +class DateValue(BaseModel): + """ + Date value. + """ + + date: datetime + value: int + + +class ValueDelta(BaseModel): + """ + Value delta. + """ + + value: int + delta: int + + +class IDName(BaseModel): + id: str | int + name: str + + +class IDNamePicture(IDName): + pictureUrl: str | None + + @field_validator("pictureUrl", mode="before") + @classmethod + def serialize_picture_url(cls, v: str | None) -> str | None: + """ + Serialize the picture URL. + """ + if v: + if not v.startswith("https://"): + return settings.S3_CLIENT.generate_presigned_url(v, expiration=3600) + return v + return None + + +class OrderType(Enum): + """ + Order type. + """ + + ASCENDING = 1 + DESCENDING = -1 + + +class SortBy(BaseModel): + """ + Sort by. + """ + + name: str + order: OrderType + + +class TokenUsage(BaseModel): + inputTokens: int + outputTokens: int + + +class SkillStatistics(BaseModel): + """ + Skill statistics model. + """ + + score: int + bestAccount: IDNamePicture | None = None + + +class InsightType(Enum): + MISTAKE = 1 + ACHIEVEMENT = 2 diff --git a/cbh/api/common/schemas.py b/cbh/api/common/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..9a272bb3ffd5859b324638d9e4489dbaafef519e --- /dev/null +++ b/cbh/api/common/schemas.py @@ -0,0 +1,121 @@ +""" +Common schemas. +""" + +from datetime import date +from typing import TypeVar, Generic + +from pydantic import BaseModel + +from cbh.api.common.dto import ( + OrderType, + Paging, + SearchFilter, + SkillStatistics, + SortBy, + DateValue, InsightType, +) + +T = TypeVar("T", bound=BaseModel) + + +class AllObjectsResponse(BaseModel, Generic[T]): + """ + Response model for all objects. + """ + + paging: Paging + data: list[T] + + +class SearchRequest(BaseModel): + """ + Request schema for searching calls or statistics. + + Attributes: + filter (list[SearchFilter]): List of filters to apply + pageSize (int): Number of items to return per page + pageIndex (int): Page index to retrieve + """ + + filter: list[SearchFilter] + pageSize: int + pageIndex: int + + +class FilterRequest(BaseModel, Generic[T]): + """ + Filter request. + """ + + filter: T + sortBy: SortBy | None = None + pageSize: int = 10 + pageIndex: int = 0 + + +class PlainTextResponse(BaseModel): + """ + Response model for plain text. + """ + + text: str + + +class BatchIdsRequest(BaseModel): + """ + Batch ids request. + """ + + ids: list[str] + + +class EmailRequest(BaseModel): + """ + Email request. + """ + + email: str + + +class OrderTypeRequest(BaseModel): + """ + Order type request. + """ + + order: OrderType + + +class SkillsStatisticsResponse(BaseModel): + """ + Scenario skills statistics response. + """ + + communication: SkillStatistics + activeListening: SkillStatistics + conversation: SkillStatistics + objection: SkillStatistics + empathy: SkillStatistics + final: int + + +class RepProgressResponse(BaseModel): + """ + Rep progress response. + """ + + progress: list[DateValue] + + +class AIInsightsResponse(BaseModel): + """ + AI insights. + """ + + aiInsights: str | None = None + + +class InsightFilter(BaseModel): + type: InsightType + startDate: date | None = None + endDate: date | None = None diff --git a/cbh/api/common/utils.py b/cbh/api/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe512ae1b80ed57645fc9d1cd47ae76be2d7a06 --- /dev/null +++ b/cbh/api/common/utils.py @@ -0,0 +1,244 @@ +""" +Common utilities. +""" + +import asyncio +import base64 +import re +from typing import Callable, TypeVar + +import httpx + +from cbh.api.account.dto import AccountType +from cbh.api.account.models import AccountModel +from cbh.api.common.dto import ( + LeaderboardStatisticsPosition, + Scores, + OrderType, + SkillStatistics, + IDNamePicture, +) +from cbh.api.common.schemas import SkillsStatisticsResponse +from cbh.core.config import settings + +T = TypeVar("T") + + +def calculate_avg_scores(reports: list) -> Scores: + if not reports: + return Scores( + communication=0, + activeListening=0, + conversation=0, + objection=0, + empathy=0, + final=0, + ) + return Scores( + communication=round(sum(report.scores.communication for report in reports) / len(reports)), + activeListening=round( + sum(report.scores.activeListening for report in reports) / len(reports) + ), + conversation=round(sum(report.scores.conversation for report in reports) / len(reports)), + objection=round(sum(report.scores.objection for report in reports) / len(reports)), + empathy=round(sum(report.scores.empathy for report in reports) / len(reports)), + final=round(sum(report.scores.final for report in reports) / len(reports)), + ) + + +def form_user_stats(session_reports: list) -> dict[str, dict]: + user_stats = {} + for report in session_reports: + user_id = report.account.id + if user_id not in user_stats: + user_stats[user_id] = { + "account": report.account, + "reports": [], + "attempts": 0, + } + user_stats[user_id]["reports"].append(report) + user_stats[user_id]["attempts"] += 1 + return user_stats + + +def leaderboard_sort_key(item: tuple, reverse: bool = False) -> tuple: + user_data, score = item + if reverse: + return -score, user_data["attempts"], user_data["account"].name.lower() + + return score, -user_data["attempts"], user_data["account"].name.lower() + + +def build_leaderboard( + session_reports: list, + position_builder: Callable[[dict, Scores], T], + order: OrderType | None = None, +) -> list[T]: + user_stats = form_user_stats(session_reports) + user_scores = [ + (user_data, calculate_avg_scores(user_data["reports"]).final) + for user_data in user_stats.values() + ] + + is_descending = order is None or order == OrderType.DESCENDING + sorted_users = sorted( + user_scores, key=lambda item: leaderboard_sort_key(item, reverse=is_descending) + ) + + leaderboard = [] + for user_data, _ in sorted_users: + avg_scores = calculate_avg_scores(user_data["reports"]) + leaderboard.append(position_builder(user_data, avg_scores)) + + return leaderboard + + +def build_leaderboard_simple( + session_reports: list, + pageSize: int | None = None, + pageIndex: int | None = None, + order: OrderType | None = None, +) -> list[LeaderboardStatisticsPosition]: + def position_builder(user_data: dict, avg_scores: Scores) -> LeaderboardStatisticsPosition: + return LeaderboardStatisticsPosition( + account=IDNamePicture( + id=user_data["account"].id, + name=user_data["account"].name, + pictureUrl=user_data["account"].pictureUrl, + ), + scores=avg_scores, + ) + + leaderboard = build_leaderboard(session_reports, position_builder, order) + return paginate_list(leaderboard, pageSize, pageIndex) + + +def paginate_list( + items: list[T], page_size: int | None = None, page_index: int | None = None +) -> list[T]: + if page_size is None or page_index is None: + return items + return items[page_index * page_size : (page_index + 1) * page_size] + + +def form_additional_scenario_filter(account: AccountModel, allow_demo: bool = False): + from cbh.api.scenario.dto import AssigneesType, ScenarioStatus + + filter_ = {"owner.organization.id": account.organization.id} + if account.accountType == AccountType.USER: + filter_.update( + { + "$or": [ + {"assignees": {"$size": 0}}, + { + "assignees": { + "$elemMatch": { + "type": AssigneesType.USER.value, + "account.id": account.id, + } + } + }, + { + "assignees": { + "$elemMatch": { + "type": AssigneesType.TEAM.value, + "team.members": {"$elemMatch": {"id": account.id}}, + } + } + }, + ], + "isTemplate": False, + "status": ScenarioStatus.ACTIVE.value, + } + ) + if not allow_demo or account.accountType != AccountType.USER: + filter_.update( + { + "isDemo": False, + } + ) + return filter_ + + +async def convert_document_to_text(file: bytes, filename: str) -> str: + if filename.endswith(".txt"): + return file.decode("utf-8", errors="ignore") + filename = re.sub(r"[^\w\s.-]", "", filename) + base64_file = base64.b64encode(file).decode("utf-8") + headers = {"Content-Type": "application/json"} + data = { + "apikey": settings.CONVERTIO_API_KEY, + "input": "base64", + "file": base64_file, + "filename": filename, + "outputformat": "txt", + } + async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=120)) as client: + response = await client.post("https://api.convertio.co/convert", json=data, headers=headers) + response = response.json() + if response["code"] == 200: + conversion_id = response["data"]["id"] + status = "" + attempt = 0 + while status != "finish": + if attempt > 50: + raise Exception("Please, try again") + get_status_response = await client.get( + f"https://api.convertio.co/convert/{conversion_id}/status" + ) + get_status_response = get_status_response.json() + if get_status_response["code"] != 200: + raise Exception("Please, try again") + else: + status = get_status_response["data"]["step"] + await asyncio.sleep(1) + attempt += 1 + file_url = get_status_response["data"]["output"]["url"] + response = await client.get(file_url) + response.raise_for_status() + return response.content.decode("utf-8", errors="ignore") + else: + return "" + + +def calculate_skills_statistics(session_reports: list) -> SkillsStatisticsResponse: + """ + Calculate team skills statistics. + """ + if not session_reports: + empty_skill = SkillStatistics(score=0, bestAccount=None) + return SkillsStatisticsResponse( + communication=empty_skill, + activeListening=empty_skill, + conversation=empty_skill, + objection=empty_skill, + empathy=empty_skill, + final=0, + ) + avg_scores = calculate_avg_scores(session_reports) + skills = ["communication", "activeListening", "conversation", "objection", "empathy"] + + skill_stats = {} + for skill in skills: + best_report = sorted( + session_reports, + key=lambda r, s=skill: (getattr(r.scores, s), r.datetimeInserted.timestamp()), + reverse=True, + )[0] + skill_stats[skill] = SkillStatistics( + score=getattr(avg_scores, skill), + bestAccount=IDNamePicture( + id=best_report.account.id, + name=best_report.account.name, + pictureUrl=best_report.account.pictureUrl, + ), + ) + + return SkillsStatisticsResponse( + communication=skill_stats["communication"], + activeListening=skill_stats["activeListening"], + conversation=skill_stats["conversation"], + objection=skill_stats["objection"], + empathy=skill_stats["empathy"], + final=avg_scores.final, + ) diff --git a/cbh/api/messages/__init__.py b/cbh/api/messages/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f45ae68f8d30533f375a437761a2338a465edd --- /dev/null +++ b/cbh/api/messages/__init__.py @@ -0,0 +1,5 @@ +from fastapi import APIRouter + +messages_router = APIRouter(prefix="/messages", tags=["messages"]) + +from . import views \ No newline at end of file diff --git a/cbh/api/messages/db_requests.py b/cbh/api/messages/db_requests.py new file mode 100644 index 0000000000000000000000000000000000000000..c503be1c94bf5e85880252031c69c35a88b3950b --- /dev/null +++ b/cbh/api/messages/db_requests.py @@ -0,0 +1,23 @@ +from cbh.api.common.db_requests import get_all_objs +from cbh.api.messages.models import MessageModel + + +async def get_chat_message_history(chat_id: str, account_id: str) -> list[dict]: + """ + Get the message history for a chat. + """ + messages = await get_all_objs( + MessageModel, + 100000, + 0, + additional_filter={"chat.account.id": account_id, "chat.id": chat_id}, + ) + response = [] + for message in messages: + response.append( + { + "role": message.role, + "content": message.content, + } + ) + return messages diff --git a/cbh/api/messages/dto.py b/cbh/api/messages/dto.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cbh/api/messages/models.py b/cbh/api/messages/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e2de13685370019d8b22712188af73c222d014 --- /dev/null +++ b/cbh/api/messages/models.py @@ -0,0 +1,13 @@ +from datetime import datetime +from pydantic import Field + +from cbh.api.ari.dto import Author +from cbh.core.database import MongoBaseModel + + +class MessageModel(MongoBaseModel): + role: Author + content: str + chatId: str + accountId: str + datetimeInserted: datetime = Field(default_factory=datetime.now) diff --git a/cbh/api/messages/utils.py b/cbh/api/messages/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cbh/api/messages/views.py b/cbh/api/messages/views.py new file mode 100644 index 0000000000000000000000000000000000000000..df0b2b4c636ef9555aec6d294717304a751022db --- /dev/null +++ b/cbh/api/messages/views.py @@ -0,0 +1,28 @@ +from fastapi import Depends +from cbh.api.account.models import AccountModel +from cbh.api.account.dto import AccountType +from cbh.api.messages import messages_router +from cbh.api.messages.models import MessageModel +from cbh.api.common.schemas import AllObjectsResponse, Paging +from cbh.api.common.db_requests import get_all_objs +from cbh.core.security import PermissionDependency +from cbh.core.wrappers import CbhResponseWrapper + +@messages_router.get("/messages/{chatId}") +async def get_chat_messages( + chatId: str, + account: AccountModel = Depends(PermissionDependency([AccountType.ADMIN, AccountType.OWNER])), +) -> CbhResponseWrapper[AllObjectsResponse[MessageModel]]: + messages, total_count = await get_all_objs( + MessageModel, + 100000, + 0, + additional_filter={"accountId": account.id, "chatId": chatId}, + sort=("id", 1), + ) + return CbhResponseWrapper( + data=AllObjectsResponse( + data=messages, + paging=Paging(pageSize=len(messages), pageIndex=0, totalCount=total_count), + ) + ) diff --git a/cbh/api/platforms/__init__.py b/cbh/api/platforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb260c98ef844b1bd7c53dab66b4e27837b2829 --- /dev/null +++ b/cbh/api/platforms/__init__.py @@ -0,0 +1,5 @@ +from fastapi import APIRouter + +platforms_router = APIRouter(prefix="/platforms", tags=["platforms"]) + +from . import views \ No newline at end of file diff --git a/cbh/api/platforms/db_requests.py b/cbh/api/platforms/db_requests.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cbh/api/platforms/dto.py b/cbh/api/platforms/dto.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc4ebd89daf9534e87917648a2aa73a5185bff8 --- /dev/null +++ b/cbh/api/platforms/dto.py @@ -0,0 +1,44 @@ +from enum import Enum + + +class Category(int, Enum): + WEB_APPS_SAAS_MVP = 1 + WEBSITES_LANDING_PAGES = 2 + MOBILE_APPS = 3 + UI_UX_DESIGN = 4 + AI_CODING_TOOLS = 5 + AUTOMATION_AI_AGENTS = 6 + VIDEO_CREATIVE = 7 + SEO_GEO = 8 + GROWTH_SOCIAL_REDDIT = 9 + RESEARCH_ANALYTICS = 10 + + +class Level(int, Enum): + LOW = 1 + LOW_TO_MEDIUM = 2 + MEDIUM = 3 + MEDIUM_TO_HIGH = 4 + HIGH = 5 + + +class ToolType(int, Enum): + NO_CODE = 1 + HYBRID = 2 + DEV = 3 + + +class Focus(int, Enum): + WEB = 1 + MOBILE = 2 + DESKTOP = 3 + MULTI_PLATFORM = 4 + MOBILE_DESIGN = 5 + DEVELOPER_WORKFLOW = 6 + DESKTOP_MULTI_PLATFORM_DEV = 7 + + +class MonetizationPriority(int, Enum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 \ No newline at end of file diff --git a/cbh/api/platforms/models.py b/cbh/api/platforms/models.py new file mode 100644 index 0000000000000000000000000000000000000000..846e708b4f0420dc3683ab6132a0578a4e0dfb8b --- /dev/null +++ b/cbh/api/platforms/models.py @@ -0,0 +1,23 @@ +from pydantic import Field + +from cbh.api.platforms.dto import Focus, ToolType, Level, Category, MonetizationPriority +from cbh.core.database import MongoBaseModel + +class PlatformModel(MongoBaseModel): + name: str = Field(description="Platform name, e.g. 'v0', 'Lovable'") + category: Category = Field(description="Primary category. One of: 1=Web apps/SaaS/MVP, 2=Websites/Landing pages, 3=Mobile apps, 4=UI/UX Design, 5=AI Coding tools, 6=Automation/AI agents, 7=Video/Creative, 8=SEO/GEO, 9=Growth/Social/Reddit, 10=Research/Analytics") + subcategory: str = Field(description="Specific subcategory within the category, e.g. 'UI + app generation', 'Prompt-based app builder'") + oneLinePos: str = Field(description="One-line positioning statement describing what the tool does") + description: str = Field(description="Detailed description of the tool: when to use it, strengths, and best-fit scenarios") + userQueries: list[str] = Field(description="List of typical user queries/intents this tool covers, e.g. ['I want to build a SaaS', 'I need a dashboard']") + idealCases: str = Field(description="Description of the ideal client scenario, e.g. 'the client wants a fast visual MVP and is comfortable refining later'") + personas: list[str] = Field(description="List of recommended user personas, e.g. ['Founder', 'PM', 'Developer']") + level: Level = Field(description="Required skill level. One of: 1=Low, 2=Low-to-Medium, 3=Medium, 4=Medium-to-High, 5=High") + toolType: ToolType = Field(description="Tool type. One of: 1=No-code, 2=Hybrid, 3=Dev") + focus: list[Focus] = Field(description="Platform focus areas. List of: 1=Web, 2=Mobile, 3=Desktop, 4=Multi-platform, 5=Mobile design, 6=Developer workflow, 7=Desktop/Multi-platform dev") + productStage: list[str] = Field(description="Best product stages, e.g. ['Ideation', 'MVP', 'MVP UI']") + keyStrengths: list[str] = Field(description="List of key strengths of the tool") + caveats: list[str] = Field(description="List of main caveats or limitations") + monetizationPriority: MonetizationPriority = Field(description="Monetization priority. One of: 1=Low, 2=Medium, 3=High") + website: str = Field(description="Official website URL") + internalNotes: str = Field(description="Internal notes about when/how to recommend this tool") \ No newline at end of file diff --git a/cbh/api/platforms/utils.py b/cbh/api/platforms/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cbh/api/platforms/views.py b/cbh/api/platforms/views.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cbh/api/security/__init__.py b/cbh/api/security/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b48af044d0718f629bd4179565c115b904a10616 --- /dev/null +++ b/cbh/api/security/__init__.py @@ -0,0 +1,11 @@ +""" +Security module initialization. +""" + +from fastapi import APIRouter + +security_router = APIRouter( + prefix="/api/security", +) + +from . import views # pylint: disable=C0413 # noqa: E402,F401 diff --git a/cbh/api/security/db_requests.py b/cbh/api/security/db_requests.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8ba0c0d3823db75f3c9a9c0dc9c3d77b8b1971 --- /dev/null +++ b/cbh/api/security/db_requests.py @@ -0,0 +1,186 @@ +""" +Database requests module for security functionality. +""" + +import asyncio +from datetime import datetime, timedelta + +from fastapi import HTTPException +from passlib.context import CryptContext +from pydantic import EmailStr +from pymongo import ReturnDocument + +from cbh.api.account.dto import AccountType, RegistrationType +from cbh.api.account.models import AccountModel, AccountShorten +from cbh.api.security.dto import VerificationCodeStatus, VerificationCodeType +from cbh.api.security.models import VerificationCodeModel +from cbh.api.security.schemas import ( + LoginAccountRequest, + RegisterAccountRequest, +) +from cbh.core.config import settings +from cbh.core.security import verify_password +from cbh.core.wrappers import background_task + + +async def check_unique_email(email: EmailStr | str) -> AccountModel | None: + """ + Check if a field value already exists in the database to ensure uniqueness. + """ + account = await settings.DB_CLIENT.accounts.find_one( + {"email": {"$regex": f"^{str(email)}$", "$options": "i"}} + ) + account = AccountModel.from_mongo(account) if account else None + + return account + + +async def authenticate_account(data: LoginAccountRequest) -> AccountModel: + """ + Authenticate a user account using mail and password. + """ + account = await settings.DB_CLIENT.accounts.find_one( + {"email": {"$regex": f"^{data.email}$", "$options": "i"}} + ) + if account is None: + raise HTTPException(status_code=404, detail="Invalid email or password.") + + account = AccountModel.from_mongo(account) + if account.registrationType != RegistrationType.ORGANIC: + raise HTTPException(status_code=422, detail="Please sign in with social providers.") + + if not verify_password(data.password, account.password): + raise HTTPException(status_code=400, detail="Invalid email or password.") + + return account + + +async def get_account_by_email(email: str) -> AccountModel | None: + """ + Verify if an account exists. + """ + account = await settings.DB_CLIENT.accounts.find_one( + {"email": {"$regex": f"^{email}$", "$options": "i"}} + ) + return AccountModel.from_mongo(account) if account else None + + +async def create_code_obj( + account: AccountShorten, type_: VerificationCodeType, time_delta: timedelta +) -> VerificationCodeModel: + """ + Create a code object. + """ + prev_code = ( + await settings.DB_CLIENT.verificationcodes.find( + { + "account.id": account.id, + "type": type_.value, + } + ) + .sort("_id", -1) + .to_list(length=1) + ) + prev_code = VerificationCodeModel.from_mongo(prev_code[0]) if prev_code else None + + if prev_code and prev_code.datetimeInserted > datetime.now() - timedelta(minutes=1): + raise HTTPException(status_code=429, detail="Too many requests") + + code = VerificationCodeModel( + account=account, + type=type_, + expiresAt=datetime.now() + time_delta, + ) + await settings.DB_CLIENT.verificationcodes.insert_one(code.to_mongo()) + return code + + +@background_task() +async def set_used_code(code: VerificationCodeModel | None): + """ + Set a code object as used. + """ + if code: + await settings.DB_CLIENT.verificationcodes.update_one( + {"id": code.id}, {"$set": {"status": VerificationCodeStatus.USED.value}} + ) + + +async def verify_code_obj( + code_: str, types: list[VerificationCodeType], exception: bool = True, set_used: bool = True +) -> VerificationCodeModel: + """ + Verify a code object. + """ + code = ( + await settings.DB_CLIENT.verificationcodes.find( + {"id": code_, "type": {"$in": [t.value for t in types]}}, + ) + .sort("_id", -1) + .to_list(length=1) + ) + code = VerificationCodeModel.from_mongo(code[0]) if code else None + + if not code and exception: + error_msg = "Invalid invitation link. Please ask your manager to resend the invite." + if VerificationCodeType.PASSWORD_RESET in types: + error_msg = "Invalid password reset link. Please request a new one." + raise HTTPException( + status_code=404, + detail=error_msg, + ) + + if code and code.status == VerificationCodeStatus.USED and exception: + error_map = { + VerificationCodeType.ORG_INVITATION: "You already created an account. Please sign in.", + VerificationCodeType.TEAM_INVITATION: "You already accepted this invitation. Please sign in.", + VerificationCodeType.PASSWORD_RESET: "You already used this reset link. Please request a new one.", + VerificationCodeType.ORG_CREATION: "You already created an organization. Please sign in.", + } + raise HTTPException(status_code=400, detail=error_map[code.type]) + + if code and code.expiresAt < datetime.now() and exception: + error_msg = "Expired invitation link. Please ask your manager to resend the invite." + if VerificationCodeType.PASSWORD_RESET in types: + error_msg = "Expired password reset link. Please request a new one." + raise HTTPException(status_code=410, detail=error_msg) + + if code and set_used: + asyncio.create_task(set_used_code(code)) + return code + + +async def reset_password_obj(account: AccountShorten, password: str) -> AccountShorten: + """ + Reset a password object. + """ + password = CryptContext(schemes=["bcrypt"], deprecated="auto").hash(password) + await settings.DB_CLIENT.accounts.update_one( + {"id": account.id}, + {"$set": {"password": password}}, + ) + return account + + + +async def create_google_account(user_info: dict) -> AccountModel: + account = AccountModel( + email=user_info["email"], + name=user_info.get("name"), + accountType=AccountType.USER, + registrationType=RegistrationType.GOOGLE, + ) + await settings.DB_CLIENT.accounts.insert_one(account.to_mongo()) + return account + + +async def create_account(data: RegisterAccountRequest) -> AccountModel: + account = AccountModel( + email=data.email, + password=data.password, + name=data.name, + accountType=AccountType.USER, + registrationType=RegistrationType.ORGANIC, + ) + await settings.DB_CLIENT.accounts.insert_one(account.to_mongo()) + return account \ No newline at end of file diff --git a/cbh/api/security/dto.py b/cbh/api/security/dto.py new file mode 100644 index 0000000000000000000000000000000000000000..9e276dd7edb50efa992a5a837633fdcb75c964d1 --- /dev/null +++ b/cbh/api/security/dto.py @@ -0,0 +1,48 @@ +""" +Data Transfer Objects (DTOs) for security functionality. +""" + +from enum import Enum + +from pydantic import BaseModel + + +class AccessToken(BaseModel): + """ + Access token model for authentication. + """ + + type: str = "Bearer" + value: str + + +class GoogleCallbackError(Enum): + """ + Error model for Google callback. + """ + + ORGANIZATION_NOT_FOUND = "OrganizationNotFound" + BLOCKED_ACCOUNT = "BlockedAccount" + INVALID_CODE = "InvalidCode" + EMAIL_MISMATCH = "EmailMismatch" + EXPIRED_CODE = "ExpiredCode" + + +class VerificationCodeType(Enum): + """ + Enum for verification code types. + """ + + PASSWORD_RESET = 1 + TEAM_INVITATION = 2 + ORG_INVITATION = 3 + ORG_CREATION = 4 + + +class VerificationCodeStatus(Enum): + """ + Enum for verification code status. + """ + + PENDING = 1 + USED = 2 diff --git a/cbh/api/security/models.py b/cbh/api/security/models.py new file mode 100644 index 0000000000000000000000000000000000000000..7fa79f8d21c7cce72a38bf2e07756381b46a5569 --- /dev/null +++ b/cbh/api/security/models.py @@ -0,0 +1,15 @@ +from datetime import datetime + +from pydantic import Field + +from cbh.api.account.models import AccountShorten +from cbh.api.security.dto import VerificationCodeStatus, VerificationCodeType +from cbh.core.database import MongoBaseModel + + +class VerificationCodeModel(MongoBaseModel): + account: AccountShorten + expiresAt: datetime + type: VerificationCodeType + status: VerificationCodeStatus = VerificationCodeStatus.PENDING + datetimeInserted: datetime = Field(default_factory=datetime.now) diff --git a/cbh/api/security/schemas.py b/cbh/api/security/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..93aa21a105d0895c2099b6cd89b696d087131afe --- /dev/null +++ b/cbh/api/security/schemas.py @@ -0,0 +1,48 @@ +""" +Schema definitions for security API endpoints. +""" + +from pydantic import BaseModel, EmailStr + +from cbh.api.account.models import AccountModel +from cbh.api.account.dto import AccountType +from cbh.api.security.dto import AccessToken + + + +class LoginAccountRequest(BaseModel): + """ + Request model for account login. + """ + + email: EmailStr + password: str + + +class LoginAccountResponse(BaseModel): + """ + Response model for successful login. + """ + + accessToken: AccessToken | None = None + account: AccountModel + code: str | None = None + + +class ResetPasswordConfirmRequest(BaseModel): + """ + Request model for confirming a password reset. + """ + + code: str + password: str + + +class RegisterAccountRequest(BaseModel): + """ + Request model for registering a new account. + """ + + name: str + email: EmailStr + password: str \ No newline at end of file diff --git a/cbh/api/security/services/__init__.py b/cbh/api/security/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..723632b3597082bd5c74e40dcf18fe0e559f9c91 --- /dev/null +++ b/cbh/api/security/services/__init__.py @@ -0,0 +1,11 @@ +from .oauth import handle_account_oauth +from .utils import ( + form_google_user_info, + form_google_login_url, +) + +__all__ = [ + "handle_account_oauth", + "form_google_user_info", + "form_google_login_url", +] diff --git a/cbh/api/security/services/oauth.py b/cbh/api/security/services/oauth.py new file mode 100644 index 0000000000000000000000000000000000000000..30efcb017c55729505f219a04534a12abcd8fe6f --- /dev/null +++ b/cbh/api/security/services/oauth.py @@ -0,0 +1,22 @@ +from starlette.responses import RedirectResponse + +from cbh.api.security.db_requests import ( + get_account_by_email, + create_google_account, +) +from cbh.core.config import settings +from cbh.core.security import create_access_token + + +async def handle_account_oauth( + user_info: dict, +) -> RedirectResponse: + email = user_info["email"] + account = await get_account_by_email(email) + + if not account: + account = await create_google_account(user_info) + + token = create_access_token(account.email, str(account.id), account.accountType) + redirect_url = f"{settings.Audience}/login/callback?accessToken={token}" + return RedirectResponse(redirect_url) diff --git a/cbh/api/security/services/utils.py b/cbh/api/security/services/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb151e7b6739a769b1670e9298d73103119bd0f --- /dev/null +++ b/cbh/api/security/services/utils.py @@ -0,0 +1,110 @@ +import asyncio +import json +from urllib.parse import urlencode + +import httpx +from fastapi import Request, HTTPException +from starlette.responses import RedirectResponse + +from cbh.api.security.db_requests import verify_code_obj +from cbh.api.security.dto import GoogleCallbackError, VerificationCodeType +from cbh.api.security.models import VerificationCodeModel +from cbh.core.config import settings + + +def send_error_redirect( + error: GoogleCallbackError, code: VerificationCodeModel | None = None +) -> RedirectResponse: + """ + Send an error redirect. + """ + if code: + redirect_url = f"{settings.Audience}/signup?code={code.id}&error={error.value}" + else: + redirect_url = f"{settings.Audience}/login/callback?error={error.value}" + return RedirectResponse(redirect_url) + + +def form_google_login_url(): + """ + Form the Google login URL with the given parameters. + """ + params = { + "client_id": settings.GOOGLE_CLIENT_ID, + "redirect_uri": f"{settings.Issuer}/api/security/google/callback", + "response_type": "code", + "scope": "openid email profile", + "access_type": "offline", + "state": json.dumps({"secret": settings.SECRET_KEY}), + } + return f"https://accounts.google.com/o/oauth2/auth?{urlencode(params)}" + + +async def get_google_access_token(code: str) -> str: + """ + Get the Google access token from the given code. + """ + params = { + "client_id": settings.GOOGLE_CLIENT_ID, + "client_secret": settings.GOOGLE_CLIENT_SECRET, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": f"{settings.Issuer}/api/security/google/callback", + } + async with httpx.AsyncClient() as client: + response = await client.post("https://oauth2.googleapis.com/token", data=params) + response.raise_for_status() + return response.json()["access_token"] + + +async def get_google_user_info(access_token: str) -> dict: + """ + Get the Google user info from the given access token. + """ + headers = {"Authorization": f"Bearer {access_token}"} + async with httpx.AsyncClient() as client: + response = await client.get( + "https://www.googleapis.com/oauth2/v1/userinfo", headers=headers + ) + response.raise_for_status() + return response.json() + + +async def form_google_user_info( + request: Request, +) -> dict: + """ + Form the Google user info from the given request. + """ + code = request.query_params.get("code") + state = json.loads(request.query_params.get("state")) + if state.get("secret") != settings.SECRET_KEY: + raise HTTPException(status_code=403, detail="Permission denied") + access_token = await get_google_access_token(code) + user_info = await get_google_user_info(access_token) + return user_info + + +async def _download_google_picture(user_info: dict) -> bytes | None: + """ + Download the profile picture from the Google user info. + """ + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(15)) as client: + response = await client.get(user_info["picture"]) + response.raise_for_status() + return response.content + except Exception: + return None + + +async def extract_and_upload_google_picture(user_info: dict) -> str | None: + """ + Download and upload the Google picture to S3. + """ + picture = await _download_google_picture(user_info) + if picture: + return settings.S3_CLIENT.upload_file( + picture, f"{user_info['email']}.png", "pictures" + ) + return None diff --git a/cbh/api/security/utils.py b/cbh/api/security/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cefb9e75bce2f7a48125fc47d6fe920dc85e6b5f --- /dev/null +++ b/cbh/api/security/utils.py @@ -0,0 +1,50 @@ +""" +Security utilities module. +""" + +from fastapi import HTTPException +from jinja2 import Environment, FileSystemLoader, select_autoescape + +from cbh.api.account.dto import RegistrationType +from cbh.api.account.models import AccountModel +from cbh.core.config import settings + + +def check_account_to_reset(account_obj: AccountModel, account: AccountModel | None = None) -> None: + """ + Check if the account can be reset. + """ + if not account and account_obj.registrationType != RegistrationType.ORGANIC: + raise HTTPException( + status_code=422, + detail="Please sign in with social providers", + ) + elif account and account_obj.registrationType != RegistrationType.ORGANIC: + raise HTTPException( + status_code=422, + detail="Password reset is not available for social login accounts", + ) + + +async def send_password_reset_email( + code: str, account_obj: AccountModel +) -> None: + """ + Send a password reset email. + """ + templates_path = settings.BASE_DIR / "cbh" / "templates" / "emails" + env = Environment( + loader=FileSystemLoader(templates_path), + autoescape=select_autoescape(["html", "xml"]), + ) + template = env.get_template("resetPassword.html") + + template_content = template.render( + link=f"{settings.Audience}/change-password?code={code}", + audience_link=settings.Audience, + ) + await settings.EMAIL_CLIENT.send_email( + account_obj.email, + "You requested a password reset in Arena", + template_content, + ) diff --git a/cbh/api/security/views.py b/cbh/api/security/views.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff70fb01133f5a6a7f47a81d8abae5d0e0f6c3c --- /dev/null +++ b/cbh/api/security/views.py @@ -0,0 +1,150 @@ +""" +Security API views module. +""" + +from datetime import timedelta + +from fastapi import Depends, HTTPException, Request, Query +from fastapi.responses import RedirectResponse + +from cbh.api.account.models import AccountModel, AccountShorten +from cbh.api.common.db_requests import get_obj_by_id +from cbh.api.common.schemas import PlainTextResponse, EmailRequest +from cbh.api.security import security_router +from cbh.api.security.db_requests import ( + create_account, + reset_password_obj, + authenticate_account, + get_account_by_email, + verify_code_obj, + create_code_obj, +) +from cbh.api.security.dto import AccessToken, VerificationCodeType +from cbh.api.security.schemas import ( + LoginAccountRequest, + LoginAccountResponse, + RegisterAccountRequest, + ResetPasswordConfirmRequest, +) +from cbh.api.security.services import ( + handle_account_oauth, + form_google_user_info, + form_google_login_url, +) +from cbh.api.security.utils import ( + check_account_to_reset, + send_password_reset_email, +) +from cbh.core.security import PermissionDependency, create_access_token +from cbh.core.wrappers import CbhResponseWrapper + + + +@security_router.post("/login") +async def login(data: LoginAccountRequest) -> CbhResponseWrapper[LoginAccountResponse]: + """ + Authenticate a user and generate an access token. + """ + account = await authenticate_account(data) + access_token = create_access_token(account.email, str(account.id), account.accountType) + response = LoginAccountResponse( + accessToken=AccessToken(value=access_token), + account=account, + ) + return CbhResponseWrapper(data=response) + + +@security_router.post("/register") +async def register(data: RegisterAccountRequest) -> CbhResponseWrapper[LoginAccountResponse]: + """ + Register a new user and generate an access token. + """ + account = await create_account(data) + access_token = create_access_token(account.email, str(account.id), account.accountType) + response = LoginAccountResponse( + accessToken=AccessToken(value=access_token), + account=account, + ) + return CbhResponseWrapper(data=response) + + +@security_router.post("/verify") +async def verify( + account: AccountModel = Depends(PermissionDependency()), +) -> CbhResponseWrapper[AccountModel]: + """ + Verify a user's authentication token. + """ + return CbhResponseWrapper(data=account) + + +@security_router.get("/google/login") +async def google_login( +) -> CbhResponseWrapper[PlainTextResponse]: + """ + Redirect to Google OAuth login page. + """ + url = form_google_login_url() + return CbhResponseWrapper(data=PlainTextResponse(text=url)) + + +@security_router.get("/google/callback") +async def google_callback(request: Request) -> RedirectResponse: + """ + Handle Google OAuth callback. + """ + user_info = await form_google_user_info(request) + redirect_response = await handle_account_oauth(user_info) + return redirect_response + + +@security_router.post("/login/as/user") +async def login_as_user( + email: str, + # _: AccountModel = Depends(PermissionDependency([AccountType.ADMIN])), +) -> CbhResponseWrapper[LoginAccountResponse]: + """ + Login as a user. + """ + account = await get_obj_by_id(AccountModel, None, additional_filter={"email": email}) + if account is None: + raise HTTPException(status_code=404, detail="User not found") + token = create_access_token(account.email, str(account.id), account.accountType) + response = LoginAccountResponse( + accessToken=AccessToken(value=token), + account=account, + ) + return CbhResponseWrapper(data=response) + + +@security_router.post("/password/reset/request") +async def request_password_reset_code( + data: EmailRequest, + account: AccountModel = Depends(PermissionDependency(required=False)), +) -> CbhResponseWrapper: + """ + Reset a user's password. + """ + account_obj = await get_account_by_email(data.email) + if not account_obj: + return CbhResponseWrapper(data=None) + check_account_to_reset(account_obj, account) + code = await create_code_obj( + AccountShorten(**account_obj.model_dump()), + VerificationCodeType.PASSWORD_RESET, + timedelta(minutes=5), + ) + await send_password_reset_email(code.id, account_obj) + return CbhResponseWrapper() + + +@security_router.post("/password/reset/confirm") +async def reset_password_confirm( + data: ResetPasswordConfirmRequest, +) -> CbhResponseWrapper[AccountShorten]: + """ + Confirm a user's password reset. + """ + code = await verify_code_obj(data.code, [VerificationCodeType.PASSWORD_RESET]) + account = await reset_password_obj(code.account, data.password) + return CbhResponseWrapper(data=account) diff --git a/cbh/core/config.py b/cbh/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ab7c45a9f89e155746b639fb8746c3e42dee674d --- /dev/null +++ b/cbh/core/config.py @@ -0,0 +1,119 @@ +""" +Configuration module for ClipboardHealthAI application. +""" +import os +import pathlib +from functools import lru_cache +from typing import Optional, Type + +from dotenv import load_dotenv +from langchain_core.runnables import Runnable +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase +from pydantic import BaseModel + +from cbh.core.email_client import EmailClient + +load_dotenv() + + +class BaseConfig: + """ + Base configuration class containing common settings for all environments. + """ + + BASE_DIR: pathlib.Path = pathlib.Path(__file__).parent.parent.parent + SECRET_KEY: str = os.getenv("SECRET", "") + + GOOGLE_CLIENT_ID: str = os.getenv("GOOGLE_CLIENT_ID", "") + GOOGLE_CLIENT_SECRET: str = os.getenv("GOOGLE_CLIENT_SECRET", "") + + OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") + + DB_CLIENT: AsyncIOMotorDatabase = AsyncIOMotorClient( + os.getenv("MONGO_DB_URL") + ).urelocate + + EMAIL_CLIENT: EmailClient = EmailClient( + smtp_server=os.getenv("SMTP_SERVER"), + port=int(os.getenv("PORT")), + sender_email=os.getenv("SENDER_EMAIL"), + password=os.getenv("PASSWORD"), + sender_name=os.getenv("SENDER_NAME"), + ) + + @staticmethod + def get_headers(api_key: str) -> dict: + """ + Generate HTTP headers for API requests. + """ + return { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + } + + + @staticmethod + @lru_cache() + def get_llm( + model: str = "gpt-4.1-mini", + temperature: float = 0.0, + reasoning: str = "none", + schema: Optional[Type[BaseModel]] = None, + is_json: bool = False, + ) -> Runnable: + """ + Get a configured LLM instance. + """ + kwargs = {"model": model, "temperature": temperature} + if model.startswith("gpt-5"): + kwargs["reasoning_effort"] = reasoning + kwargs["temperature"] = 1 + if schema: + return ChatOpenAI(**kwargs).with_structured_output(schema) + if is_json: + return ChatOpenAI(**kwargs).with_structured_output(method="json_mode") + return ChatOpenAI(**kwargs) + + @staticmethod + @lru_cache() + def get_embedding_client( + model: str = "text-embedding-3-small", dimensions: int = 384 + ) -> OpenAIEmbeddings: + return OpenAIEmbeddings(model=model, dimensions=dimensions) + + +class DevelopmentConfig(BaseConfig): + """ + Development environment configuration settings. + """ + + Issuer = "https://dashboardsalesai.cbhexp.com" + Audience = "https://dashboardsalesai.cbhexp.com" + + +class ProductionConfig(BaseConfig): + """ + Production environment configuration settings. + """ + + Issuer = "https://dashboardsalesai.cbhexp.com" + Audience = "https://dashboardsalesai.cbhexp.com" + + +@lru_cache() +def get_settings() -> DevelopmentConfig | ProductionConfig: + """ + Get the appropriate configuration based on the current environment. + """ + config_cls_dict = { + "development": DevelopmentConfig, + "production": ProductionConfig, + } + config_name = os.getenv("FASTAPI_CONFIG", default="development") + config_cls = config_cls_dict[config_name] + return config_cls() # type: ignore + + +settings = get_settings() diff --git a/cbh/core/database.py b/cbh/core/database.py new file mode 100644 index 0000000000000000000000000000000000000000..a874cdfaae4e579ba9f74790b4709e85c7b1327f --- /dev/null +++ b/cbh/core/database.py @@ -0,0 +1,187 @@ +""" +Database utilities for ClipboardHealthAI application. +""" + +from datetime import datetime +from enum import Enum +import re +from typing import Any, Dict, Type, TypeVar + +from bson import ObjectId +from pydantic import AnyUrl, BaseModel, Field, GetCoreSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema + +T = TypeVar("T", bound=BaseModel) + + +class PyObjectId: + """ + Custom type for handling MongoDB ObjectId in Pydantic models. + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, _source: type, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """ + Define the core schema for Pydantic validation. + """ + return core_schema.with_info_after_validator_function( + cls.validate, core_schema.str_schema() # type: ignore + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, _schema: core_schema.CoreSchema, _handler: GetCoreSchemaHandler + ) -> JsonSchemaValue: + """ + Define the JSON schema representation. + """ + return {"type": "string"} + + @classmethod + def validate(cls, value: str) -> ObjectId: + """ + Validate and convert a string to MongoDB ObjectId. + """ + if not ObjectId.is_valid(value): + raise ValueError(f"Invalid ObjectId: {value}") + return ObjectId(value) + + def __getattr__(self, item): + """ + Delegate attribute access to the wrapped ObjectId. + """ + return getattr(self.__dict__["value"], item) + + def __init__(self, value: str | None = None): + """ + Initialize with a string value or create a new ObjectId. + """ + if value is None: + self.value = ObjectId() + else: + self.value = self.validate(value) + + def __str__(self): + """ + Convert to string representation. + """ + return str(self.value) + + +class MongoBaseModel(BaseModel): + """ + Base model for MongoDB documents with serialization support. + """ + + id: str = Field(default_factory=lambda: str(PyObjectId())) + + class Config: # pylint: disable=R0903 + """ + Configuration for the model. + """ + + arbitrary_types_allowed = True + extra = "ignore" + populate_by_name = True + + @staticmethod + def serialize_s3_url(value: Any) -> Any: + """ + Serialize an S3 URL. + """ + if value and isinstance(value, str) and "AWSAccessKeyId" in value and "Expires" in value: + match = re.search(r"s3\.amazonaws\.com/([^?]+)", value) + if match: + return match.group(1) + return value + + def to_mongo(self) -> Dict[str, Any]: + """ + Convert the model instance to a MongoDB-compatible dictionary. + """ + + def model_to_dict(model: BaseModel) -> Dict[str, Any]: + doc = {} + for name in model.__fields__.keys(): + value = getattr(model, name) + key = model.__fields__[name].alias or name + + if isinstance(value, BaseModel): + doc[key] = model_to_dict(value) + elif isinstance(value, list) and all(isinstance(i, BaseModel) for i in value): + doc[key] = [model_to_dict(item) for item in value] # type: ignore + elif value and isinstance(value, Enum): + doc[key] = value.value + elif isinstance(value, datetime): + doc[key] = value.isoformat() # type: ignore + elif value and isinstance(value, AnyUrl): + doc[key] = str(value) # type: ignore + else: + doc[key] = self.serialize_s3_url(value) + + return doc + + result = model_to_dict(self) + return result + + @classmethod + def from_mongo(cls, data: Dict[str, Any]): + """ + Create a model instance from MongoDB document data. + """ + + def restore_enums(inst: Any, model_cls: Type[BaseModel]) -> None: + for name, field in model_cls.__fields__.items(): # type: ignore + value = getattr(inst, name) + if ( + field + and isinstance(field.annotation, type) + and issubclass(field.annotation, Enum) + ): + setattr(inst, name, field.annotation(value)) + elif isinstance(value, BaseModel): + restore_enums(value, value.__class__) + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, BaseModel): + restore_enums(item, item.__class__) + elif isinstance(field.annotation, type) and issubclass( + field.annotation, Enum + ): + value[i] = field.annotation(item) + elif isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, BaseModel): + restore_enums(v, v.__class__) + elif isinstance(field.annotation, type) and issubclass( + field.annotation, Enum + ): + value[k] = field.annotation(v) + + if data is None: + return None + instance = cls(**data) + restore_enums(instance, instance.__class__) + return instance + + +class MongoBaseShortenModel(BaseModel): + """ + Base model for MongoDB documents with serialization support. + """ + + id: str + + @classmethod + def to_mongo_fields(self) -> dict: + result = {field: 1 for field in self.__annotations__ if field != "_id"} + result["_id"] = 0 + result["id"] = 1 + return result + + @classmethod + def from_mongo(cls, mongo_obj: Dict[str, Any]) -> T: + return cls(**mongo_obj) diff --git a/cbh/core/email_client.py b/cbh/core/email_client.py new file mode 100644 index 0000000000000000000000000000000000000000..fdbdfecf2bb0115f2c966531591d7e1841517c6f --- /dev/null +++ b/cbh/core/email_client.py @@ -0,0 +1,54 @@ +import re +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from email.utils import formataddr + +import aiosmtplib + + +class EmailClient: + def __init__( + self, + smtp_server: str, + port: int, + sender_email: str, + password: str, + sender_name: str, + ): + self.smtp_server = smtp_server + self.port = port + self.sender_email = sender_email + self.password = password + self.sender_name = sender_name + + async def send_email( + self, recipient_email: str, subject: str, html_body: str + ) -> None: + message = MIMEMultipart("alternative") + + message["From"] = formataddr((self.sender_name, self.sender_email)) + message["To"] = recipient_email + message["Subject"] = subject + message["Reply-To"] = self.sender_email + + message["X-Mailer"] = self.sender_name + message["X-Priority"] = "3" + + text_body = re.sub("<[^<]+?>", "", html_body) + text_body = text_body.replace(" ", " ").strip() + + text_part = MIMEText(text_body, "plain", "utf-8") + html_part = MIMEText(html_body, "html", "utf-8") + message.attach(text_part) + message.attach(html_part) + + r = await aiosmtplib.send( + message, + hostname=self.smtp_server, + port=self.port, + start_tls=True, + username=self.sender_email, + password=self.password, + timeout=30, + ) + print(r) \ No newline at end of file diff --git a/cbh/core/s3_client.py b/cbh/core/s3_client.py new file mode 100644 index 0000000000000000000000000000000000000000..652fb33776d7d969a517c8be9499dccc3bac60f9 --- /dev/null +++ b/cbh/core/s3_client.py @@ -0,0 +1,125 @@ +import re +from datetime import datetime +import mimetypes +from typing import Literal +import boto3 +from botocore.exceptions import ClientError + +from cbh.core.database import ObjectId + + +class S3Client: + def __init__(self, region: str, bucket_name: str, profile_name: str | None = None): + self.bucket_name = bucket_name + + if profile_name: + session = boto3.Session(profile_name=profile_name) + self.s3_client = session.client("s3", region_name=region) + else: + self.s3_client = boto3.client("s3", region_name=region) + + @staticmethod + def _generate_file_name( + filename: str, + type_: Literal[ + "pictures", + "orgpictures", + "recordings", + "voices", + "scenarioimages", + "organizationdocuments", + "agentfiles", + ], + ) -> str: + random_id = str(ObjectId()) + extension = filename.split(".")[-1] + return f"{type_}/{random_id}.{extension}" + + @staticmethod + def _get_content_type(filename: str) -> str: + return mimetypes.guess_type(filename)[0] + + def generate_presigned_url(self, file_key: str, expiration: int = 3600) -> str: + try: + url = self.s3_client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": self.bucket_name, "Key": file_key}, + ExpiresIn=expiration, + ) + return url + except ClientError as e: + raise Exception(f"Failed to generate presigned URL: {str(e)}") + + def upload_file( + self, + file: bytes, + filename: str, + type_: Literal[ + "pictures", + "orgpictures", + "recordings", + "voices", + "scenarioimages", + "organizationdocuments", + "agentfiles", + ], + ) -> str | None: + if not file: + return None + file_key = self._generate_file_name(filename, type_) + content_type = self._get_content_type(filename) + + self.s3_client.put_object( + Bucket=self.bucket_name, + Key=file_key, + Body=file, + ContentType=content_type, + Metadata={"upload_date": datetime.now().isoformat()}, + ) + + return file_key + + def update_file( + self, + file: bytes, + filename: str, + type_: Literal[ + "pictures", + "orgpictures", + "recordings", + "voices", + "scenarioimages", + "organizationdocuments", + "agentfiles", + ], + ) -> None: + match = re.search(rf"/({type_}/.+)\?", filename) + file_key = match.group(1) if match else filename + content_type = self._get_content_type(file_key) + + response = self.s3_client.put_object( + Bucket=self.bucket_name, + Key=file_key, + Body=file, + ContentType=content_type, + Metadata={"upload_date": datetime.now().isoformat()}, + ) + print(response) + + def delete_file(self, file_key: str) -> bool: + """Delete a file from S3 bucket.""" + try: + self.s3_client.delete_object(Bucket=self.bucket_name, Key=file_key) + return True + except Exception as e: + print(f"Failed to delete file {file_key}: {str(e)}") + return False + + def get_file_content(self, file_key: str) -> bytes | None: + """Get file content from S3 bucket.""" + try: + response = self.s3_client.get_object(Bucket=self.bucket_name, Key=file_key) + return response["Body"].read() + except ClientError as e: + print(f"Failed to get file {file_key}: {str(e)}") + return None diff --git a/cbh/core/security.py b/cbh/core/security.py new file mode 100644 index 0000000000000000000000000000000000000000..b542b38d47bf4b1cc33c93974da9ed7f2e8865d0 --- /dev/null +++ b/cbh/core/security.py @@ -0,0 +1,195 @@ +""" +Security utilities for ClipboardHealthAI application. + +This module provides authentication and authorization functionality, including: +- Password verification +- JWT token creation and validation +- Permission-based endpoint protection using FastAPI dependencies +""" + +from datetime import datetime, timedelta + +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from jose import JWTError, jwt +from passlib.context import CryptContext + +from cbh.api.account.dto import AccountType +from cbh.api.account.models import AccountModel +from cbh.core.config import settings + + +def verify_password(plain_password, hashed_password) -> bool: + """ + Verify a password against its hashed version. + + Args: + plain_password: The plain text password to verify + hashed_password: The hashed password to check against + + Returns: + bool: True if the password matches, False otherwise + """ + result = CryptContext(schemes=["bcrypt"], deprecated="auto").verify( + plain_password, hashed_password + ) + return result + + +def create_access_token(email: str, account_id: str, account_type: AccountType): + """ + Create a JWT access token for a user. + + Args: + email: User's email address + account_id: User's account ID + account_type: User's account type + + Returns: + str: Encoded JWT token + """ + payload = { + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name": email, + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier": account_id, + "accountId": account_id, + "accountType": account_type.value, + "iss": settings.Issuer, + "aud": settings.Audience, + "exp": datetime.utcnow() + timedelta(days=30), + } + encoded_jwt = jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256") + return encoded_jwt + + +class PermissionDependency: + """ + FastAPI dependency for protecting endpoints with authentication. + + This class implements the callable interface required for FastAPI dependencies + and validates JWT tokens for protected endpoints. + """ + + def __init__( + self, + account_type: list[AccountType] | None = None, + required: bool = True, + ): + self.account_types = account_type + self.required = required + + async def __call__( + self, + credentials: HTTPAuthorizationCredentials | None = Depends(HTTPBearer(auto_error=False)), + ) -> AccountModel | None: + """ + Validate authorization credentials and return account details. + + This method is called by FastAPI when the dependency is used. + + Args: + credentials: The HTTP authorization credentials from the request + + Returns: + AccountModel: The account details if authentication is successful + + Raises: + HTTPException: If authentication fails + """ + try: + if not credentials and self.required: + raise HTTPException(status_code=401, detail="Unauthorized") + elif not credentials and not self.required: + return None + account_id = self.authenticate_jwt_token(credentials.credentials) + account_data = await self.get_account_by_id(account_id) + self.check_account_health(account_data) + return AccountModel.from_mongo(account_data) + + except JWTError as e: + raise HTTPException( # pylint: disable=W0707 + status_code=403, detail="Permission denied" + ) + except Exception as e: + if isinstance(e, HTTPException) and e.status_code == 401: + raise e + raise HTTPException( # pylint: disable=W0707 + status_code=403, detail="Permission denied" + ) + + @staticmethod + async def get_account_by_id(account_id: str) -> dict: + """ + Retrieve account data from the database by ID. + + Args: + account_id: The account ID to look up + + Returns: + dict: Account data from the database + """ + account = await settings.DB_CLIENT.accounts.find_one({"id": account_id}) + if not account: + raise HTTPException(status_code=403, detail="Permission denied") + return account + + def check_account_health(self, account: dict): + """ + Verify account data is valid and active. + + Args: + account: Account data dictionary + + Raises: + HTTPException: If the account is not valid + """ + if not account: + raise HTTPException(status_code=403, detail="Permission denied") + if self.account_types and AccountType(account["accountType"]) not in self.account_types: + raise HTTPException(status_code=403, detail="Permission denied") + + @staticmethod + def authenticate_jwt_token(token: str) -> str: + """ + Validate a JWT token and extract the account ID. + + Args: + token: JWT token string + + Returns: + str: Account ID from the token + + Raises: + HTTPException: If token validation fails + """ + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms="HS256", audience=settings.Audience + ) + email: str | None = payload.get( + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name" + ) + account_id: str | None = payload.get("accountId") + + if email is None or account_id is None: + raise HTTPException(status_code=403, detail="Permission denied") + + return account_id + + +def check_account_token(token: str) -> dict | None: + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms="HS256", audience=settings.Audience + ) + email: str | None = payload.get( + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name" + ) + account_id: str | None = payload.get("accountId") + if email is None or account_id is None: + return None + return { + "email": email, + "account_id": account_id, + "account_type": payload.get("accountType"), + } + except Exception as _: + return None diff --git a/cbh/core/wrappers.py b/cbh/core/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c8b2349dd6c039e033db171a42c72bca6daf1f --- /dev/null +++ b/cbh/core/wrappers.py @@ -0,0 +1,119 @@ +""" +Response wrappers and utility decorators for ClipboardHealthAI application. + +This module provides: +- Standardized response wrappers for API endpoints +- Exception handling decorators +- OpenAI API request wrapper +- Background task decorator +""" + +from functools import wraps +from typing import Generic, Optional, TypeVar + +from fastapi import HTTPException +from pydantic import BaseModel +from starlette.responses import JSONResponse + +T = TypeVar("T") + + +class ErrorCbhResponse(BaseModel): + """ + Error response model for standardized error formatting. + + Attributes: + message: Error message describing what went wrong + """ + + message: str + + +class CbhResponseWrapper(BaseModel, Generic[T]): + """ + Standard response wrapper for all API endpoints. + + This class provides a consistent structure for all API responses, + including data, success status, and error information. + + Attributes: + data: The response data (optional) + successful: Whether the request was successful + error: Error details if the request failed + """ + + data: Optional[T] = None + successful: bool = True + error: Optional[ErrorCbhResponse] = None + + def response(self, status_code: int): + """ + Create a JSONResponse with proper status code and formatting. + + Args: + status_code: HTTP status code for the response + + Returns: + JSONResponse: Formatted API response + """ + return JSONResponse( + status_code=status_code, + content={ + "data": self.data, + "successful": self.successful, + "error": self.error.dict() if self.error else None, + }, + ) + + +def exception_wrapper(http_error: int, error_message: str): + """ + Decorator for handling exceptions in route handlers. + + Catches any exceptions and converts them to a proper HTTP exception + with specified status code and error message. + + Args: + http_error: HTTP status code to use for the exception + error_message: Error message to include + + Returns: + Decorator function that wraps route handlers + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except Exception as e: + raise HTTPException(status_code=http_error, detail=error_message) from e + + return wrapper + + return decorator + + +def background_task(): + """ + Decorator for background tasks that should not crash the application. + + Wraps a function to catch and suppress any exceptions, preventing + background task failures from affecting the main application. + + Returns: + Decorator function for background tasks + """ + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs) -> str | None: + try: + result = await func(*args, **kwargs) + return result + except Exception: # pylint: disable=W0718 + return None + + return wrapper + + return decorator diff --git a/cbh/templates/emails/base.html b/cbh/templates/emails/base.html new file mode 100644 index 0000000000000000000000000000000000000000..19198e406daec9b346fa8e74ed5289f71cc78ec0 --- /dev/null +++ b/cbh/templates/emails/base.html @@ -0,0 +1,73 @@ + + + + + + + + + + + {% block title %}{% endblock %} + {% block head %}{% endblock %} + + +
+ + {% block header %} + + + + {% endblock %} + + {% block main %}{% endblock %} + + {% block after_content %}{% endblock %} + + {% block footer %} + + + + {% endblock %} +
+ + + + +
+ Arena Logo +
+
+ + + + +
+ + + + +
+ Arena Logo +
+

+ Arena is a service of Clipboard Health LLC., a licensed provider of health organization services. + All money transmission is provided by Clipboard Health LLC, pursuant to Clipboard Health LLC's licenses. +

+ Clipboard Health is located at 2211 North First Street, San Jose, CA 95131 +

+
+
+
+ {% block scripts %}{% endblock %} + + + diff --git a/cbh/templates/emails/resetPassword.html b/cbh/templates/emails/resetPassword.html new file mode 100644 index 0000000000000000000000000000000000000000..1e86fd2b6ca98a0a1151e8fe12e84559710d7dea --- /dev/null +++ b/cbh/templates/emails/resetPassword.html @@ -0,0 +1,484 @@ +{% extends "base.html" %} + +{% block title %}Reset your password{% endblock %} + +{% block head %} + +{% endblock %} + +{% block body_extra_style %}; background-color: #0d0d0f{% endblock %} + +{% block header %}{% endblock %} + +{% block main %} + + + + + + + + + + + + +
+ + + + +
+ + + + + + +
+ STAFFILY AI + + + Log In + + + Reset + +
+
+
+ + + + + +
+

+ HI +

+

+ Reset password +

+

+ We received a request to reset the password for your Staffily AI account. Use the secure button below to create a new password and get back into your workspace. +

+

+ Core value: +
+ • quick account recovery +
+ • protected access +
+ • seamless return to work +

+ + + + + +
+ + Reset password + + + + Open website + +
+

+ Next steps +

+ + + + + + +
+ + + + +
+
+ Recommended +
+

+ Secure reset +

+

+ Create a new password with one protected link. +

+
+
+ + + + +
+
+ Fast +
+

+ Sign back in +

+

+ Return to {{audience_link}} as soon as your password is updated. +

+
+
+ + + + +
+
+ Support +
+

+ Need help? +

+

+ If this wasn't you, ignore this email and keep your account safe. +

+
+
+
+
+ + + + +
+

+ This reset link expires in 5 minutes. +

+

+ If you didn't request a password reset, you can safely ignore this email. You can always sign in at + {{audience_link}} +

+
+
+ + +{% endblock %} + +{% block after_content %}{% endblock %} + +{% block footer %}{% endblock %} + + diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..543e211bbfaf2bfe4faf729cef7992ab284628b7 --- /dev/null +++ b/main.py @@ -0,0 +1,10 @@ +""" +FastAPI application entry point for AI Sales Coach. + +This module creates and launches the FastAPI application by importing the create_app +function from the cbh package. +""" + +from cbh import create_app + +app = create_app() diff --git a/migrations/upload_data.py b/migrations/upload_data.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ecb84827dd24c4559a58c9fe3fa3160c3bde00 --- /dev/null +++ b/migrations/upload_data.py @@ -0,0 +1,61 @@ +import asyncio +import json +from cbh.core.config import settings +from langchain_core.prompts import ChatPromptTemplate +from cbh.api.platforms.models import PlatformModel + +system_prompt = """You are a data structuring assistant. Your task is to convert raw AI tool data into a strictly typed structured format. + +Given the raw tool data below, return a single JSON object matching the PlatformModel schema exactly. + +Rules: +- "name": extract the tool name as-is. +- "category": map to an integer enum: + 1=Web apps/SaaS/MVP, 2=Websites/Landing pages, 3=Mobile apps, 4=UI/UX Design, + 5=AI Coding tools, 6=Automation/AI agents, 7=Video/Creative, 8=SEO/GEO, + 9=Growth/Social/Reddit, 10=Research/Analytics. +- "subcategory": use the subcategory string as-is. +- "oneLinePos": use the "One-line positioning" value. +- "description": use the "Detailed description" value. +- "userQueries": split "User query covered" into a list of distinct user intents/queries. If there is only one, return a single-element list. +- "idealCases": use the "Best if client wants" value. +- "personas": split "Recommended persona" by commas into a list of individual personas. +- "level": map skill level to an integer enum: 1=Low, 2=Low-to-Medium, 3=Medium, 4=Medium-to-High, 5=High. +- "toolType": map to an integer enum: 1=No-code, 2=Hybrid, 3=Dev. +- "focus": map "Platform focus" to a list of integer enums: 1=Web, 2=Mobile, 3=Desktop, 4=Multi-platform, 5=Mobile design, 6=Developer workflow, 7=Desktop/Multi-platform dev. +- "productStage": split "Best product stage" by commas into a list. +- "keyStrengths": split "Key strengths" by semicolons into a list. Trim whitespace. +- "caveats": split "Main caveats" by semicolons into a list. If only one caveat, return a single-element list. Trim whitespace. +- "monetizationPriority": map to an integer enum: 1=Low, 2=Medium, 3=High. +- "website": use the "Website" URL as-is. +- "internalNotes": use the "Internal notes" value. + +Raw tool data: +{raw_data}""" + +async def upload_data(item: dict) -> PlatformModel: + prompt = ChatPromptTemplate.from_messages([ + ("system", system_prompt), + ]) + chain = prompt | settings.get_llm(model="gpt-5.4", schema=PlatformModel) + result = await chain.ainvoke({"raw_data": json.dumps(item, ensure_ascii=False)}) + print(f"Processed: {result.name}") + return result + + +async def main(): + with open("ai_tools.json", "r") as f: + data = json.load(f) + + results = [] + batch_size = 10 + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + platforms = await asyncio.gather(*[upload_data(item) for item in batch]) + results.extend(platforms) + + await settings.DB_CLIENT.platforms.insert_many([platform.to_mongo() for platform in results]) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/test_main.http b/test_main.http new file mode 100644 index 0000000000000000000000000000000000000000..a2d81a92c9122ae3e6b5b657c5723033b2f26895 --- /dev/null +++ b/test_main.http @@ -0,0 +1,11 @@ +# Test your FastAPI endpoints + +GET http://127.0.0.1:8000/ +Accept: application/json + +### + +GET http://127.0.0.1:8000/hello/User +Accept: application/json + +###