| import requests |
| import json |
| import os |
| from collections import Counter |
| from langgraph.graph import StateGraph, END |
| from typing import TypedDict, Annotated |
| import operator |
| from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage |
| from langchain_community.tools.tavily_search import TavilySearchResults |
| from langchain.chat_models import init_chat_model |
| from langchain.schema import HumanMessage |
| from langchain.tools import tool |
| from langchain_core.messages import convert_to_messages |
| from langgraph.prebuilt import create_react_agent |
| from langchain_core.tools import tool |
| from typing import List |
| from langgraph_supervisor import create_supervisor |
| from langchain.chat_models import init_chat_model |
| import re |
| import gradio as gr |
| from PIL import Image |
| import shutil |
|
|
| MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
| def pretty_print_message(message, indent=False): |
| pretty_message = message.pretty_repr(html=True) |
| if not indent: |
| print(pretty_message) |
| return |
|
|
| indented = "\n".join("\t" + c for c in pretty_message.split("\n")) |
| print(indented) |
|
|
|
|
| def pretty_print_messages(update, last_message=False): |
| is_subgraph = False |
| if isinstance(update, tuple): |
| ns, update = update |
| |
| if len(ns) == 0: |
| return |
|
|
| graph_id = ns[-1].split(":")[0] |
| print(f"Update from subgraph {graph_id}:") |
| print("\n") |
| is_subgraph = True |
|
|
| for node_name, node_update in update.items(): |
| update_label = f"Update from node {node_name}:" |
| if is_subgraph: |
| update_label = "\t" + update_label |
|
|
| print(update_label) |
| print("\n") |
|
|
| messages = convert_to_messages(node_update["messages"]) |
| if last_message: |
| messages = messages[-1:] |
|
|
| for m in messages: |
| pretty_print_message(m, indent=is_subgraph) |
| print("\n") |
|
|
| @tool |
| def parse_design_intent(prompt: str) -> dict: |
| """ |
| Parses level design prompt into structured configuration. |
| E.g. "a spooky forest with 2 enemies and 1 trap" |
| """ |
| import re |
| config = {"theme": "generic", "enemies": 0, "traps": 0} |
|
|
| if "spooky" in prompt.lower(): |
| config["theme"] = "spooky" |
|
|
| enemies = re.search(r"(\d+)\s*enemies?", prompt) |
| traps = re.search(r"(\d+)\s*traps?", prompt) |
|
|
| if enemies: config["enemies"] = int(enemies.group(1)) |
| if traps: config["traps"] = int(traps.group(1)) |
|
|
| return config |
|
|
| design_intent_agent = create_react_agent( |
| model="mistral-large-latest", |
| tools=[parse_design_intent], |
| prompt=( |
| "You are a design intent parsing agent.\n" |
| "- Take user level prompts and extract theme, number of enemies, traps.\n" |
| "- Return only the structured config dictionary.\n" |
| "- Do not include any extra explanation." |
| ), |
| name="design_intent_agent" |
| ) |
|
|
| @tool |
| def generate_level_layout(config: dict) -> list: |
| """ |
| Generates a maze-like 2D grid based on the config. |
| Tiles: 'P'=Path, 'W'=Wall, 'E'=Enemy, 'T'=Trap, 'S'=Start, 'G'=Goal |
| """ |
| import numpy as np |
| from random import randint, shuffle |
| from collections import deque |
|
|
| size = 11 |
| grid = [["W" for _ in range(size)] for _ in range(size)] |
|
|
| def carve(x, y): |
| dirs = [(0, 2), (0, -2), (2, 0), (-2, 0)] |
| shuffle(dirs) |
| for dx, dy in dirs: |
| nx, ny = x + dx, y + dy |
| if 1 <= nx < size - 1 and 1 <= ny < size - 1 and grid[nx][ny] == "W": |
| grid[x + dx // 2][y + dy // 2] = "P" |
| grid[nx][ny] = "P" |
| carve(nx, ny) |
|
|
| def is_solvable(g): |
| visited = set() |
| queue = deque([(1, 1)]) |
| while queue: |
| x, y = queue.popleft() |
| if (x, y) == (size - 2, size - 2): |
| return True |
| for dx, dy in [(-1,0), (1,0), (0,-1), (0,1)]: |
| nx, ny = x + dx, y + dy |
| if 0 <= nx < size and 0 <= ny < size and g[nx][ny] in {"P", "G"} and (nx, ny) not in visited: |
| visited.add((nx, ny)) |
| queue.append((nx, ny)) |
| return False |
|
|
| def place_items_safely(grid): |
| temp_grid = [row[:] for row in grid] |
| def place_item(symbol, count): |
| placed = 0 |
| tries = 0 |
| while placed < count and tries < 100: |
| x, y = randint(1, size - 2), randint(1, size - 2) |
| if temp_grid[x][y] == "P": |
| temp_grid[x][y] = symbol |
| if is_solvable(temp_grid): |
| placed += 1 |
| else: |
| temp_grid[x][y] = "P" |
| tries += 1 |
| place_item("E", config.get("enemies", 2)) |
| place_item("T", config.get("traps", 1)) |
| return temp_grid |
|
|
| |
| grid[1][1] = "P" |
| carve(1, 1) |
| grid[1][1] = "S" |
| grid[size - 2][size - 2] = "G" |
|
|
| |
| final_grid = place_items_safely(grid) |
|
|
| return final_grid |
|
|
|
|
| layout_agent = create_react_agent( |
| model="google_genai:gemini-2.0-flash", |
| tools=[generate_level_layout], |
| prompt=( |
| "You are a layout generation agent.\n" |
| "- Take level config and return a valid grid layout (as a list of lists).\n" |
| "- Place start at (1,1), goal at bottom right, enemies and traps randomly.\n" |
| "- No additional explanation." |
| ), |
| name="layout_agent" |
| ) |
|
|
| @tool |
| def evaluate_level(grid: list) -> dict: |
| """ |
| Checks if path exists from start to goal (very simple BFS), |
| and returns a basic difficulty estimate. |
| """ |
| from collections import deque |
|
|
| def bfs(grid): |
| rows, cols = len(grid), len(grid[0]) |
| visited = [[False]*cols for _ in range(rows)] |
| queue = deque([(1,1)]) |
| visited[1][1] = True |
|
|
| while queue: |
| x, y = queue.popleft() |
| if grid[x][y] == "G": |
| return True |
| for dx, dy in [(1,0), (-1,0), (0,1), (0,-1)]: |
| nx, ny = x+dx, y+dy |
| if 0 <= nx < rows and 0 <= ny < cols and grid[nx][ny] != "W" and not visited[nx][ny]: |
| queue.append((nx, ny)) |
| visited[nx][ny] = True |
| return False |
|
|
| enemies = sum(row.count("E") for row in grid) |
| traps = sum(row.count("T") for row in grid) |
| playable = bfs(grid) |
|
|
| difficulty = "Easy" |
| if enemies + traps > 4: |
| difficulty = "Hard" |
| elif enemies + traps > 2: |
| difficulty = "Medium" |
|
|
| return { |
| "playable": playable, |
| "difficulty": difficulty, |
| "enemy_count": enemies, |
| "trap_count": traps |
| } |
|
|
| test_agent = create_react_agent( |
| model="mistral-large-latest", |
| tools=[evaluate_level], |
| prompt=( |
| "You are a level test agent.\n" |
| "- Evaluate if the level has a valid path from Start to Goal.\n" |
| "- Count enemies and traps.\n" |
| "- Return a dictionary with playability and difficulty.\n" |
| "- No explanation." |
| ), |
| name="test_agent" |
| ) |
|
|
| @tool |
| def render_grid_image(grid: List[List[str]]) -> str: |
| """ |
| Renders a visually appealing grid image and returns the file path (served via Gradio). |
| """ |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import os |
| from matplotlib.colors import ListedColormap |
|
|
|
|
| |
| tile_map = {"P": 0, "W": 1, "E": 2, "T": 3, "S": 4, "G": 5} |
| colors = [ |
| "#e0e0e0", |
| "#2e2e2e", |
| "#d32f2f", |
| "#fbc02d", |
| "#1976d2", |
| "#388e3c", |
| ] |
| color_map = ListedColormap(colors) |
|
|
| numeric_grid = np.array([[tile_map[cell] for cell in row] for row in grid]) |
|
|
| |
| image_path = "/tmp/level.png" |
|
|
| |
| fig, ax = plt.subplots(figsize=(6, 6)) |
| ax.imshow(numeric_grid, cmap=color_map, interpolation='none') |
|
|
| |
| ax.set_xticks(np.arange(-0.5, len(grid), 1), minor=True) |
| ax.set_yticks(np.arange(-0.5, len(grid[0]), 1), minor=True) |
| ax.grid(which='minor', color='black', linewidth=0.5) |
| ax.tick_params(which='both', bottom=False, left=False, labelbottom=False, labelleft=False) |
|
|
| |
| for i in range(len(grid)): |
| for j in range(len(grid[0])): |
| cell = grid[i][j] |
| if cell in ["S", "G", "E", "T"]: |
| label = { |
| "S": "S", |
| "G": "G", |
| "E": "⚠️", |
| "T": "💣", |
| }.get(cell, "") |
| ax.text(j, i, label, ha='center', va='center', fontsize=12, color='black') |
|
|
| plt.tight_layout() |
| plt.savefig(image_path, bbox_inches="tight", dpi=150) |
| plt.close() |
|
|
| return image_path |
|
|
|
|
|
|
| render_image_agent = create_react_agent( |
| model="google_genai:gemini-2.0-flash", |
| tools=[render_grid_image], |
| prompt=( |
| "You are a rendering agent.\n" |
| "- Only use the render_grid_image tool to render the grid.\n" |
| "- When responding, ONLY return the raw file path returned by the tool.\n" |
| "- Do NOT add any commentary or formatting.\n" |
| "- Example: /tmp/level.png" |
| ), |
| name="render_image_agent" |
| ) |
|
|
| supervisor = create_supervisor( |
| model=init_chat_model("google_genai:gemini-2.0-flash-lite"), |
| agents=[design_intent_agent, layout_agent, test_agent, render_image_agent], |
| prompt=( |
| "You are a supervisor managing three agents:\n" |
| "- a design intent agent. Assign design intent-related tasks to this agent\n" |
| "- a layout agent. Assign layout related tasks to this agent\n" |
| "- a test agent. Assign test related tasks to this agent\n" |
| "- a render image agent. Assign render image tasks to this agent\n" |
| "Assign work to one agent at a time, do not call agents in parallel.\n" |
| "Do not do any work yourself." |
| ), |
| add_handoff_back_messages=True, |
| output_mode="full_history", |
| ).compile() |
|
|
| def run_streamed_supervisor(user_input): |
| image_path = None |
|
|
| |
| for chunk in supervisor.stream({ |
| "messages": [ |
| {"role": "user", "content": user_input} |
| ] |
| }): |
| pass |
|
|
| final_messages = chunk["supervisor"]["messages"] |
|
|
| |
| for message in reversed(final_messages): |
| if hasattr(message, "content") and isinstance(message.content, str): |
| match = re.search(r"(/tmp/level\.png)", message.content) |
| if match: |
| image_path = match.group(1) |
| break |
|
|
| |
| if image_path and os.path.exists(image_path): |
| return image_path |
|
|
| return "❌ Image path not found or file does not exist." |
|
|
| def clear_fields(): |
| return "", None |
|
|
| legend_html = """ |
| <div style="display: flex; flex-wrap: wrap; gap: 10px; margin-top: 10px;"> |
| <div style="display: flex; align-items: center;"> |
| <div style="width: 20px; height: 20px; background: #e0e0e0; border: 1px solid #ccc;"></div> |
| <span style="margin-left: 5px;">Path</span> |
| </div> |
| <div style="display: flex; align-items: center;"> |
| <div style="width: 20px; height: 20px; background: #2e2e2e;"></div> |
| <span style="margin-left: 5px; color: #2e2e2e;">Wall</span> |
| </div> |
| <div style="display: flex; align-items: center;"> |
| <div style="width: 20px; height: 20px; background: #d32f2f;"></div> |
| <span style="margin-left: 5px;">Enemy</span> |
| </div> |
| <div style="display: flex; align-items: center;"> |
| <div style="width: 20px; height: 20px; background: #fbc02d;"></div> |
| <span style="margin-left: 5px;">Trap</span> |
| </div> |
| <div style="display: flex; align-items: center;"> |
| <div style="width: 20px; height: 20px; background: #1976d2;"></div> |
| <span style="margin-left: 5px;">Start</span> |
| </div> |
| <div style="display: flex; align-items: center;"> |
| <div style="width: 20px; height: 20px; background: #388e3c;"></div> |
| <span style="margin-left: 5px;">Goal</span> |
| </div> |
| </div> |
| """ |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("## Maze Generator (LangGraph Multi-Agent Supervisor)") |
|
|
| with gr.Row(): |
| inp = gr.Textbox(label="Instruction (e.g., 'design, test and render a spooky maze level with 1 enemy and 2 traps')") |
| out = gr.Image(label="Generated Maze", type="filepath") |
|
|
| with gr.Row(): |
| btn = gr.Button("🎮 Generate Maze") |
| clr = gr.Button("🧹 Clear") |
|
|
| btn.click(fn=run_streamed_supervisor, inputs=inp, outputs=out) |
| clr.click(fn=clear_fields, inputs=None, outputs=[inp, out]) |
|
|
| gr.Markdown("### 🗺️ Maze Legend") |
| gr.HTML(legend_html) |
|
|
| gr.Markdown("### 📸 Reference Mazes generated by AI agent") |
| gr.Gallery( |
| ["generated_mazes/maze_design_1.png", "generated_mazes/maze_design_2.png", "generated_mazes/maze_design_3.png"], |
| label="Reference Mazes", |
| columns=3, |
| height="auto" |
| ) |
|
|
|
|
| demo.launch(share = True) |
|
|
|
|