HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.animation import FuncAnimation
import numpy as np
import torch
from ...logging_utils import logger
def grid_adjacency(R=3, C=3, diagonals=False, by_index=True):
"""
R, C: rows and columns
diagonals: False=4-connectivity; True=8-connectivity
by_index: True returns with index (0..R*C-1) as key; False returns with coordinate (r,c) as key
"""
# 4-connectivity directions
dirs = [(-1, 0), (1, 0), (0, -1), (0, 1)]
# 8-connectivity adds four diagonal directions
if diagonals:
dirs += [(-1, -1), (-1, 1), (1, -1), (1, 1)]
adj = {}
for r in range(R):
for c in range(C):
nbrs = []
for dr, dc in dirs:
nr, nc = r + dr, c + dc
if 0 <= nr < R and 0 <= nc < C:
nbrs.append((nr, nc))
key = r * C + c if by_index else (r, c)
if by_index:
adj[key] = [nr * C + nc for (nr, nc) in nbrs]
else:
adj[key] = nbrs
return adj
def dfs_path(adj, start, target, generator=None, blocked_nodes=None):
"""
Execute DFS to find path from start to target
Returns visit_order, path and edges used during search
Args:
adj: adjacency list
start: start node
target: target node
generator: torch.Generator for random neighbor selection
blocked_nodes: list of node indices to avoid (cannot pass through these nodes)
Returns:
visit_order: list of all nodes visited in order
path: final path from start to target
edges_used: edges traversed during search
"""
visited = set()
visit_order = [] # Track all nodes visited in order
path = []
edges_used = []
found = False
# Convert blocked_nodes to set for O(1) lookup
if blocked_nodes is None:
blocked_nodes = set()
else:
blocked_nodes = set(blocked_nodes)
def dfs_helper(node, current_path):
nonlocal found
if found:
return
visited.add(node)
visit_order.append(node) # Record visit order
current_path.append(node)
# Found target node
if node == target:
path.extend(current_path)
found = True
return
# Get neighbors and optionally shuffle them
neighbors = adj[node].copy()
if generator is not None:
# Use torch to randomly permute neighbors
neighbors_tensor = torch.tensor(neighbors)
perm = torch.randperm(len(neighbors), generator=generator)
neighbors = neighbors_tensor[perm].tolist()
# Continue searching
for neighbor in neighbors:
# Skip blocked nodes, already visited nodes
if neighbor not in visited and neighbor not in blocked_nodes and not found:
edges_used.append((node, neighbor))
dfs_helper(neighbor, current_path)
if found:
return
current_path.pop()
dfs_helper(start, [])
return visit_order, path, edges_used
def index_to_coord(idx, C):
"""Convert index to coordinate (row, col)"""
return idx // C, idx % C
def find_path_0_to_8(start, target, R=3, C=3, diagonals=False, generator=None, blocked_nodes=None):
"""Find path from start to target using DFS
Args:
start: start node
target: target node
R: number of rows
C: number of columns
diagonals: whether to use diagonal connections
generator: torch.Generator for random neighbor selection
blocked_nodes: list of node indices to avoid (cannot pass through these nodes)
Returns:
path: list of nodes in the final path from start to target (first return value)
visit_order: list of all nodes visited in order during DFS
edges_used: edges traversed during search
adj: adjacency list
"""
# Generate adjacency list
adj = grid_adjacency(R, C, diagonals=diagonals, by_index=True)
# Execute DFS to find path
visit_order, path, edges_used = dfs_path(adj, start, target, generator=generator, blocked_nodes=blocked_nodes)
if not path:
logger.debug(f"❌ Cannot find path from {start} to {target}!")
return None, None, None, None
# Only print the path
logger.debug(f"Path: {' β†’ '.join(map(str, path))}")
return path, visit_order, edges_used, adj
def run_path_generation(run_id, generator, segments, R=3, C=7, backtrack_enable=True):
"""
Run a single path generation with segments
Args:
run_id: Run identifier for printing
generator: torch.Generator for reproducible randomness
segments: List of segment tuples, each as (start, end, blocked_nodes)
e.g., [(7, 9, [8, 3, 10, 17]), (9, 11, [1, 8, 15]), ...]
R: Number of rows in grid
C: Number of columns in grid
backtrack_enable: Whether random backtracking passes are allowed
Returns:
combined_path: Complete path through all waypoints
combined_edges: All edges used in the path
adj: Adjacency list of the grid
start_node: Starting node (first segment's start)
end_node: Ending node (last segment's end)
Returns (None, None, None, None, None) if failed
"""
# Extract waypoints for display
waypoints = [segments[0][0]] # Start with first segment's start
for seg in segments:
waypoints.append(seg[1]) # Add each segment's end
logger.debug(f"\n{'=' * 60}")
logger.debug(f"Run {run_id}: Finding path through waypoints {' β†’ '.join(map(str, waypoints))}")
logger.debug(f"{'=' * 60}")
all_paths = []
all_edges = []
adj = None
for seg_idx, (seg_start, seg_target, seg_blocked) in enumerate(segments):
# Randomly decide whether to use backtracking for this segment
# Use torch generator for reproducible randomness
do_backtrack = backtrack_enable and (torch.rand(1, generator=generator).item() > 0.5)
# Print segment header
logger.debug(f"\nSegment {seg_idx + 1}: {seg_start} β†’ {seg_target}")
# Forward path: seg_start β†’ seg_target (always executed)
logger.debug(f" Path 1 (Forward {seg_start}β†’{seg_target}):")
path_forward, visit_order_fwd, edges_fwd, adj = find_path_0_to_8(
start=seg_start, target=seg_target, R=R, C=C, diagonals=False,
generator=generator, blocked_nodes=seg_blocked
)
if not path_forward:
logger.debug(f"\n❌ Failed to find forward path for segment {seg_idx + 1}, skipping this run")
return None, None, None, None, None
if do_backtrack:
# Backward path: seg_target β†’ seg_start (backtracking)
logger.debug(f" Path 2 (Backward {seg_target}β†’{seg_start}):")
path_backward, visit_order_bwd, edges_bwd, adj = find_path_0_to_8(
start=seg_target, target=seg_start, R=R, C=C, diagonals=False,
generator=generator, blocked_nodes=seg_blocked
)
if not path_backward:
logger.debug(f"\n❌ Failed to find backward path for segment {seg_idx + 1}, skipping this run")
return None, None, None, None, None
# Forward path again: seg_start β†’ seg_target
logger.debug(f" Path 3 (Forward {seg_start}β†’{seg_target}):")
path_forward2, visit_order_fwd2, edges_fwd2, adj = find_path_0_to_8(
start=seg_start, target=seg_target, R=R, C=C, diagonals=False,
generator=generator, blocked_nodes=seg_blocked
)
if not path_forward2:
logger.debug(f"\n❌ Failed to find second forward path for segment {seg_idx + 1}, skipping this run")
return None, None, None, None, None
# Combine: forward + backward + forward (removing duplicate nodes at connection points)
seg_combined = path_forward + path_backward[1:] + path_forward2[1:]
seg_edges = edges_fwd + edges_bwd + edges_fwd2
logger.debug(f" β†’ Segment generated 3 paths (with backtracking)")
else:
# No backtracking: just use forward path
seg_combined = path_forward
seg_edges = edges_fwd
logger.debug(f" β†’ Segment generated 1 path (no backtracking)")
all_paths.append(seg_combined)
all_edges.extend(seg_edges)
# Combine all paths (remove duplicate waypoint nodes between segments)
combined_path = all_paths[0]
for path_seg in all_paths[1:]:
combined_path = combined_path + path_seg[1:] # Skip first element (waypoint)
combined_edges = all_edges
logger.debug(f"\nβœ… Final combined path: {' β†’ '.join(map(str, combined_path))}")
# Return data separately (not as tuple)
return combined_path, combined_edges, adj, waypoints[0], waypoints[-1]
def visualize_single_path(path, edges_used, adj, start, target, R, C, blocked_nodes, run_id, segment_info):
"""
Visualize a single DFS path
Args:
path: list of nodes in the path
edges_used: list of edges used in the path
adj: adjacency dictionary
start: start node
target: target node
R: number of rows
C: number of columns
blocked_nodes: set of blocked node indices
run_id: run identifier for file naming
segment_info: string describing the segment selection (e.g., "segments[0:2]")
"""
if not path:
logger.debug("❌ No path to visualize!")
return
# Convert blocked_nodes to set for O(1) lookup
if blocked_nodes is None:
blocked_nodes = set()
else:
blocked_nodes = set(blocked_nodes)
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.set_xlim(-0.5, C - 0.5)
ax.set_ylim(-0.5, R - 0.5)
ax.set_aspect('equal')
ax.invert_yaxis()
ax.set_title(f'Run {run_id}: {segment_info}\nPath: {start} β†’ {target}',
fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel('Column')
ax.set_ylabel('Row')
ax.grid(True, alpha=0.3)
# Draw all possible edges (adjacency relationships) - very light
for node, neighbors in adj.items():
r1, c1 = index_to_coord(node, C)
for neighbor in neighbors:
r2, c2 = index_to_coord(neighbor, C)
if node < neighbor:
ax.plot([c1, c2], [r1, r2], 'gray', alpha=0.15, linewidth=1, zorder=1)
# Draw path edges in blue
color = 'blue'
for i in range(len(path) - 1):
node = path[i]
neighbor = path[i + 1]
r1, c1 = index_to_coord(node, C)
r2, c2 = index_to_coord(neighbor, C)
ax.plot([c1, c2], [r1, r2], color=color, linewidth=3, alpha=0.7, zorder=3)
# Add arrows
dx, dy = c2 - c1, r2 - r1
ax.arrow(c1, r1, dx * 0.6, dy * 0.6,
head_width=0.12, head_length=0.08,
fc=color, ec=color, alpha=0.7, zorder=3)
# Draw all nodes
for idx in range(R * C):
r, c = index_to_coord(idx, C)
# Check if node is blocked
if idx in blocked_nodes:
node_color = 'dimgray'
else:
if idx == start:
node_color = 'lightgreen'
elif idx == target:
node_color = 'lightcoral'
elif idx in path:
node_color = 'lightyellow'
else:
node_color = 'lightgray'
circle = plt.Circle((c, r), 0.35, color=node_color, ec='black', linewidth=2, zorder=2)
ax.add_patch(circle)
ax.text(c, r, str(idx), ha='center', va='center',
fontsize=16, fontweight='bold',
color='white' if idx in blocked_nodes else 'black', zorder=10)
# Add X mark for blocked nodes
if idx in blocked_nodes:
ax.plot([c - 0.2, c + 0.2], [r - 0.2, r + 0.2], 'r', linewidth=3, zorder=11)
ax.plot([c - 0.2, c + 0.2], [r + 0.2, r - 0.2], 'r', linewidth=3, zorder=11)
# Add legend
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen',
markersize=12, label='Start Node', markeredgecolor='black', markeredgewidth=2),
Line2D([0], [0], marker='o', color='w', markerfacecolor='lightcoral',
markersize=12, label='Target Node', markeredgecolor='black', markeredgewidth=2),
]
# Add blocked node indicator to legend if there are blocked nodes
if blocked_nodes:
legend_elements.append(
Line2D([0], [0], marker='o', color='w', markerfacecolor='dimgray',
markersize=12, label='Blocked Node', markeredgecolor='black', markeredgewidth=2)
)
ax.legend(handles=legend_elements, loc='upper left', fontsize=10, bbox_to_anchor=(1.02, 1))
plt.tight_layout()
# Save with run-specific filename
output_path = f'/Users/fuhongze/Desktop/robotic/verl/9grid/grid_path_run_{run_id}.png'
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.debug(f"βœ… Visualization saved to: {output_path}")
plt.close(fig) # Close figure to free memory
if __name__ == "__main__":
print("=" * 60)
print("3x3 Grid: Multiple DFS Path Findings with Randomization")
print("=" * 60)
# Configuration
num_runs = 10 # You can change this number
# # Define segments as (start, end, blocked_nodes) tuples
segments = [
(7, 9, [8, 3, 10, 17]), # Segment 1: 7 β†’ 9
(9, 11, [1, 8, 15, 10, 5, 12, 19]), # Segment 2: 9 β†’ 11
(11, 13, [3, 10, 17, 12]) # Segment 3: 11 β†’ 13
]
segment_2 = [
(13, 11, [3, 10, 17, 12]), # Segment 3: 11 β†’ 13
(11, 9, [1, 8, 15, 10, 5, 12, 19]), # Segment 2: 9 β†’ 11
(9, 7, [8, 3, 10, 17]), # Segment 1: 7 β†’ 9
]
R, C = 3, 7
# Run multiple times and visualize each separately
for i in range(num_runs):
# Create generator with specific seed for this run
generator = torch.Generator()
generator.manual_seed(i)
# Randomly choose between segments and segment_2
segment_choice = torch.randint(0, 2, (1,), generator=generator).item()
chosen_segments = segments if segment_choice == 0 else segment_2
segment_name = "segments" if segment_choice == 0 else "segment_2"
# Randomly select a continuous slice from chosen segments
# Possible slices: [0:1], [1:2], [2:3], [0:2], [1:3], [0:3]
num_segments = len(chosen_segments)
start_idx = torch.randint(0, num_segments, (1,), generator=generator).item()
end_idx = torch.randint(start_idx + 1, num_segments + 1, (1,), generator=generator).item()
selected_segments = chosen_segments[start_idx:end_idx]
segment_info = f"{segment_name}[{start_idx}:{end_idx}]"
# Print selection
print(f"\n🎲 Selected: {segment_info}")
print(f" Segments to use:")
for idx, (seg_start, seg_target, seg_blocked) in enumerate(selected_segments, start=start_idx):
print(f" [{idx}] {seg_start} β†’ {seg_target}, blocked: {seg_blocked}")
# Collect all blocked nodes for visualization
all_blocked_nodes = set()
for _, _, blocked in selected_segments:
all_blocked_nodes.update(blocked)
# Run path generation
combined_path, combined_edges, adj, start_node, end_node = run_path_generation(
run_id=i+1,
generator=generator,
segments=selected_segments,
R=R,
C=C
)
# Visualize this run if generation was successful
if combined_path is not None:
print(f"\n{'=' * 60}")
print(f"Visualizing Run {i+1}")
print(f"{'=' * 60}")
visualize_single_path(
path=combined_path,
edges_used=combined_edges,
adj=adj,
start=start_node,
target=end_node,
R=R,
C=C,
blocked_nodes=all_blocked_nodes,
run_id=i+1,
segment_info=segment_info
)