Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, BackgroundTasks, Response | |
| from pydantic import BaseModel, Field, ConfigDict | |
| from typing import Optional, List, Dict | |
| import sys | |
| import re | |
| import json | |
| from huggingface_hub import repo_info | |
| from db import db_client | |
| from agents import create_archivist, create_librarian, run_recruiter, create_registrar, narrator, get_model, IMAGE_MODELS, AUDIO_MODELS, REASONING_MODELS | |
| from persistence import save_to_dataset, load_from_dataset, list_saves, delete_save, persistence_manager, get_cached_media, save_cached_media | |
| from vector import vector_model | |
| import os | |
| import asyncio | |
| import httpx | |
| app = FastAPI(title="Grim Fable: World Memory") | |
| async def startup_event(): | |
| # Verify environment on startup | |
| token = os.getenv("HF_TOKEN") | |
| dataset_id = os.getenv("DATASET_ID") | |
| if not token: | |
| print("CRITICAL: HF_TOKEN is missing!") | |
| if not dataset_id: | |
| print("CRITICAL: DATASET_ID is missing!") | |
| # Check dataset accessibility | |
| if token and dataset_id: | |
| try: | |
| repo_info(repo_id=dataset_id, repo_type="dataset", token=token) | |
| print(f"Verified accessibility of dataset: {dataset_id}") | |
| except Exception as e: | |
| print(f"WARNING: Could not access dataset {dataset_id}: {e}") | |
| # Bootstrap rules into the vector DB before starting | |
| try: | |
| print("Starting rule bootstrapping...") | |
| await perform_bootstrap() | |
| print("Rule bootstrapping completed.") | |
| except Exception as e: | |
| print(f"FAILED to bootstrap world: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| asyncio.create_task(inactivity_autosave_loop()) | |
| async def inactivity_autosave_loop(): | |
| while True: | |
| try: | |
| await asyncio.sleep(60) | |
| if persistence_manager.should_autosave(interaction_happened=False): | |
| print(f"Periodic autosave triggered for: {persistence_manager.current_save_name}") | |
| await perform_save(persistence_manager.current_save_name) | |
| except Exception as e: | |
| print(f"Error in inactivity_autosave_loop: {e}") | |
| await asyncio.sleep(10) # Wait a bit before retrying | |
| class InteractRequest(BaseModel): | |
| user_input: str | |
| narrator_model: Optional[str] = None | |
| reasoning_model: Optional[str] = None | |
| image_model: Optional[str] = None | |
| audio_model: Optional[str] = None | |
| class ChatMessageDto(BaseModel): | |
| text: str | |
| is_user: bool | |
| class ConsultGmRequest(BaseModel): | |
| user_input: str | |
| history: List[ChatMessageDto] | |
| reasoning_model: Optional[str] = None | |
| class CharacterDraft(BaseModel): | |
| model_config = ConfigDict(populate_by_name=True) | |
| name: str | |
| race: str | |
| class_name: str = Field(alias="class") | |
| stats: Dict[str, int] | |
| skills: List[str] | |
| items: List[str] | |
| starting_context: str | |
| class WorldInitRequest(BaseModel): | |
| prompt: Optional[str] = None | |
| history: Optional[List[ChatMessageDto]] = None | |
| save_name: Optional[str] = None | |
| description: Optional[str] = None | |
| reasoning_model: Optional[str] = None | |
| class ConfirmRequest(BaseModel): | |
| draft: CharacterDraft | |
| save_name: str | |
| description: Optional[str] = None | |
| reasoning_model: Optional[str] = None | |
| async def root(): | |
| return {"status": "online", "engine": "FalkorDB", "mock_mode": db_client.is_mock} | |
| async def health(): | |
| try: | |
| # Simple query to verify DB connectivity | |
| db_client.query("MATCH (n) RETURN count(n)") | |
| return {"status": "ready", "database": "connected"} | |
| except Exception as e: | |
| return {"status": "degraded", "database": "disconnected", "error": str(e)} | |
| async def consult_gm(request: ConsultGmRequest): | |
| try: | |
| print(f"Received GM consultation request: {request.user_input}") | |
| librarian = create_librarian(request.reasoning_model) | |
| if not librarian: | |
| raise HTTPException(status_code=500, detail="Failed to initialize Librarian for GM consultation.") | |
| # Search for rules and context to answer the question | |
| rules_context = await asyncio.to_thread( | |
| librarian.run, f"Use vector_search on index 'Rule' to find information that helps answer this GM question: {request.user_input}" | |
| ) | |
| # Get world state context if possible | |
| world_context = "" | |
| try: | |
| player_data = db_client.query("MATCH (p:Player) RETURN p.name, p.hp, p.max_hp, p.strength, p.dexterity, p.constitution, p.intelligence, p.wisdom, p.charisma, p.x, p.y") | |
| if player_data: | |
| world_context = f"Current Player Stats: {player_data}" | |
| except: pass | |
| history_str = "\n".join([f"{'User' if m.is_user else 'GM'}: {m.text}" for m in request.history[-5:]]) # Last 5 messages for context | |
| model = get_model(request.reasoning_model) | |
| prompt = f""" | |
| You are the Game Master (GM). A player is consulting you for information or rule clarification. | |
| Relevant Rules: {rules_context} | |
| World Context: {world_context} | |
| Recent Consultation History: | |
| {history_str} | |
| Current Question: {request.user_input} | |
| Provide a helpful, concise, and in-character response based on the rules and world state. Do NOT change any data in the database. | |
| """ | |
| # model is an HfApiModel, it has a generate method | |
| response = await asyncio.to_thread(model.generate, messages=[{"role": "user", "content": prompt}]) | |
| response_text = response.content if hasattr(response, 'content') else str(response) | |
| return {"response": response_text} | |
| except Exception as e: | |
| print(f"Error in /consult_gm: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def interact(request: InteractRequest, background_tasks: BackgroundTasks): | |
| try: | |
| print(f"Received interaction request: {request.user_input}") | |
| persistence_manager.update_interaction() | |
| should_save = persistence_manager.should_autosave(interaction_happened=True) | |
| librarian = create_librarian(request.reasoning_model) | |
| archivist = create_archivist(request.reasoning_model) | |
| if not librarian or not archivist: | |
| raise HTTPException(status_code=500, detail="Failed to initialize Librarian or Archivist agents.") | |
| rules_context = await asyncio.to_thread( | |
| librarian.run, f"Use vector_search on index 'Rule' to find relevant game mechanics for: {request.user_input}" | |
| ) | |
| memories_context = await asyncio.to_thread( | |
| librarian.run, f"Use vector_search on index 'Memory' to find relevant past events for: {request.user_input}" | |
| ) | |
| context = f"Rules:\n{rules_context}\n\nMemories:\n{memories_context}" | |
| # Archivist agent execution is blocking, run in thread | |
| changes = await asyncio.to_thread( | |
| archivist.run, f"Context: {context}\nUser Action: {request.user_input}\nUpdate the world state in FalkorDB. Use the smart_validator to ensure the Cypher respects game rules and user intent. Coordinate tracking (x,y) is mandatory. Summarize the changes." | |
| ) | |
| response = await asyncio.to_thread( | |
| narrator.run, context=context, user_input=request.user_input, changes=changes, model_id=request.narrator_model | |
| ) | |
| background_tasks.add_task(record_memory, request.user_input, response) | |
| if should_save: | |
| background_tasks.add_task(perform_save, persistence_manager.current_save_name) | |
| return {"response": response} | |
| except Exception as e: | |
| print(f"Error in /interact: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if isinstance(e, HTTPException): raise e | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def record_memory(user_input: str, response: str): | |
| try: | |
| model = get_model() | |
| summary_prompt = f"Summarize this game exchange for long-term memory:\nUser: {user_input}\nNarrator: {response}" | |
| # Model inference is blocking, run in thread | |
| summary = await asyncio.to_thread(model.generate, messages=[{"role": "user", "content": summary_prompt}]) | |
| summary_text = summary.content if hasattr(summary, 'content') else str(summary) | |
| content = f"User: {user_input} | Narrator: {response}" | |
| embedding = vector_model.encode(summary_text) | |
| query = "CREATE (m:Memory {content: $content, summary: $summary, embedding: vecf32($embedding), timestamp: timestamp()})" | |
| await asyncio.to_thread(db_client.query, query, { | |
| "content": content, | |
| "summary": summary_text, | |
| "embedding": embedding | |
| }) | |
| except Exception as e: | |
| print(f"Error recording memory: {e}") | |
| async def perform_save(save_name: str, description: str = None): | |
| try: | |
| db_path = "/data/world.db" | |
| await asyncio.to_thread(db_client.save_db, db_path) | |
| if await asyncio.to_thread(save_to_dataset, save_name, db_path, description=description): | |
| persistence_manager.needs_save = False | |
| persistence_manager.last_save_time = os.path.getmtime(db_path) if os.path.exists(db_path) else persistence_manager.last_save_time | |
| except Exception as e: | |
| print(f"Error performing save: {e}") | |
| async def save_world(save_name: str, background_tasks: BackgroundTasks, description: Optional[str] = None): | |
| persistence_manager.current_save_name = save_name | |
| background_tasks.add_task(perform_save, save_name, description) | |
| return {"status": "save_initiated", "save_name": save_name} | |
| async def delete_world_save(save_name: str): | |
| if delete_save(save_name): | |
| return {"status": "deleted", "save_name": save_name} | |
| raise HTTPException(status_code=404, detail="Save not found or could not be deleted") | |
| async def load_world(save_name: str, background_tasks: BackgroundTasks): | |
| try: | |
| db_path = "/data/world.db" | |
| if await asyncio.to_thread(load_from_dataset, save_name, db_path): | |
| await asyncio.to_thread(db_client.load_db, db_path) | |
| persistence_manager.current_save_name = save_name | |
| # Restart after sending response to apply new RDB file | |
| background_tasks.add_task(restart_server) | |
| return {"status": "loaded", "save_name": save_name, "notice": "Server restarting to apply world state"} | |
| raise HTTPException(status_code=404, detail="Save not found") | |
| except Exception as e: | |
| print(f"Error in /world/load: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if isinstance(e, HTTPException): raise e | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_saves(): | |
| try: | |
| return {"saves": list_saves()} | |
| except Exception as e: | |
| print(f"Error in /world/saves: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_media(entity_id: str, media_type: str, prompt: str, model_id: Optional[str] = None): | |
| """Fetches cached media or generates new media for an entity asynchronously with fallback logic.""" | |
| if model_id and model_id.lower() == "disabled": | |
| raise HTTPException(status_code=400, detail="Media generation is disabled") | |
| save_name = persistence_manager.current_save_name | |
| cached_content = get_cached_media(save_name, entity_id, media_type) | |
| if cached_content: | |
| mime = "image/webp" if media_type == "image" else "audio/mpeg" | |
| return Response(content=cached_content, media_type=mime) | |
| # Preparation for generation | |
| models_to_try = [] | |
| if model_id: | |
| models_to_try.append(model_id) | |
| base_list = IMAGE_MODELS if media_type == "image" else AUDIO_MODELS | |
| for m in base_list: | |
| if m not in models_to_try: | |
| models_to_try.append(m) | |
| headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"} | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| last_error = "Unknown error" | |
| for mid in models_to_try: | |
| try: | |
| api_url = f"https://api-inference.huggingface.co/models/{mid}" | |
| response = await client.post(api_url, headers=headers, json={"inputs": prompt}) | |
| if response.status_code == 200: | |
| save_cached_media(save_name, entity_id, media_type, response.content) | |
| mime = "image/webp" if media_type == "image" else "audio/mpeg" | |
| return Response(content=response.content, media_type=mime) | |
| else: | |
| last_error = f"Model {mid} failed: {response.status_code} {response.text}" | |
| continue | |
| except Exception as e: | |
| last_error = str(e) | |
| continue | |
| raise HTTPException(status_code=500, detail=f"All generation attempts failed. Last error: {last_error}") | |
| async def get_npc(npc_id: str): | |
| # Support retrieving both Players and NPCs, and match by id or name for legacy compatibility | |
| # RETURN properties(n) ensures we get a dictionary in both mock and live DB | |
| query = "MATCH (n) WHERE (n:NPC OR n:Player) AND (n.id = $id OR n.name = $id) RETURN properties(n)" | |
| result = db_client.query(query, {"id": npc_id}) | |
| if not result or len(result) == 0: | |
| raise HTTPException(status_code=404, detail="Entity not found") | |
| # result[0][0] is the node properties dictionary | |
| entity = result[0][0] | |
| return {"npc": entity} | |
| async def get_map(): | |
| # Only return Player and Location nodes. NPCs are tracked but not displayed. | |
| query = """ | |
| MATCH (n) | |
| WHERE n:Player OR n:Location | |
| RETURN labels(n)[0], n.id, n.name, n.x, n.y | |
| """ | |
| print(f"Executing map query: {query}") | |
| result = db_client.query(query) | |
| print(f"Map query result: {result}") | |
| entities = [] | |
| if result: | |
| for row in result: | |
| entities.append({ | |
| "type": row[0], | |
| "id": row[1] or row[2], # Fallback to name if id is missing | |
| "name": row[2], | |
| "x": int(row[3]) if row[3] is not None else 0, | |
| "y": int(row[4]) if row[4] is not None else 0 | |
| }) | |
| return {"entities": entities} | |
| def robust_json_extract(text: str): | |
| """Extracts JSON and returns (draft_dict, clean_text).""" | |
| draft = None | |
| clean_text = text | |
| # Try markdown block first | |
| json_match = re.search(r"(```(?:json)?\s*(.*?)\s*```)", text, re.DOTALL) | |
| if json_match: | |
| full_block = json_match.group(1) | |
| content = json_match.group(2).strip() | |
| clean_text = text.replace(full_block, "").strip() | |
| try: | |
| draft = json.loads(content) | |
| except: | |
| try: | |
| content = re.sub(r",\s*([\]}])", r"\1", content) | |
| draft = json.loads(content) | |
| except: pass | |
| if draft: return draft, clean_text | |
| # Try searching for any { ... } block | |
| brace_match = re.search(r"({.*})", text, re.DOTALL) | |
| if brace_match: | |
| full_match = brace_match.group(1) | |
| try: | |
| draft = json.loads(full_match.strip()) | |
| clean_text = text.replace(full_match, "").strip() | |
| return draft, clean_text | |
| except: pass | |
| return None, text | |
| async def init_world(request: WorldInitRequest): | |
| try: | |
| print(f"Character creation step for: {request.prompt}") | |
| user_input = request.prompt or "Hello" | |
| history = request.history or [] | |
| response_text = await asyncio.to_thread(run_recruiter, user_input, history, request.reasoning_model) | |
| draft, clean_text = robust_json_extract(response_text) | |
| return { | |
| "response": clean_text, | |
| "draft": draft | |
| } | |
| except Exception as e: | |
| print(f"Error in /world/init: {e}") | |
| if "BadRequestError" in str(type(e)): | |
| try: | |
| # Attempt to log response content if it exists | |
| if hasattr(e, 'response') and hasattr(e.response, 'text'): | |
| print(f"Bad Request Response Body: {e.response.text}") | |
| except: pass | |
| import traceback | |
| traceback.print_exc() | |
| if isinstance(e, HTTPException): raise e | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def confirm_world(request: ConfirmRequest, background_tasks: BackgroundTasks): | |
| try: | |
| print(f"Confirming character: {request.draft.name}") | |
| registrar = create_registrar(request.reasoning_model) | |
| if not registrar: | |
| raise HTTPException(status_code=500, detail="Failed to initialize Registrar agent.") | |
| # Registrar handles DB creation | |
| # Use by_alias=True to ensure 'class' is used in the JSON passed to the registrar | |
| await asyncio.to_thread(registrar.run, f"Validated Character Draft: {request.draft.model_dump_json(by_alias=True)}") | |
| # Use character name as the save slot name | |
| save_name = request.draft.name | |
| persistence_manager.current_save_name = save_name | |
| # Character confirmed - immediate save | |
| await perform_save(save_name, request.description) | |
| return {"status": "initialized", "save_name": save_name} | |
| except Exception as e: | |
| print(f"Error in /world/confirm: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def bootstrap_world(background_tasks: BackgroundTasks): | |
| background_tasks.add_task(perform_bootstrap) | |
| return {"status": "bootstrap_initiated"} | |
| async def perform_bootstrap(): | |
| import json | |
| rules_path = "dnd_srd_rules.json" | |
| if not os.path.exists(rules_path): | |
| print(f"Rules file not found at {rules_path}") | |
| return | |
| with open(rules_path, "r") as f: | |
| rules = json.load(f) | |
| # Attempt to create indices, ignore if they already exist | |
| # Using the new syntax for FalkorDB 4.x+ | |
| try: | |
| db_client.query("CREATE VECTOR INDEX FOR (r:Rule) ON (r.embedding) OPTIONS {dimension: 384, similarityFunction: 'cosine'}") | |
| except Exception as e: | |
| if "already exists" not in str(e).lower(): print(f"Notice creating Rule index: {e}") | |
| try: | |
| db_client.query("CREATE VECTOR INDEX FOR (m:Memory) ON (m.embedding) OPTIONS {dimension: 384, similarityFunction: 'cosine'}") | |
| except Exception as e: | |
| if "already exists" not in str(e).lower(): print(f"Notice creating Memory index: {e}") | |
| # Fetch all existing rule titles once to optimize bootstrapping | |
| existing_rules = db_client.query("MATCH (r:Rule) RETURN r.title") | |
| existing_titles = set() | |
| if existing_rules: | |
| for row in existing_rules: | |
| if isinstance(row, list) and len(row) > 0: | |
| existing_titles.add(row[0]) | |
| for rule in rules: | |
| title = rule.get("title") | |
| if title in existing_titles: | |
| continue | |
| content = rule.get("content") | |
| tags = rule.get("tags", []) | |
| embedding = vector_model.encode(content) | |
| query = """ | |
| CREATE (r:Rule {title: $title, content: $content, tags: $tags, embedding: vecf32($embedding)}) | |
| """ | |
| db_client.query(query, { | |
| "title": title, | |
| "content": content, | |
| "tags": tags, | |
| "embedding": embedding | |
| }) | |
| print(f"Successfully bootstrapped {len(rules)} rules.") | |
| async def restart_server(): | |
| print("Initiating server restart in 2 seconds...") | |
| await asyncio.sleep(2) | |
| sys.exit(0) | |