Spaces:
Paused
Paused
| import asyncio | |
| import time | |
| import uuid | |
| import logging | |
| from fastapi import FastAPI, HTTPException, Query | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, List | |
| from PyCharacterAI import get_client | |
| from PyCharacterAI.exceptions import SessionClosedError, RequestError | |
| import uvicorn | |
| import os | |
| app = FastAPI() | |
| DEFAULT_TOKEN = os.getenv("token") | |
| DEFAULT_CHARACTER_ID = "smtV3Vyez6ODkwS8BErmBAdgGNj-1XWU73wIFVOY1hQ" | |
| DEFAULT_VOICE_ID = "974fea59-7c26-411b-ae0d-64ff2c4e9666" | |
| MAX_RETRIES = 5 | |
| RETRY_DELAY = 1.0 | |
| BACKOFF_MULTIPLIER = 2.0 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| client = None | |
| chat_locks = {} | |
| new_chat_lock = asyncio.Lock() | |
| async def startup_event(): | |
| global client | |
| client = await get_client(token=DEFAULT_TOKEN) | |
| async def shutdown_event(): | |
| if client: | |
| await client.close_session() | |
| async def reinitialize_client(): | |
| """Helper function to reinitialize the client if the session is closed.""" | |
| global client | |
| try: | |
| if client: | |
| await client.close_session() | |
| except: | |
| pass | |
| client = await get_client(token=DEFAULT_TOKEN) | |
| def get_chat_lock(chat_id): | |
| """Get a lock specific to the chat_id or create a new one if it doesn't exist.""" | |
| if chat_id not in chat_locks: | |
| chat_locks[chat_id] = asyncio.Lock() | |
| return chat_locks[chat_id] | |
| def is_retryable_error(exception): | |
| """ | |
| Determine if an error should be retried. | |
| Args: | |
| exception: The exception to check | |
| Returns: | |
| bool: True if the error should be retried, False otherwise | |
| """ | |
| error_message = str(exception).lower() | |
| retryable_patterns = [ | |
| "maybe your token is invalid", | |
| "session closed", | |
| "connection", | |
| "timeout", | |
| "request failed", | |
| "network", | |
| "server error", | |
| "internal server error", | |
| "bad gateway", | |
| "service unavailable", | |
| "gateway timeout", | |
| "rate limit", | |
| "too many requests", | |
| "temporary", | |
| "temporarily unavailable", | |
| "request timeout", | |
| "read timeout", | |
| ] | |
| retryable_exceptions = ( | |
| RequestError, | |
| SessionClosedError, | |
| ConnectionError, | |
| TimeoutError, | |
| asyncio.TimeoutError, | |
| ) | |
| if isinstance(exception, retryable_exceptions): | |
| return True | |
| for pattern in retryable_patterns: | |
| if pattern in error_message: | |
| return True | |
| return False | |
| async def retry_with_backoff(func, *args, max_retries=MAX_RETRIES, base_delay=RETRY_DELAY, **kwargs): | |
| """ | |
| Retry a function with exponential backoff. | |
| Args: | |
| func: The async function to retry | |
| *args: Arguments to pass to the function | |
| max_retries: Maximum number of retry attempts | |
| base_delay: Base delay between retries in seconds | |
| **kwargs: Keyword arguments to pass to the function | |
| Returns: | |
| The result of the function call | |
| Raises: | |
| The last exception encountered if all retries fail | |
| """ | |
| last_exception = None | |
| for attempt in range(max_retries + 1): | |
| try: | |
| return await func(*args, **kwargs) | |
| except Exception as e: | |
| last_exception = e | |
| if not is_retryable_error(e): | |
| logger.error(f"Non-retryable error occurred: {str(e)}") | |
| raise e | |
| if attempt == max_retries: | |
| logger.error(f"All {max_retries + 1} attempts failed. Last error: {str(e)}") | |
| raise e | |
| delay = base_delay * (BACKOFF_MULTIPLIER ** attempt) | |
| logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {delay:.2f} seconds...") | |
| if isinstance(e, SessionClosedError) or "token" in str(e).lower(): | |
| logger.info("Session/token issue detected, reinitializing client...") | |
| await reinitialize_client() | |
| await asyncio.sleep(delay) | |
| if last_exception: | |
| raise last_exception | |
| async def search_characters( | |
| query: str = Query(..., description="Character name or keyword to search for"), | |
| token: Optional[str] = Query(None, description="API token for authentication."), | |
| limit: int = Query(10, description="Maximum number of results to return") | |
| ): | |
| """ | |
| Search for characters by name or keyword. | |
| Returns a list of characters matching the search criteria. | |
| """ | |
| global client | |
| token = token or DEFAULT_TOKEN | |
| try: | |
| if client is None: | |
| await reinitialize_client() | |
| async def search_operation(): | |
| return await client.character.search_characters(query) | |
| characters = await retry_with_backoff(search_operation) | |
| results = [] | |
| for char in characters[:limit]: | |
| character_info = { | |
| "id": char.character_id, | |
| "name": char.name, | |
| "greeting": char.greeting, | |
| "description": char.description, | |
| "avatar_url": char.avatar, | |
| } | |
| results.append(character_info) | |
| return {"query": query, "results": results, "total_found": len(characters), "returned": len(results)} | |
| except SessionClosedError: | |
| await reinitialize_client() | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Session was closed after retries. Please try your request again." | |
| ) | |
| except RequestError as e: | |
| raise HTTPException(status_code=500, detail=f"Character AI API error after retries: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| async def send_message( | |
| message: str, | |
| token: Optional[str] = Query(None, description="API token for authentication."), | |
| character_id: str = Query(DEFAULT_CHARACTER_ID, description="Character ID for the chat session."), | |
| chat_id: Optional[str] = Query(None, description="ID of the existing chat session, if available."), | |
| voice_id: str = Query(DEFAULT_VOICE_ID, description="Voice ID for generating speech."), | |
| voice: bool = Query(True, description="Set to true to generate voice, false to skip voice generation.") | |
| ): | |
| """ | |
| Send a message to the character. If no chat_id is provided, initialize a new chat session. | |
| Optionally generate voice for the response. | |
| """ | |
| global client | |
| token = token or DEFAULT_TOKEN | |
| try: | |
| if client is None: | |
| await reinitialize_client() | |
| greeting_text = None | |
| if not chat_id: | |
| async with new_chat_lock: | |
| try: | |
| async def create_chat_operation(): | |
| return await client.chat.create_chat(character_id) | |
| chat, greeting_message = await retry_with_backoff(create_chat_operation) | |
| chat_id = chat.chat_id | |
| chat_locks[chat_id] = asyncio.Lock() | |
| greeting_text = { | |
| "author": greeting_message.author_name, | |
| "text": greeting_message.get_primary_candidate().text | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to create new chat after retries: {str(e)}" | |
| ) | |
| chat_lock = get_chat_lock(chat_id) | |
| async with chat_lock: | |
| async def send_message_operation(): | |
| return await client.chat.send_message(character_id, chat_id, message) | |
| answer = await retry_with_backoff(send_message_operation) | |
| response_text = answer.get_primary_candidate().text | |
| speech_url = None | |
| if voice: | |
| try: | |
| async def generate_speech_operation(): | |
| return await client.utils.generate_speech( | |
| chat_id, | |
| answer.turn_id, | |
| answer.get_primary_candidate().candidate_id, | |
| voice_id, | |
| return_url=True | |
| ) | |
| speech_url = await retry_with_backoff( | |
| generate_speech_operation, | |
| max_retries=5, | |
| base_delay=0.5 | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Voice generation failed after retries: {e}") | |
| response_data = { | |
| "chat_id": chat_id, | |
| "author": answer.author_name, | |
| "response": response_text, | |
| "voice_url": speech_url if voice else None | |
| } | |
| if greeting_text: | |
| response_data["greeting_message"] = greeting_text | |
| return response_data | |
| except SessionClosedError: | |
| await reinitialize_client() | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Session was closed after retries. Please try your request again." | |
| ) | |
| except RequestError as e: | |
| raise HTTPException(status_code=500, detail=f"Character AI API error after retries: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| async def start_cleanup_task(): | |
| async def cleanup_locks(): | |
| while True: | |
| await asyncio.sleep(3600) | |
| current_time = time.time() | |
| locks_to_remove = [] | |
| for lock_id in list(chat_locks.keys()): | |
| lock = chat_locks.get(lock_id) | |
| if lock and not lock.locked(): | |
| pass | |
| asyncio.create_task(cleanup_locks()) | |
| def root(): | |
| return {"message": "Welcome to the Character AI API with FastAPI!"} | |
| def health_check(): | |
| return {"status": "healthy", "client_initialized": client is not None} | |
| if __name__ == "__main__": | |
| uvicorn.run("main:app", host="0.0.0.0", port=7860, workers=8, timeout_keep_alive=60000) |