from fastapi import FastAPI, HTTPException, BackgroundTasks, Response from pydantic import BaseModel, Field, ConfigDict from typing import Optional, List, Dict import sys import re import json from huggingface_hub import repo_info from db import db_client from agents import create_archivist, create_librarian, run_recruiter, create_registrar, narrator, get_model, IMAGE_MODELS, AUDIO_MODELS, REASONING_MODELS from persistence import save_to_dataset, load_from_dataset, list_saves, delete_save, persistence_manager, get_cached_media, save_cached_media from vector import vector_model import os import asyncio import httpx app = FastAPI(title="Grim Fable: World Memory") @app.on_event("startup") async def startup_event(): # Verify environment on startup token = os.getenv("HF_TOKEN") dataset_id = os.getenv("DATASET_ID") if not token: print("CRITICAL: HF_TOKEN is missing!") if not dataset_id: print("CRITICAL: DATASET_ID is missing!") # Check dataset accessibility if token and dataset_id: try: repo_info(repo_id=dataset_id, repo_type="dataset", token=token) print(f"Verified accessibility of dataset: {dataset_id}") except Exception as e: print(f"WARNING: Could not access dataset {dataset_id}: {e}") # Bootstrap rules into the vector DB before starting try: print("Starting rule bootstrapping...") await perform_bootstrap() print("Rule bootstrapping completed.") except Exception as e: print(f"FAILED to bootstrap world: {e}") import traceback traceback.print_exc() asyncio.create_task(inactivity_autosave_loop()) async def inactivity_autosave_loop(): while True: try: await asyncio.sleep(60) if persistence_manager.should_autosave(interaction_happened=False): print(f"Periodic autosave triggered for: {persistence_manager.current_save_name}") await perform_save(persistence_manager.current_save_name) except Exception as e: print(f"Error in inactivity_autosave_loop: {e}") await asyncio.sleep(10) # Wait a bit before retrying class InteractRequest(BaseModel): user_input: str narrator_model: Optional[str] = None reasoning_model: Optional[str] = None image_model: Optional[str] = None audio_model: Optional[str] = None class ChatMessageDto(BaseModel): text: str is_user: bool class ConsultGmRequest(BaseModel): user_input: str history: List[ChatMessageDto] reasoning_model: Optional[str] = None class CharacterDraft(BaseModel): model_config = ConfigDict(populate_by_name=True) name: str race: str class_name: str = Field(alias="class") stats: Dict[str, int] skills: List[str] items: List[str] starting_context: str class WorldInitRequest(BaseModel): prompt: Optional[str] = None history: Optional[List[ChatMessageDto]] = None save_name: Optional[str] = None description: Optional[str] = None reasoning_model: Optional[str] = None class ConfirmRequest(BaseModel): draft: CharacterDraft save_name: str description: Optional[str] = None reasoning_model: Optional[str] = None @app.get("/") async def root(): return {"status": "online", "engine": "FalkorDB", "mock_mode": db_client.is_mock} @app.get("/health") async def health(): try: # Simple query to verify DB connectivity db_client.query("MATCH (n) RETURN count(n)") return {"status": "ready", "database": "connected"} except Exception as e: return {"status": "degraded", "database": "disconnected", "error": str(e)} @app.post("/consult_gm") async def consult_gm(request: ConsultGmRequest): try: print(f"Received GM consultation request: {request.user_input}") librarian = create_librarian(request.reasoning_model) if not librarian: raise HTTPException(status_code=500, detail="Failed to initialize Librarian for GM consultation.") # Search for rules and context to answer the question rules_context = await asyncio.to_thread( librarian.run, f"Use vector_search on index 'Rule' to find information that helps answer this GM question: {request.user_input}" ) # Get world state context if possible world_context = "" try: player_data = db_client.query("MATCH (p:Player) RETURN p.name, p.hp, p.max_hp, p.strength, p.dexterity, p.constitution, p.intelligence, p.wisdom, p.charisma, p.x, p.y") if player_data: world_context = f"Current Player Stats: {player_data}" except: pass history_str = "\n".join([f"{'User' if m.is_user else 'GM'}: {m.text}" for m in request.history[-5:]]) # Last 5 messages for context model = get_model(request.reasoning_model) prompt = f""" You are the Game Master (GM). A player is consulting you for information or rule clarification. Relevant Rules: {rules_context} World Context: {world_context} Recent Consultation History: {history_str} Current Question: {request.user_input} Provide a helpful, concise, and in-character response based on the rules and world state. Do NOT change any data in the database. """ # model is an HfApiModel, it has a generate method response = await asyncio.to_thread(model.generate, messages=[{"role": "user", "content": prompt}]) response_text = response.content if hasattr(response, 'content') else str(response) return {"response": response_text} except Exception as e: print(f"Error in /consult_gm: {e}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.post("/interact") async def interact(request: InteractRequest, background_tasks: BackgroundTasks): try: print(f"Received interaction request: {request.user_input}") persistence_manager.update_interaction() should_save = persistence_manager.should_autosave(interaction_happened=True) librarian = create_librarian(request.reasoning_model) archivist = create_archivist(request.reasoning_model) if not librarian or not archivist: raise HTTPException(status_code=500, detail="Failed to initialize Librarian or Archivist agents.") rules_context = await asyncio.to_thread( librarian.run, f"Use vector_search on index 'Rule' to find relevant game mechanics for: {request.user_input}" ) memories_context = await asyncio.to_thread( librarian.run, f"Use vector_search on index 'Memory' to find relevant past events for: {request.user_input}" ) context = f"Rules:\n{rules_context}\n\nMemories:\n{memories_context}" # Archivist agent execution is blocking, run in thread changes = await asyncio.to_thread( archivist.run, f"Context: {context}\nUser Action: {request.user_input}\nUpdate the world state in FalkorDB. Use the smart_validator to ensure the Cypher respects game rules and user intent. Coordinate tracking (x,y) is mandatory. Summarize the changes." ) response = await asyncio.to_thread( narrator.run, context=context, user_input=request.user_input, changes=changes, model_id=request.narrator_model ) background_tasks.add_task(record_memory, request.user_input, response) if should_save: background_tasks.add_task(perform_save, persistence_manager.current_save_name) return {"response": response} except Exception as e: print(f"Error in /interact: {e}") import traceback traceback.print_exc() if isinstance(e, HTTPException): raise e raise HTTPException(status_code=500, detail=str(e)) async def record_memory(user_input: str, response: str): try: model = get_model() summary_prompt = f"Summarize this game exchange for long-term memory:\nUser: {user_input}\nNarrator: {response}" # Model inference is blocking, run in thread summary = await asyncio.to_thread(model.generate, messages=[{"role": "user", "content": summary_prompt}]) summary_text = summary.content if hasattr(summary, 'content') else str(summary) content = f"User: {user_input} | Narrator: {response}" embedding = vector_model.encode(summary_text) query = "CREATE (m:Memory {content: $content, summary: $summary, embedding: vecf32($embedding), timestamp: timestamp()})" await asyncio.to_thread(db_client.query, query, { "content": content, "summary": summary_text, "embedding": embedding }) except Exception as e: print(f"Error recording memory: {e}") async def perform_save(save_name: str, description: str = None): try: db_path = "/data/world.db" await asyncio.to_thread(db_client.save_db, db_path) if await asyncio.to_thread(save_to_dataset, save_name, db_path, description=description): persistence_manager.needs_save = False persistence_manager.last_save_time = os.path.getmtime(db_path) if os.path.exists(db_path) else persistence_manager.last_save_time except Exception as e: print(f"Error performing save: {e}") @app.post("/world/save") async def save_world(save_name: str, background_tasks: BackgroundTasks, description: Optional[str] = None): persistence_manager.current_save_name = save_name background_tasks.add_task(perform_save, save_name, description) return {"status": "save_initiated", "save_name": save_name} @app.delete("/world/save/{save_name}") async def delete_world_save(save_name: str): if delete_save(save_name): return {"status": "deleted", "save_name": save_name} raise HTTPException(status_code=404, detail="Save not found or could not be deleted") @app.post("/world/load") async def load_world(save_name: str, background_tasks: BackgroundTasks): try: db_path = "/data/world.db" if await asyncio.to_thread(load_from_dataset, save_name, db_path): await asyncio.to_thread(db_client.load_db, db_path) persistence_manager.current_save_name = save_name # Restart after sending response to apply new RDB file background_tasks.add_task(restart_server) return {"status": "loaded", "save_name": save_name, "notice": "Server restarting to apply world state"} raise HTTPException(status_code=404, detail="Save not found") except Exception as e: print(f"Error in /world/load: {e}") import traceback traceback.print_exc() if isinstance(e, HTTPException): raise e raise HTTPException(status_code=500, detail=str(e)) @app.get("/world/saves") async def get_saves(): try: return {"saves": list_saves()} except Exception as e: print(f"Error in /world/saves: {e}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.get("/world/media") async def get_media(entity_id: str, media_type: str, prompt: str, model_id: Optional[str] = None): """Fetches cached media or generates new media for an entity asynchronously with fallback logic.""" if model_id and model_id.lower() == "disabled": raise HTTPException(status_code=400, detail="Media generation is disabled") save_name = persistence_manager.current_save_name cached_content = get_cached_media(save_name, entity_id, media_type) if cached_content: mime = "image/webp" if media_type == "image" else "audio/mpeg" return Response(content=cached_content, media_type=mime) # Preparation for generation models_to_try = [] if model_id: models_to_try.append(model_id) base_list = IMAGE_MODELS if media_type == "image" else AUDIO_MODELS for m in base_list: if m not in models_to_try: models_to_try.append(m) headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"} async with httpx.AsyncClient(timeout=60.0) as client: last_error = "Unknown error" for mid in models_to_try: try: api_url = f"https://api-inference.huggingface.co/models/{mid}" response = await client.post(api_url, headers=headers, json={"inputs": prompt}) if response.status_code == 200: save_cached_media(save_name, entity_id, media_type, response.content) mime = "image/webp" if media_type == "image" else "audio/mpeg" return Response(content=response.content, media_type=mime) else: last_error = f"Model {mid} failed: {response.status_code} {response.text}" continue except Exception as e: last_error = str(e) continue raise HTTPException(status_code=500, detail=f"All generation attempts failed. Last error: {last_error}") @app.get("/world/npc/{npc_id}") async def get_npc(npc_id: str): # Support retrieving both Players and NPCs, and match by id or name for legacy compatibility # RETURN properties(n) ensures we get a dictionary in both mock and live DB query = "MATCH (n) WHERE (n:NPC OR n:Player) AND (n.id = $id OR n.name = $id) RETURN properties(n)" result = db_client.query(query, {"id": npc_id}) if not result or len(result) == 0: raise HTTPException(status_code=404, detail="Entity not found") # result[0][0] is the node properties dictionary entity = result[0][0] return {"npc": entity} @app.get("/world/map") async def get_map(): # Only return Player and Location nodes. NPCs are tracked but not displayed. query = """ MATCH (n) WHERE n:Player OR n:Location RETURN labels(n)[0], n.id, n.name, n.x, n.y """ print(f"Executing map query: {query}") result = db_client.query(query) print(f"Map query result: {result}") entities = [] if result: for row in result: entities.append({ "type": row[0], "id": row[1] or row[2], # Fallback to name if id is missing "name": row[2], "x": int(row[3]) if row[3] is not None else 0, "y": int(row[4]) if row[4] is not None else 0 }) return {"entities": entities} def robust_json_extract(text: str): """Extracts JSON and returns (draft_dict, clean_text).""" draft = None clean_text = text # Try markdown block first json_match = re.search(r"(```(?:json)?\s*(.*?)\s*```)", text, re.DOTALL) if json_match: full_block = json_match.group(1) content = json_match.group(2).strip() clean_text = text.replace(full_block, "").strip() try: draft = json.loads(content) except: try: content = re.sub(r",\s*([\]}])", r"\1", content) draft = json.loads(content) except: pass if draft: return draft, clean_text # Try searching for any { ... } block brace_match = re.search(r"({.*})", text, re.DOTALL) if brace_match: full_match = brace_match.group(1) try: draft = json.loads(full_match.strip()) clean_text = text.replace(full_match, "").strip() return draft, clean_text except: pass return None, text @app.post("/world/init") async def init_world(request: WorldInitRequest): try: print(f"Character creation step for: {request.prompt}") user_input = request.prompt or "Hello" history = request.history or [] response_text = await asyncio.to_thread(run_recruiter, user_input, history, request.reasoning_model) draft, clean_text = robust_json_extract(response_text) return { "response": clean_text, "draft": draft } except Exception as e: print(f"Error in /world/init: {e}") if "BadRequestError" in str(type(e)): try: # Attempt to log response content if it exists if hasattr(e, 'response') and hasattr(e.response, 'text'): print(f"Bad Request Response Body: {e.response.text}") except: pass import traceback traceback.print_exc() if isinstance(e, HTTPException): raise e raise HTTPException(status_code=500, detail=str(e)) @app.post("/world/confirm") async def confirm_world(request: ConfirmRequest, background_tasks: BackgroundTasks): try: print(f"Confirming character: {request.draft.name}") registrar = create_registrar(request.reasoning_model) if not registrar: raise HTTPException(status_code=500, detail="Failed to initialize Registrar agent.") # Registrar handles DB creation # Use by_alias=True to ensure 'class' is used in the JSON passed to the registrar await asyncio.to_thread(registrar.run, f"Validated Character Draft: {request.draft.model_dump_json(by_alias=True)}") # Use character name as the save slot name save_name = request.draft.name persistence_manager.current_save_name = save_name # Character confirmed - immediate save await perform_save(save_name, request.description) return {"status": "initialized", "save_name": save_name} except Exception as e: print(f"Error in /world/confirm: {e}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.post("/world/bootstrap") async def bootstrap_world(background_tasks: BackgroundTasks): background_tasks.add_task(perform_bootstrap) return {"status": "bootstrap_initiated"} async def perform_bootstrap(): import json rules_path = "dnd_srd_rules.json" if not os.path.exists(rules_path): print(f"Rules file not found at {rules_path}") return with open(rules_path, "r") as f: rules = json.load(f) # Attempt to create indices, ignore if they already exist # Using the new syntax for FalkorDB 4.x+ try: db_client.query("CREATE VECTOR INDEX FOR (r:Rule) ON (r.embedding) OPTIONS {dimension: 384, similarityFunction: 'cosine'}") except Exception as e: if "already exists" not in str(e).lower(): print(f"Notice creating Rule index: {e}") try: db_client.query("CREATE VECTOR INDEX FOR (m:Memory) ON (m.embedding) OPTIONS {dimension: 384, similarityFunction: 'cosine'}") except Exception as e: if "already exists" not in str(e).lower(): print(f"Notice creating Memory index: {e}") # Fetch all existing rule titles once to optimize bootstrapping existing_rules = db_client.query("MATCH (r:Rule) RETURN r.title") existing_titles = set() if existing_rules: for row in existing_rules: if isinstance(row, list) and len(row) > 0: existing_titles.add(row[0]) for rule in rules: title = rule.get("title") if title in existing_titles: continue content = rule.get("content") tags = rule.get("tags", []) embedding = vector_model.encode(content) query = """ CREATE (r:Rule {title: $title, content: $content, tags: $tags, embedding: vecf32($embedding)}) """ db_client.query(query, { "title": title, "content": content, "tags": tags, "embedding": embedding }) print(f"Successfully bootstrapped {len(rules)} rules.") async def restart_server(): print("Initiating server restart in 2 seconds...") await asyncio.sleep(2) sys.exit(0)