Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Security, Query, status, Request, Depends | |
| from fastapi.responses import FileResponse, StreamingResponse, JSONResponse | |
| from fastapi.security import APIKeyHeader | |
| import openai | |
| from pydantic import BaseModel | |
| from uuid import UUID | |
| import os | |
| import logging | |
| import json | |
| import regex as re | |
| from datetime import datetime, timezone | |
| from app.user import User | |
| from typing import List, Optional, Callable | |
| from functools import wraps | |
| from openai import OpenAI | |
| import psycopg2 | |
| from psycopg2 import sql | |
| import os | |
| from app.utils import add_to_cache, download_file_from_s3, get_api_key, get_user_info, get_growth_guide_session, pop_cache, print_log, get_user, upload_mementos_to_db, get_user_summary, get_user_life_status, create_pre_gg_report | |
| from dotenv import load_dotenv | |
| import logging.config | |
| import time | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| import sys | |
| import boto3 | |
| import pickle | |
| from app.exceptions import * | |
| import re | |
| import sentry_sdk | |
| load_dotenv() | |
| AWS_ACCESS_KEY = os.getenv('AWS_ACCESS_KEY') | |
| AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY') | |
| REGION = os.getenv('AWS_REGION') | |
| SENTRY_DSN = os.getenv('SENTRY_DSN') | |
| sentry_sdk.init( | |
| dsn=SENTRY_DSN, | |
| # Set traces_sample_rate to 1.0 to capture 100% | |
| # of transactions for tracing. | |
| traces_sample_rate=1.0, | |
| _experiments={ | |
| # Set continuous_profiling_auto_start to True | |
| # to automatically start the profiler on when | |
| # possible. | |
| "continuous_profiling_auto_start": True, | |
| }, | |
| ) | |
| # Create required folders | |
| os.makedirs('logs', exist_ok=True) | |
| os.makedirs(os.path.join('logs', 'users'), exist_ok=True) | |
| if not os.path.exists(os.path.join('users', 'data')): | |
| os.makedirs(os.path.join('users', 'data')) | |
| else: | |
| # Folder exists, we want to clear all current user data | |
| for file in os.listdir(os.path.join('users', 'data')): | |
| os.remove(os.path.join('users', 'data', file)) | |
| if not os.path.exists(os.path.join('bookings', 'data')): | |
| os.makedirs(os.path.join('bookings', 'data')) | |
| else: | |
| # Folder exists, we want to clear all current booking data | |
| for file in os.listdir(os.path.join('bookings', 'data')): | |
| os.remove(os.path.join('bookings', 'data', file)) | |
| if not os.path.exists(os.path.join('bookings', 'to_upload')): | |
| os.makedirs(os.path.join('bookings', 'to_upload')) | |
| else: | |
| # Folder exists, we want to clear all current booking data | |
| for file in os.listdir(os.path.join('bookings', 'to_upload')): | |
| os.remove(os.path.join('bookings', 'to_upload', file)) | |
| if not os.path.exists(os.path.join('users', 'to_upload')): | |
| os.makedirs(os.path.join('users', 'to_upload')) | |
| if not os.path.exists(os.path.join('mementos', 'to_upload')): | |
| os.makedirs(os.path.join('mementos', 'to_upload')) | |
| # Custom filter for user-specific logs | |
| class UserFilter(logging.Filter): | |
| def filter(self, record): | |
| return hasattr(record, 'user_id') and record.user_id != "no-user" | |
| class NoUserFilter(logging.Filter): | |
| def filter(self, record): | |
| return not (hasattr(record, 'user_id') and record.user_id != "no-user") | |
| class UserLogHandler(logging.Handler): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| self.base_path = kwargs.get('base_path', 'logs/users') | |
| self.maxBytes = kwargs.get('maxBytes', 10485760) | |
| self.backupCount = kwargs.get('backupCount', 3) | |
| self.handlers = {} | |
| # Ensure base path exists | |
| os.makedirs(self.base_path, exist_ok=True) | |
| def emit(self, record): | |
| if hasattr(record, 'user_id') and record.user_id != "no-user": | |
| # Remove brackets from filename | |
| if record.user_id: | |
| user_id = record.user_id.strip('[]').strip() | |
| else: | |
| user_id = "no-user" | |
| if user_id not in self.handlers: | |
| handler = logging.handlers.RotatingFileHandler( | |
| filename=os.path.join(self.base_path, f'{user_id}.log'), | |
| maxBytes=self.maxBytes, | |
| backupCount=self.maxBytes, | |
| encoding='utf-8' | |
| ) | |
| formatter = logging.Formatter('%(asctime)s [%(levelname)s] [%(endpoint)s] [%(user_id)s]: %(message)s') | |
| handler.setFormatter(formatter) | |
| self.handlers[user_id] = handler | |
| try: | |
| self.handlers[user_id].emit(record) | |
| except Exception: | |
| self.handleError(record) | |
| class ConditionalFormatter(logging.Formatter): | |
| def format(self, record): | |
| format_string = '%(asctime)s [%(levelname)s]' | |
| if getattr(record, 'endpoint', None): | |
| format_string += ' [%(endpoint)s]' | |
| if getattr(record, 'user_id', None): | |
| format_string += ' [%(user_id)s]' | |
| if getattr(record, 'duration', None): | |
| format_string += ' [Duration: %(duration).3fs]' | |
| format_string += ': %(message)s' | |
| self._style._fmt = format_string | |
| return super().format(record) | |
| # Add new filter class after existing filter classes | |
| class EndpointFilter(logging.Filter): | |
| def filter(self, record): | |
| return hasattr(record, 'endpoint') and record.endpoint.startswith('/') | |
| # Configure logging | |
| logging_config = { | |
| 'version': 1, | |
| 'disable_existing_loggers': False, | |
| 'formatters': { | |
| 'conditional': { | |
| '()': ConditionalFormatter, | |
| 'datefmt': '%Y-%m-%d %H:%M:%S', | |
| }, | |
| }, | |
| 'filters': { | |
| 'userfilter': { | |
| '()': UserFilter | |
| }, | |
| 'nouserfilter': { | |
| '()': NoUserFilter | |
| }, | |
| 'endpointfilter': { | |
| '()': EndpointFilter | |
| } | |
| }, | |
| 'handlers': { | |
| 'default': { | |
| 'level': 'INFO', | |
| 'formatter': 'conditional', | |
| 'class': 'logging.StreamHandler', | |
| 'stream': sys.stdout, # Use stdout instead of stderr | |
| 'filters': ['nouserfilter'] | |
| }, | |
| 'file': { | |
| 'level': 'INFO', | |
| 'formatter': 'conditional', | |
| 'class': 'logging.handlers.RotatingFileHandler', | |
| 'filename': 'logs/app.log', | |
| 'maxBytes': 10485760, # 10MB | |
| 'backupCount': 5, | |
| 'encoding': 'utf-8', # Add UTF-8 encoding | |
| 'filters': ['endpointfilter'] # Only log endpoints | |
| }, | |
| 'userfile': { | |
| 'level': 'INFO', | |
| 'formatter': 'conditional', | |
| '()': UserLogHandler, # Changed from 'class' to '()' | |
| 'base_path': 'logs/users', | |
| 'maxBytes': 10485760, | |
| 'backupCount': 3, | |
| 'filters': ['userfilter'] | |
| } | |
| }, | |
| 'loggers': { | |
| '': { # root logger | |
| 'handlers': ['default', 'file', 'userfile'], | |
| 'level': 'INFO', | |
| 'propagate': True | |
| } | |
| } | |
| } | |
| logging.config.dictConfig(logging_config) | |
| logger = logging.getLogger(__name__) | |
| # Suppress verbose logs from external libraries | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| logging.getLogger("urllib3").setLevel(logging.WARNING) | |
| # Request logging middleware | |
| class LoggingMiddleware(BaseHTTPMiddleware): | |
| async def dispatch(self, request: Request, call_next: Callable): | |
| start_time = time.time() | |
| endpoint = request.url.path | |
| user_id = None | |
| if "user_id" in request.query_params: | |
| user_id = request.query_params["user_id"] | |
| elif request.method == "POST": | |
| try: | |
| body = await request.json() | |
| user_id = body.get("user_id") | |
| except: | |
| pass | |
| # Log start of request | |
| logger.info( | |
| "[start]: Request received", | |
| extra={ | |
| "user_id": user_id, | |
| "endpoint": endpoint, | |
| } | |
| ) | |
| try: | |
| response = await call_next(request) | |
| duration = time.time() - start_time | |
| # Log end of request with duration | |
| logger.info( | |
| f"Request completed with status {response.status_code}", | |
| extra={ | |
| "user_id": user_id, | |
| "endpoint": endpoint, | |
| "duration": duration | |
| } | |
| ) | |
| return response | |
| except Exception as e: | |
| duration = time.time() - start_time | |
| logger.error( | |
| f"Request failed with error: {str(e)}", | |
| extra={ | |
| "user_id": user_id, | |
| "endpoint": endpoint, | |
| "duration": duration | |
| } | |
| ) | |
| raise | |
| # OpenAI Client | |
| # GENERAL_ASSISTANT = os.getenv('OPENAI_GENERAL_ASSISTANT') | |
| GENERAL_ASSISTANT = "asst_vnucWWELJlCWadfAARwyKkCW" | |
| # Initialize Logging (optional) | |
| # logging.basicConfig(filename='app.log', level=logging.INFO) | |
| # FastAPI App | |
| app = FastAPI(title="Ourcoach AI API", description="A FastAPI app for ourcoach's chatbot", version="0.1.0") | |
| app.add_middleware(LoggingMiddleware) | |
| # Pydantic Models | |
| class CreateUserItem(BaseModel): | |
| user_id: str | |
| class ChatItem(BaseModel): | |
| user_id: str | |
| message: str | |
| class PersonaItem(BaseModel): | |
| user_id: str | |
| persona: str | |
| class GGItem(BaseModel): | |
| user_id: str | |
| gg_session_id: str | |
| class AssistantItem(BaseModel): | |
| user_id: str | |
| assistant_id: str | |
| class ChangeDateItem(BaseModel): | |
| user_id: str | |
| date: str | |
| class BookingItem(BaseModel): | |
| booking_id: str | |
| def catch_endpoint_error(func): | |
| """Decorator to handle errors in FastAPI endpoints""" | |
| # Add this to preserve endpoint metadata | |
| async def wrapper(*args, **kwargs): | |
| try: | |
| # Extract api_key from kwargs if present and pass it to the wrapped function | |
| api_key = kwargs.pop('api_key', None) | |
| return await func(*args, **kwargs) | |
| except OpenAIRequestError as e: | |
| # OpenAI service error | |
| # Try to cancel the run so we dont get "Cannot add message to thread with active run" | |
| # if e.run_id: | |
| # user_id = e.user_id | |
| # if user_id != 'no-user': | |
| # user = get_user(user_id) | |
| # user.cancel_run(e.run_id) | |
| logger.error(f"OpenAI service error in {func.__name__}(...): {str(e)}", | |
| extra={ | |
| 'user_id': e.user_id, | |
| 'endpoint': func.__name__ | |
| }) | |
| # Extract thread_id and run_id from error message | |
| thread_match = re.search(r'thread_(\w+)', str(e)) | |
| run_match = re.search(r'run_(\w+)', str(e)) | |
| if thread_match and run_match: | |
| thread_id = f"thread_{thread_match.group(1)}" | |
| run_id = f"run_{run_match.group(1)}" | |
| user = get_user(e.user_id) | |
| logger.info(f"Cancelling run {run_id} for thread {thread_id}", extra={"user_id": e.user_id, "endpoint": func.__name__}) | |
| user.cancel_run(run_id, thread_id) | |
| logger.info(f"Run {run_id} cancelled for thread {thread_id}", extra={"user_id": e.user_id, "endpoint": func.__name__}) | |
| raise HTTPException( | |
| status_code=status.HTTP_502_BAD_GATEWAY, | |
| detail=e.get_formatted_details() | |
| ) | |
| except DBError as e: | |
| # check if code is one of ["NoOnboardingError", "NoBookingError"] if yes then return code 404 otherwise 500 | |
| if e.code == "NoOnboardingError" or e.code == "NoBookingError": | |
| # no onboarding or booking data (user not found) | |
| status_code = 404 | |
| else: | |
| status_code = 505 | |
| logger.error(f"Database error in {func.__name__}: {str(e)}", | |
| extra={ | |
| 'user_id': e.user_id, | |
| 'endpoint': func.__name__ | |
| }) | |
| raise HTTPException( | |
| status_code=status_code, | |
| detail=e.get_formatted_details() | |
| ) | |
| except (UserError, AssistantError, ConversationManagerError, UtilsError) as e: | |
| # Known internal errors | |
| logger.error(f"Internal error in {func.__name__}: {str(e)}", | |
| extra={ | |
| 'user_id': e.user_id, | |
| 'endpoint': func.__name__, | |
| 'traceback': traceback.extract_stack() | |
| }) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| # detail = traceback.extract_stack() | |
| detail=e.get_formatted_details() | |
| ) | |
| except openai.BadRequestError as e: | |
| # OpenAI request error | |
| user_id = kwargs.get('user_id', 'no-user') | |
| logger.error(f"OpenAI request error in {func.__name__}: {str(e)}", | |
| extra={ | |
| 'user_id': user_id, | |
| 'endpoint': func.__name__ | |
| }) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail={ | |
| "type": "OpenAIError", | |
| "message": str(e), | |
| "user_id": user_id, | |
| "at": datetime.now(timezone.utc).isoformat() | |
| } | |
| ) | |
| except Exception as e: | |
| # Unknown errors | |
| user_id = kwargs.get('user_id', 'no-user') | |
| if len(args) and hasattr(args[0], 'user_id'): | |
| user_id = args[0].user_id | |
| logger.error(f"Unexpected error in {func.__name__}: {str(e)}", | |
| extra={ | |
| 'user_id': user_id, | |
| 'endpoint': func.__name__ | |
| }) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail={ | |
| "type": "FastAPIError", | |
| "message": str(e), | |
| "user_id": user_id, | |
| "at": datetime.now(timezone.utc).isoformat() | |
| } | |
| ) | |
| # raise FastAPIError( | |
| # user_id=user_id, | |
| # message=f"Unexpected error in {func.__name__}", | |
| # e=str(e) | |
| # ) | |
| return wrapper | |
| # Apply decorator to all endpoints | |
| async def set_intro_done( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| user = get_user(user_id) | |
| user.set_intro_done() | |
| logger.info("Intro done", extra={"user_id": user_id, "endpoint": "/set_intro_done"}) | |
| return {"response": "ok"} | |
| async def set_goal( | |
| user_id: str, | |
| goal: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| user = get_user(user_id) | |
| user.set_goal(goal) | |
| logger.info(f"Goal set: {goal}", extra={"user_id": user_id, "endpoint": "/set_goal"}) | |
| return {"response": "ok"} | |
| async def do_micro( | |
| request: ChangeDateItem, | |
| day: int, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| user = get_user(request.user_id) | |
| response = user.do_micro(request.date, day) | |
| logger.info(f"Micro action completed", extra={"user_id": request.user_id, "endpoint": "/do_micro"}) | |
| return {"response": response} | |
| # endpoint to change user assistant using user.change_to_latest_assistant() | |
| async def change_assistant( | |
| request: AssistantItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| user = get_user(request.user_id) | |
| user.change_assistant(request.assistant_id) | |
| logger.info(f"Assistant changed to {request.assistant_id}", | |
| extra={"user_id": request.user_id, "endpoint": "/change_assistant"}) | |
| return {"assistant_id": request.assistant_id} | |
| async def clear_cache( | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| pop_cache(user_id='all') | |
| logger.info("Cache cleared successfully", extra={"endpoint": "/clear_cache"}) | |
| return {"response": "Cache cleared successfully"} | |
| async def migrate_user( | |
| request: CreateUserItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
| if not client: | |
| raise OpenAIRequestError( | |
| user_id=request.user_id, | |
| message="Failed to initialize OpenAI client" | |
| ) | |
| user_file = os.path.join('users', 'data', f'{request.user_id}.pkl') | |
| download_file_from_s3(f'{request.user_id}.pkl', 'core-ai-assets') | |
| with open(user_file, 'rb') as f: | |
| old_user_object = pickle.load(f) | |
| user = User(request.user_id, old_user_object.user_info, client, GENERAL_ASSISTANT) | |
| user.conversations.current_thread = old_user_object.conversations.current_thread | |
| user.conversations.intro_done = True | |
| user.done_first_reflection = old_user_object.done_first_reflection | |
| user.client = client | |
| user.conversations.client = client | |
| api_response = { | |
| "user": user.user_info, | |
| "user_messages": user.get_messages(), | |
| "general_assistant": user.conversations.assistants['general'].id, | |
| "intro_assistant": user.conversations.assistants['intro'].id, | |
| "goal": user.goal if user.goal else "No goal is not set yet", | |
| "current_day": user.growth_plan.current()['day'], | |
| "micro_actions": user.micro_actions, | |
| "recommended_actions": user.recommended_micro_actions, | |
| "challenges": user.challenges, | |
| "other_focusses": user.other_focusses, | |
| "scores": f"Personal Growth: {user.personal_growth_score} || Career: {user.career_growth_score} || Health/Wellness: {user.health_and_wellness_score} || Relationships: {user.relationship_score} || Mental Health: {user.mental_well_being_score}", | |
| "recent_wins": user.recent_wins | |
| } | |
| add_to_cache(user) | |
| pop_cache(user.user_id) | |
| os.remove(user_file) | |
| logger.info(f"User {user.user_id} loaded successfully from S3", extra={'user_id': user.user_id, 'endpoint': 'migrate_user'}) | |
| return api_response | |
| async def get_user_by_id( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| print_log("INFO", "Getting user", extra={"user_id": user_id, "endpoint": "/get_user"}) | |
| logger.info("Getting user", extra={"user_id": user_id, "endpoint": "/get_user"}) | |
| user = get_user(user_id) | |
| print_log("INFO", "Successfully retrieved user", extra={"user_id": user_id, "endpoint": "/get_user"}) | |
| logger.info("Successfully retrieved user", extra={"user_id": user_id, "endpoint": "/get_user"}) | |
| api_response = {"user": user.user_info, "user_messages": user.get_messages(), "general_assistant": user.conversations.assistants['general'].id, "intro_assistant": user.conversations.assistants['intro'].id} | |
| if user.goal: | |
| api_response["goal"] = user.goal | |
| else: | |
| api_response["goal"] = "No goal is not set yet" | |
| api_response["current_day"] = user.growth_plan.current()['day'] | |
| api_response['micro_actions'] = user.micro_actions | |
| api_response['recommended_actions'] = user.recommended_micro_actions | |
| api_response['challenges'] = user.challenges | |
| api_response['other_focusses'] = user.other_focusses | |
| api_response['reminders'] = user.reminders | |
| api_response['scores'] = f"Personal Growth: {user.personal_growth_score} || Career: {user.career_growth_score} || Health/Wellness: {user.health_and_wellness_score} || Relationships: {user.relationship_score} || Mental Health: {user.mental_well_being_score}" | |
| api_response['recent_wins'] = user.recent_wins | |
| api_response['last_gg_session'] = user.last_gg_session | |
| return api_response | |
| async def update_user_persona( | |
| request: PersonaItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| """Update user's legendary persona in the database""" | |
| user_id = request.user_id | |
| persona = request.persona | |
| user = get_user(user_id) | |
| user.update_user_info(f"User's new Legendary Persona is: {persona}") | |
| logger.info(f"Updated persona to {persona}", extra={"user_id": user_id, "endpoint": "/update_user_persona"}) | |
| # Connect to database | |
| db_params = { | |
| 'dbname': 'ourcoach', | |
| 'user': 'ourcoach', | |
| 'password': 'hvcTL3kN3pOG5KteT17T', | |
| 'host': 'staging-ourcoach.cx8se8o0iaiy.ap-southeast-1.rds.amazonaws.com', | |
| 'port': '5432' | |
| } | |
| conn = psycopg2.connect(**db_params) | |
| cur = conn.cursor() | |
| # Get current onboarding data | |
| cur.execute("SELECT onboarding FROM users WHERE id = %s", (user_id,)) | |
| result = cur.fetchone() | |
| if not result: | |
| raise DBError( | |
| user_id=user_id, | |
| code="NoOnboardingError", | |
| message="User not found in database" | |
| ) | |
| # Update legendPersona in onboarding JSON | |
| onboarding = json.loads(result[0]) | |
| onboarding['legendPersona'] = persona | |
| # Update database | |
| cur.execute( | |
| "UPDATE users SET onboarding = %s WHERE id = %s", | |
| (json.dumps(onboarding), user_id) | |
| ) | |
| conn.commit() | |
| if 'cur' in locals(): | |
| cur.close() | |
| if 'conn' in locals(): | |
| conn.close() | |
| return {"status": "success", "message": f"Updated persona to {persona}"} | |
| async def add_ai_message( | |
| request: ChatItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| user_id = request.user_id | |
| message = request.message | |
| logger.info("Adding AI response", extra={"user_id": user_id, "endpoint": "/add_ai_message"}) | |
| print_log("INFO", "Adding AI response", extra={"user_id": user_id, "endpoint": "/add_ai_message"}) | |
| user = get_user(user_id) | |
| user.add_ai_message(message) | |
| add_to_cache(user) | |
| pop_cache(user.user_id) | |
| print_log("INFO", "AI response added", extra={"user_id": user_id, "endpoint": "/add_ai_message"}) | |
| return {"response": "ok"} | |
| async def schedule_gg_reminder( | |
| request: ChangeDateItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| # session_id = request.gg_session_id | |
| user_id = request.user_id | |
| logger.info(f"Scheduling GG session reminder for {request.date}", extra={"user_id": user_id, "endpoint": "/schedule_gg_reminder"}) | |
| print_log("INFO", f"Scheduling GG session: reminder for {request.date}", extra={"user_id": user_id, "endpoint": "/schedule_gg_reminder"}) | |
| # get user | |
| user = get_user(user_id) | |
| # call user.ask_to_schedule_growth_guide_reminder(session_id) | |
| response = user.ask_to_schedule_growth_guide_reminder(request.date) | |
| logger.info(f"GG session reminder scheduled, response: {response}", extra={"user_id": user_id, "endpoint": "/schedule_gg_reminder"}) | |
| return {"response": response} | |
| async def process_gg_session( | |
| request: GGItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| logger.info("Processing growth guide session", extra={"user_id": request.user_id, "endpoint": "/process_gg_session"}) | |
| user = get_user(request.user_id) | |
| session_data = get_growth_guide_session(request.user_id, request.gg_session_id) | |
| response = user.process_growth_guide_session(session_data, request.gg_session_id) | |
| add_to_cache(user) | |
| pop_cache(user.user_id) | |
| return {"response": response} | |
| async def get_daily_message( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| logger.info("Getting daily messages", extra={"user_id": user_id, "endpoint": "/user_daily_messages"}) | |
| user = get_user(user_id) | |
| daily_messages = user.get_daily_messages() | |
| return {"response": daily_messages} | |
| async def refresh_multiple_users( | |
| user_ids: List[str], | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| logger.info("Refreshing multiple users", extra={"endpoint": "/batch_refresh_users"}) | |
| client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
| failed_users = [] | |
| for i,user_id in enumerate(user_ids): | |
| old_user = get_user(user_id) | |
| user = old_user.refresh(client) | |
| add_to_cache(user) | |
| pop_cache(user.user_id) | |
| logger.info(f"Successfully refreshed user {i+1}/{len(user_ids)}", extra={"user_id": user_id, "endpoint": "/batch_refresh_users"}) | |
| if failed_users: | |
| return {"status": "partial", "failed_users": failed_users} | |
| return {"status": "success", "failed_users": []} | |
| async def refresh_user( | |
| request: CreateUserItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| print_log("INFO","Refreshing user", extra={"user_id": request.user_id, "endpoint": "/refresh_user"}) | |
| logger.info("Refreshing user", extra={"user_id": request.user_id, "endpoint": "/refresh_user"}) | |
| client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
| old_user = get_user(request.user_id) | |
| user = old_user.refresh(client) | |
| add_to_cache(user) | |
| pop_cache(user.user_id) | |
| print_log("INFO","User refreshed", extra={"user_id": request.user_id, "endpoint": "/refresh_user"}) | |
| logger.info(f"User refreshed -> {user}", extra={"user_id": request.user_id, "endpoint": "/refresh_user"}) | |
| return {"response": "ok"} | |
| async def create_user( | |
| request: CreateUserItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| logger.info("Creating new user", extra={"user_id": request.user_id, "endpoint": "/create_user"}) | |
| client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
| if not client: | |
| raise OpenAIRequestError("client_init", "Failed to initialize OpenAI client") | |
| if os.path.exists(f'users/data/{request.user_id}.pkl'): | |
| return {"message": f"[OK] User already exists: {request.user_id}"} | |
| user_info, _ = get_user_info(request.user_id) | |
| user = User(request.user_id, user_info, client, GENERAL_ASSISTANT) | |
| folder_path = os.path.join("mementos", "to_upload", request.user_id) | |
| os.makedirs(folder_path, exist_ok=True) | |
| add_to_cache(user) | |
| pop_cache(request.user_id) | |
| logger.info(f"Successfully created user", extra={"user_id": request.user_id, "endpoint": "/create_user"}) | |
| return {"message": {"info": f"[OK] User created: {user}", "messages": user.get_messages()}} | |
| async def chat( | |
| request: ChatItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| logger.info("Processing chat request", extra={"user_id": request.user_id, "endpoint": "/chat"}) | |
| user = get_user(request.user_id) | |
| response = user.send_message(request.message) | |
| logger.info(f"Assistant response generated", extra={"user_id": request.user_id, "endpoint": "/chat"}) | |
| return {"response": response} | |
| async def get_reminders( | |
| user_id: str, | |
| date:str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| print_log("INFO","Getting reminders", extra={"user_id": user_id, "endpoint": "/reminders"}) | |
| logger.info("Getting reminders", extra={"user_id": user_id, "endpoint": "/reminders"}) | |
| user = get_user(user_id) | |
| reminders = user.get_reminders(date) | |
| if len(reminders) == 0: | |
| print_log("INFO",f"No reminders for {date}", extra={"user_id": user_id, "endpoint": "/reminders"}) | |
| logger.info(f"No reminders for {date}", extra={"user_id": user_id, "endpoint": "/reminders"}) | |
| reminders = None | |
| print_log("INFO",f"Successfully retrieved reminders: {reminders}", extra={"user_id": user_id, "endpoint": "/reminders"}) | |
| logger.info(f"Successfully retrieved reminders: {reminders} for {date}", extra={"user_id": user_id, "endpoint": "/reminders"}) | |
| return {"reminders": reminders} | |
| async def change_date( | |
| request: ChangeDateItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| logger.info(f"Processing date change request", extra={"user_id": request.user_id, "endpoint": "/change_date"}) | |
| user = get_user(request.user_id) | |
| # Validate date format | |
| try: | |
| datetime.strptime(request.date, "%d-%m-%Y %a %H:%M:%S") | |
| except ValueError: | |
| # HF format is YYYY-MM-DD | |
| try: | |
| request.date = datetime.strptime(request.date, "%Y-%m-%d") | |
| # convert to '%d-%m-%Y %a 10:00:00' | |
| request.date = request.date.strftime("%d-%m-%Y %a 10:00:00") | |
| except ValueError as e: | |
| raise FastAPIError( | |
| message="Invalid date format", | |
| e=str(e) | |
| ) | |
| # Upload mementos to DB | |
| upload_mementos_to_db(request.user_id) | |
| # Change date and get response | |
| response = user.change_date(request.date) | |
| response['user_id'] = request.user_id | |
| # Update cache | |
| add_to_cache(user) | |
| pop_cache(user.user_id) | |
| return response | |
| async def reset_user_messages( | |
| request: CreateUserItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| print_log("INFO","Resetting messages", extra={"user_id": request.user_id, "endpoint": "/reset_user"}) | |
| logger.info("Resetting messages", extra={"user_id": request.user_id, "endpoint": "/reset_user"}) | |
| user = get_user(request.user_id) | |
| user.reset_conversations() | |
| print_log("INFO",f"Successfully reset messages for user: {request.user_id}", extra={"user_id": request.user_id, "endpoint": "/reset_user"}) | |
| logger.info(f"Successfully reset messages for user: {request.user_id}", extra={"user_id": request.user_id, "endpoint": "/reset_user"}) | |
| add_to_cache(user) | |
| update = pop_cache(user.user_id) | |
| print_log("INFO",f"Successfully updated user pickle: {request.user_id}", extra={"user_id": request.user_id, "endpoint": "/reset_user"}) | |
| logger.info(f"Successfully updated user pickle: {request.user_id}", extra={"user_id": request.user_id, "endpoint": "/reset_user"}) | |
| return {"response": "ok"} | |
| async def get_logs( | |
| user_id: str = Query(default="", description="User ID to fetch logs for") | |
| ): | |
| if (user_id): | |
| log_file_path = os.path.join('logs', 'users', f'{user_id}.log') | |
| if not os.path.exists(log_file_path): | |
| print_log("INFO",f"Log file not found for user: {user_id}", extra={"user_id": user_id, "endpoint": "/get_logs"}) | |
| logger.error(f"Log file not found for user: {user_id}", extra={"user_id": user_id, "endpoint": "/get_logs"}) | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Log file for user {user_id} not found" | |
| ) | |
| else: | |
| log_file_path = 'logs/app.log' | |
| def file_iterator(): | |
| with open(log_file_path, 'rb') as f: | |
| while chunk := f.read(8192): | |
| yield chunk | |
| return StreamingResponse( | |
| file_iterator(), | |
| media_type='text/plain', | |
| headers={'Content-Disposition': f'attachment; filename="{os.path.basename(log_file_path)}"'} | |
| ) | |
| async def is_user_responsive( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| logger.info("Checking if user is responsive", extra={"user_id": user_id, "endpoint": "/is_user_responsive"}) | |
| user = get_user(user_id) | |
| messages = user.get_messages() | |
| if len(messages) >= 3 and messages[-1]['role'] == 'assistant' and messages[-2]['role'] == 'assistant': | |
| return {"response": False} | |
| else: | |
| return {"response": True} | |
| async def get_summary_by_id( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| print_log("INFO", "Getting user's summary", extra={"user_id": user_id, "endpoint": "/get_user_summary"}) | |
| logger.info("Getting user's summary", extra={"user_id": user_id, "endpoint": "/get_user_summary"}) | |
| user_summary = get_user_summary(user_id) | |
| print_log("INFO", "Successfully generated summary", extra={"user_id": user_id, "endpoint": "/get_user_summary"}) | |
| logger.info("Successfully generated summary", extra={"user_id": user_id, "endpoint": "/get_user_summary"}) | |
| return user_summary | |
| async def get_life_status_by_id( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| print_log("INFO", "Getting user's life status", extra={"user_id": user_id, "endpoint": "/get_life_status"}) | |
| logger.info("Getting user's life status", extra={"user_id": user_id, "endpoint": "/get_life_status"}) | |
| life_status = get_user_life_status(user_id) | |
| print_log("INFO", "Successfully generated life status", extra={"user_id": user_id, "endpoint": "/get_life_status"}) | |
| logger.info("Successfully generated life status", extra={"user_id": user_id, "endpoint": "/get_life_status"}) | |
| return life_status | |
| async def add_booking_point_by_user( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| user = get_user(user_id) | |
| user.add_point_for_booking() | |
| return {"response": "ok"} | |
| async def add_session_completion_point_by_user( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| user = get_user(user_id) | |
| user.add_point_for_completing_session() | |
| return {"response": "ok"} | |
| async def create_pre_gg_by_booking( | |
| request: BookingItem, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| create_pre_gg_report(request.booking_id) | |
| return {"response": "ok"} | |
| async def get_user_persona( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| """Get user's legendary persona from the database""" | |
| logger.info("Getting user's persona", extra={"user_id": user_id, "endpoint": "/get_user_persona"}) | |
| # Connect to database | |
| db_params = { | |
| 'dbname': 'ourcoach', | |
| 'user': 'ourcoach', | |
| 'password': 'hvcTL3kN3pOG5KteT17T', | |
| 'host': 'staging-ourcoach.cx8se8o0iaiy.ap-southeast-1.rds.amazonaws.com', | |
| 'port': '5432' | |
| } | |
| conn = psycopg2.connect(**db_params) | |
| cur = conn.cursor() | |
| # Get onboarding data | |
| cur.execute("SELECT onboarding FROM users WHERE id = %s", (user_id,)) | |
| result = cur.fetchone() | |
| if not result: | |
| raise DBError( | |
| user_id=user_id, | |
| code="NoOnboardingError", | |
| message="User not found in database" | |
| ) | |
| # Extract persona from onboarding JSON | |
| onboarding = json.loads(result[0]) | |
| persona = onboarding.get('legendPersona', '') | |
| if 'cur' in locals(): | |
| cur.close() | |
| if 'conn' in locals(): | |
| conn.close() | |
| return {"persona": persona} | |
| async def get_recent_booking( | |
| user_id: str, | |
| api_key: str = Depends(get_api_key) # Change Security to Depends | |
| ): | |
| """Get the most recent booking ID for a user""" | |
| logger.info("Getting recent booking", extra={"user_id": user_id, "endpoint": "/get_recent_booking"}) | |
| # Connect to database | |
| db_params = { | |
| 'dbname': 'ourcoach', | |
| 'user': 'ourcoach', | |
| 'password': 'hvcTL3kN3pOG5KteT17T', | |
| 'host': 'staging-ourcoach.cx8se8o0iaiy.ap-southeast-1.rds.amazonaws.com', | |
| 'port': '5432' | |
| } | |
| conn = psycopg2.connect(**db_params) | |
| cur = conn.cursor() | |
| # Get most recent booking where status == 2 | |
| cur.execute(""" | |
| SELECT booking_id | |
| FROM public.user_notes | |
| WHERE user_id = %s | |
| ORDER BY created_at DESC | |
| LIMIT 1 | |
| """, (user_id,)) | |
| result = cur.fetchone() | |
| if not result: | |
| raise DBError( | |
| user_id=user_id, | |
| code="NoBookingError", | |
| message="No bookings found for user" | |
| ) | |
| booking_id = result[0] | |
| logger.info(f"Found recent booking: {booking_id}", extra={"user_id": user_id, "endpoint": "/get_recent_booking"}) | |
| if 'cur' in locals(): | |
| cur.close() | |
| if 'conn' in locals(): | |
| conn.close() | |
| return {"booking_id": booking_id} | |