nharshavardhana's picture
commit
dfadb93
import requests
import json
import os
from collections import Counter
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.chat_models import init_chat_model
from langchain.schema import HumanMessage
from langchain.tools import tool
from langchain_core.messages import convert_to_messages
from langgraph.prebuilt import create_react_agent
from langchain_core.tools import tool
from typing import List
from langgraph_supervisor import create_supervisor
from langchain.chat_models import init_chat_model
import re
import gradio as gr
from PIL import Image
import shutil
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
def pretty_print_message(message, indent=False):
pretty_message = message.pretty_repr(html=True)
if not indent:
print(pretty_message)
return
indented = "\n".join("\t" + c for c in pretty_message.split("\n"))
print(indented)
def pretty_print_messages(update, last_message=False):
is_subgraph = False
if isinstance(update, tuple):
ns, update = update
# skip parent graph updates in the printouts
if len(ns) == 0:
return
graph_id = ns[-1].split(":")[0]
print(f"Update from subgraph {graph_id}:")
print("\n")
is_subgraph = True
for node_name, node_update in update.items():
update_label = f"Update from node {node_name}:"
if is_subgraph:
update_label = "\t" + update_label
print(update_label)
print("\n")
messages = convert_to_messages(node_update["messages"])
if last_message:
messages = messages[-1:]
for m in messages:
pretty_print_message(m, indent=is_subgraph)
print("\n")
@tool
def parse_design_intent(prompt: str) -> dict:
"""
Parses level design prompt into structured configuration.
E.g. "a spooky forest with 2 enemies and 1 trap"
"""
import re
config = {"theme": "generic", "enemies": 0, "traps": 0}
if "spooky" in prompt.lower():
config["theme"] = "spooky"
enemies = re.search(r"(\d+)\s*enemies?", prompt)
traps = re.search(r"(\d+)\s*traps?", prompt)
if enemies: config["enemies"] = int(enemies.group(1))
if traps: config["traps"] = int(traps.group(1))
return config
design_intent_agent = create_react_agent(
model="mistral-large-latest",
tools=[parse_design_intent],
prompt=(
"You are a design intent parsing agent.\n"
"- Take user level prompts and extract theme, number of enemies, traps.\n"
"- Return only the structured config dictionary.\n"
"- Do not include any extra explanation."
),
name="design_intent_agent"
)
@tool
def generate_level_layout(config: dict) -> list:
"""
Generates a maze-like 2D grid based on the config.
Tiles: 'P'=Path, 'W'=Wall, 'E'=Enemy, 'T'=Trap, 'S'=Start, 'G'=Goal
"""
import numpy as np
from random import randint, shuffle
from collections import deque
size = 11 # Must be odd for proper maze
grid = [["W" for _ in range(size)] for _ in range(size)]
def carve(x, y):
dirs = [(0, 2), (0, -2), (2, 0), (-2, 0)]
shuffle(dirs)
for dx, dy in dirs:
nx, ny = x + dx, y + dy
if 1 <= nx < size - 1 and 1 <= ny < size - 1 and grid[nx][ny] == "W":
grid[x + dx // 2][y + dy // 2] = "P"
grid[nx][ny] = "P"
carve(nx, ny)
def is_solvable(g):
visited = set()
queue = deque([(1, 1)])
while queue:
x, y = queue.popleft()
if (x, y) == (size - 2, size - 2):
return True
for dx, dy in [(-1,0), (1,0), (0,-1), (0,1)]:
nx, ny = x + dx, y + dy
if 0 <= nx < size and 0 <= ny < size and g[nx][ny] in {"P", "G"} and (nx, ny) not in visited:
visited.add((nx, ny))
queue.append((nx, ny))
return False
def place_items_safely(grid):
temp_grid = [row[:] for row in grid]
def place_item(symbol, count):
placed = 0
tries = 0
while placed < count and tries < 100:
x, y = randint(1, size - 2), randint(1, size - 2)
if temp_grid[x][y] == "P":
temp_grid[x][y] = symbol
if is_solvable(temp_grid):
placed += 1
else:
temp_grid[x][y] = "P"
tries += 1
place_item("E", config.get("enemies", 2))
place_item("T", config.get("traps", 1))
return temp_grid
# Generate maze
grid[1][1] = "P"
carve(1, 1)
grid[1][1] = "S"
grid[size - 2][size - 2] = "G"
# Safely place enemies and traps
final_grid = place_items_safely(grid)
return final_grid
layout_agent = create_react_agent(
model="google_genai:gemini-2.0-flash",
tools=[generate_level_layout],
prompt=(
"You are a layout generation agent.\n"
"- Take level config and return a valid grid layout (as a list of lists).\n"
"- Place start at (1,1), goal at bottom right, enemies and traps randomly.\n"
"- No additional explanation."
),
name="layout_agent"
)
@tool
def evaluate_level(grid: list) -> dict:
"""
Checks if path exists from start to goal (very simple BFS),
and returns a basic difficulty estimate.
"""
from collections import deque
def bfs(grid):
rows, cols = len(grid), len(grid[0])
visited = [[False]*cols for _ in range(rows)]
queue = deque([(1,1)])
visited[1][1] = True
while queue:
x, y = queue.popleft()
if grid[x][y] == "G":
return True
for dx, dy in [(1,0), (-1,0), (0,1), (0,-1)]:
nx, ny = x+dx, y+dy
if 0 <= nx < rows and 0 <= ny < cols and grid[nx][ny] != "W" and not visited[nx][ny]:
queue.append((nx, ny))
visited[nx][ny] = True
return False
enemies = sum(row.count("E") for row in grid)
traps = sum(row.count("T") for row in grid)
playable = bfs(grid)
difficulty = "Easy"
if enemies + traps > 4:
difficulty = "Hard"
elif enemies + traps > 2:
difficulty = "Medium"
return {
"playable": playable,
"difficulty": difficulty,
"enemy_count": enemies,
"trap_count": traps
}
test_agent = create_react_agent(
model="mistral-large-latest",
tools=[evaluate_level],
prompt=(
"You are a level test agent.\n"
"- Evaluate if the level has a valid path from Start to Goal.\n"
"- Count enemies and traps.\n"
"- Return a dictionary with playability and difficulty.\n"
"- No explanation."
),
name="test_agent"
)
@tool
def render_grid_image(grid: List[List[str]]) -> str:
"""
Renders a visually appealing grid image and returns the file path (served via Gradio).
"""
import matplotlib.pyplot as plt
import numpy as np
import os
from matplotlib.colors import ListedColormap
# Mapping from tile type to index
tile_map = {"P": 0, "W": 1, "E": 2, "T": 3, "S": 4, "G": 5}
colors = [
"#e0e0e0", # Path - light gray
"#2e2e2e", # Wall - dark gray
"#d32f2f", # Enemy - red
"#fbc02d", # Trap - yellow
"#1976d2", # Start - blue
"#388e3c", # Goal - green
]
color_map = ListedColormap(colors)
numeric_grid = np.array([[tile_map[cell] for cell in row] for row in grid])
# Save path
image_path = "/tmp/level.png"
# Render image
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(numeric_grid, cmap=color_map, interpolation='none')
# Grid lines
ax.set_xticks(np.arange(-0.5, len(grid), 1), minor=True)
ax.set_yticks(np.arange(-0.5, len(grid[0]), 1), minor=True)
ax.grid(which='minor', color='black', linewidth=0.5)
ax.tick_params(which='both', bottom=False, left=False, labelbottom=False, labelleft=False)
# Emoji/text labels for key tiles
for i in range(len(grid)):
for j in range(len(grid[0])):
cell = grid[i][j]
if cell in ["S", "G", "E", "T"]:
label = {
"S": "S", # Start
"G": "G", # Goal
"E": "⚠️", # Enemy
"T": "💣", # Trap
}.get(cell, "")
ax.text(j, i, label, ha='center', va='center', fontsize=12, color='black')
plt.tight_layout()
plt.savefig(image_path, bbox_inches="tight", dpi=150)
plt.close()
return image_path # ✅ Serve this in gr.Image(type="filepath")
render_image_agent = create_react_agent(
model="google_genai:gemini-2.0-flash",
tools=[render_grid_image],
prompt=(
"You are a rendering agent.\n"
"- Only use the render_grid_image tool to render the grid.\n"
"- When responding, ONLY return the raw file path returned by the tool.\n"
"- Do NOT add any commentary or formatting.\n"
"- Example: /tmp/level.png"
),
name="render_image_agent"
)
supervisor = create_supervisor(
model=init_chat_model("google_genai:gemini-2.0-flash-lite"),
agents=[design_intent_agent, layout_agent, test_agent, render_image_agent],
prompt=(
"You are a supervisor managing three agents:\n"
"- a design intent agent. Assign design intent-related tasks to this agent\n"
"- a layout agent. Assign layout related tasks to this agent\n"
"- a test agent. Assign test related tasks to this agent\n"
"- a render image agent. Assign render image tasks to this agent\n"
"Assign work to one agent at a time, do not call agents in parallel.\n"
"Do not do any work yourself."
),
add_handoff_back_messages=True,
output_mode="full_history",
).compile()
def run_streamed_supervisor(user_input):
image_path = None
# Stream supervisor messages
for chunk in supervisor.stream({
"messages": [
{"role": "user", "content": user_input}
]
}):
pass # Wait for final chunk
final_messages = chunk["supervisor"]["messages"]
# Look for static image path in last messages
for message in reversed(final_messages):
if hasattr(message, "content") and isinstance(message.content, str):
match = re.search(r"(/tmp/level\.png)", message.content)
if match:
image_path = match.group(1)
break
# ✅ Return static file path directly (already saved by render_grid_image)
if image_path and os.path.exists(image_path):
return image_path # Gradio will use it directly
return "❌ Image path not found or file does not exist."
def clear_fields():
return "", None # Clear textbox and image
legend_html = """
<div style="display: flex; flex-wrap: wrap; gap: 10px; margin-top: 10px;">
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background: #e0e0e0; border: 1px solid #ccc;"></div>
<span style="margin-left: 5px;">Path</span>
</div>
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background: #2e2e2e;"></div>
<span style="margin-left: 5px; color: #2e2e2e;">Wall</span>
</div>
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background: #d32f2f;"></div>
<span style="margin-left: 5px;">Enemy</span>
</div>
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background: #fbc02d;"></div>
<span style="margin-left: 5px;">Trap</span>
</div>
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background: #1976d2;"></div>
<span style="margin-left: 5px;">Start</span>
</div>
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background: #388e3c;"></div>
<span style="margin-left: 5px;">Goal</span>
</div>
</div>
"""
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Maze Generator (LangGraph Multi-Agent Supervisor)")
with gr.Row():
inp = gr.Textbox(label="Instruction (e.g., 'design, test and render a spooky maze level with 1 enemy and 2 traps')")
out = gr.Image(label="Generated Maze", type="filepath")
with gr.Row():
btn = gr.Button("🎮 Generate Maze")
clr = gr.Button("🧹 Clear")
btn.click(fn=run_streamed_supervisor, inputs=inp, outputs=out)
clr.click(fn=clear_fields, inputs=None, outputs=[inp, out])
gr.Markdown("### 🗺️ Maze Legend")
gr.HTML(legend_html)
gr.Markdown("### 📸 Reference Mazes generated by AI agent")
gr.Gallery(
["generated_mazes/maze_design_1.png", "generated_mazes/maze_design_2.png", "generated_mazes/maze_design_3.png"],
label="Reference Mazes",
columns=3,
height="auto"
)
demo.launch(share = True)