from fastapi import ( APIRouter, status, Depends, BackgroundTasks, HTTPException, File, UploadFile, Form, ) from fastapi.responses import JSONResponse from src.utils.logger import logger from src.agents.role_play.flow import role_play_agent from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional from src.agents.role_play.scenarios import get_scenarios, get_scenario_by_id import json import os import uuid from datetime import datetime import base64 router = APIRouter(prefix="/ai", tags=["AI"]) class RoleplayRequest(BaseModel): query: str = Field(..., description="User's query for the AI agent") session_id: str = Field( ..., description="Session ID for tracking user interactions" ) scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay") class SessionRequest(BaseModel): session_id: str = Field(..., description="Session ID to perform operations on") class CreateSessionRequest(BaseModel): name: str = Field(..., description="Name for the new session") class UpdateSessionRequest(BaseModel): session_id: str = Field(..., description="Session ID to update") name: str = Field(..., description="New name for the session") # Session management helper functions SESSIONS_FILE = "sessions.json" def load_sessions() -> List[Dict[str, Any]]: """Load sessions from JSON file""" try: if os.path.exists(SESSIONS_FILE): with open(SESSIONS_FILE, "r", encoding="utf-8") as f: return json.load(f) return [] except Exception as e: logger.error(f"Error loading sessions: {str(e)}") return [] def save_sessions(sessions: List[Dict[str, Any]]): """Save sessions to JSON file""" try: with open(SESSIONS_FILE, "w", encoding="utf-8") as f: json.dump(sessions, f, ensure_ascii=False, indent=2, default=str) except Exception as e: logger.error(f"Error saving sessions: {str(e)}") def create_session(name: str) -> Dict[str, Any]: """Create a new session""" session_id = str(uuid.uuid4()) session = { "id": session_id, "name": name, "created_at": datetime.now().isoformat(), "last_message": None, "message_count": 0, } sessions = load_sessions() sessions.append(session) save_sessions(sessions) return session def get_session_by_id(session_id: str) -> Optional[Dict[str, Any]]: """Get session by ID""" sessions = load_sessions() return next((s for s in sessions if s["id"] == session_id), None) def update_session_last_message(session_id: str, message: str): """Update session's last message""" sessions = load_sessions() for session in sessions: if session["id"] == session_id: session["last_message"] = message session["message_count"] = session.get("message_count", 0) + 1 break save_sessions(sessions) def delete_session_by_id(session_id: str) -> bool: """Delete session by ID""" sessions = load_sessions() original_count = len(sessions) sessions = [s for s in sessions if s["id"] != session_id] if len(sessions) < original_count: save_sessions(sessions) return True return False @router.get("/scenarios", status_code=status.HTTP_200_OK) async def list_scenarios(): """Get all available scenarios""" return JSONResponse(content=get_scenarios()) @router.post("/roleplay", status_code=status.HTTP_200_OK) async def roleplay( session_id: str = Form( ..., description="Session ID for tracking user interactions" ), scenario: str = Form( ..., description="The scenario for the roleplay as JSON string" ), text_message: Optional[str] = Form(None, description="Text message from user"), audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), ): """Send a message (text or audio) to the roleplay agent""" # Validate that at least one input is provided if not text_message and not audio_file: raise HTTPException( status_code=400, detail="Either text_message or audio_file must be provided" ) # Parse scenario from JSON string try: scenario_dict = json.loads(scenario) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid scenario JSON format") if not scenario_dict: raise HTTPException(status_code=400, detail="Scenario not provided") # Prepare message content message_content = [] # Handle text input if text_message: message_content.append({"type": "text", "text": text_message}) # Handle audio input if audio_file: try: # Read audio file content audio_data = await audio_file.read() # Convert to base64 audio_base64 = base64.b64encode(audio_data).decode("utf-8") # Determine mime type based on file extension file_extension = ( audio_file.filename.split(".")[-1].lower() if audio_file.filename else "wav" ) mime_type_map = { "wav": "audio/wav", "mp3": "audio/mpeg", "ogg": "audio/ogg", "webm": "audio/webm", "m4a": "audio/mp4", } mime_type = mime_type_map.get(file_extension, "audio/wav") message_content.append( { "type": "audio", "source_type": "base64", "data": audio_base64, "mime_type": mime_type, } ) except Exception as e: logger.error(f"Error processing audio file: {str(e)}") raise HTTPException( status_code=400, detail=f"Error processing audio file: {str(e)}" ) # Create message in the required format message = {"role": "user", "content": message_content} try: response = await role_play_agent().ainvoke( { "messages": [message], "scenario_title": scenario_dict["scenario_title"], "scenario_description": scenario_dict["scenario_description"], "scenario_context": scenario_dict["scenario_context"], "your_role": scenario_dict["your_role"], "key_vocabulary": scenario_dict["key_vocabulary"], }, {"configurable": {"thread_id": session_id}}, ) last_message = text_message if text_message else "[Audio message]" update_session_last_message(session_id, last_message) # Extract AI response content ai_response = response["messages"][-1].content logger.info(f"AI response: {ai_response}") return JSONResponse(content={"response": ai_response}) except Exception as e: logger.error(f"Error in roleplay: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @router.post("/get-messages", status_code=status.HTTP_200_OK) async def get_messages(request: SessionRequest): """Get all messages from a conversation session""" try: # Create agent instance agent = role_play_agent() # Get current state current_state = agent.get_state( {"configurable": {"thread_id": request.session_id}} ) if not current_state or not current_state.values: return JSONResponse( content={ "session_id": request.session_id, "messages": [], "total_messages": 0, } ) # Extract messages from state messages = [] if "messages" in current_state.values: raw_messages = current_state.values["messages"] for msg in raw_messages: # Convert message object to dict format if hasattr(msg, "content") and hasattr(msg, "type"): messages.append( { "role": getattr(msg, "type", "unknown"), "content": getattr(msg, "content", ""), "timestamp": getattr(msg, "timestamp", None), } ) elif hasattr(msg, "content"): # Handle different message formats role = ( "human" if hasattr(msg, "__class__") and "Human" in msg.__class__.__name__ else "ai" ) messages.append( { "role": role, "content": msg.content, "timestamp": getattr(msg, "timestamp", None), } ) else: # Fallback for unexpected message format messages.append( {"role": "unknown", "content": str(msg), "timestamp": None} ) return JSONResponse( content={ "session_id": request.session_id, "messages": messages, "total_messages": len(messages), } ) except Exception as e: logger.error( f"Error getting messages for session {request.session_id}: {str(e)}" ) raise HTTPException(status_code=500, detail=f"Failed to get messages: {str(e)}") @router.get("/sessions", status_code=status.HTTP_200_OK) async def get_sessions(): """Get all sessions""" try: sessions = load_sessions() return JSONResponse(content={"sessions": sessions}) except Exception as e: logger.error(f"Error getting sessions: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get sessions: {str(e)}") @router.post("/sessions", status_code=status.HTTP_201_CREATED) async def create_new_session(request: CreateSessionRequest): """Create a new session""" try: session = create_session(request.name) return JSONResponse(content={"session": session}) except Exception as e: logger.error(f"Error creating session: {str(e)}") raise HTTPException( status_code=500, detail=f"Failed to create session: {str(e)}" ) @router.get("/sessions/{session_id}", status_code=status.HTTP_200_OK) async def get_session(session_id: str): """Get a specific session by ID""" try: session = get_session_by_id(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") return JSONResponse(content={"session": session}) except HTTPException: raise except Exception as e: logger.error(f"Error getting session {session_id}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get session: {str(e)}") @router.put("/sessions/{session_id}", status_code=status.HTTP_200_OK) async def update_session(session_id: str, request: UpdateSessionRequest): """Update a session""" try: sessions = load_sessions() session_found = False for session in sessions: if session["id"] == session_id: session["name"] = request.name session_found = True break if not session_found: raise HTTPException(status_code=404, detail="Session not found") save_sessions(sessions) updated_session = get_session_by_id(session_id) return JSONResponse(content={"session": updated_session}) except HTTPException: raise except Exception as e: logger.error(f"Error updating session {session_id}: {str(e)}") raise HTTPException( status_code=500, detail=f"Failed to update session: {str(e)}" ) @router.delete("/sessions/{session_id}", status_code=status.HTTP_200_OK) async def delete_session(session_id: str): """Delete a session""" try: success = delete_session_by_id(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") return JSONResponse(content={"message": "Session deleted successfully"}) except HTTPException: raise except Exception as e: logger.error(f"Error deleting session {session_id}: {str(e)}") raise HTTPException( status_code=500, detail=f"Failed to delete session: {str(e)}" )