File size: 16,440 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.animation import FuncAnimation
import numpy as np
import torch

from ...logging_utils import logger

def grid_adjacency(R=3, C=3, diagonals=False, by_index=True):
    """
    R, C: rows and columns
    diagonals: False=4-connectivity; True=8-connectivity
    by_index: True returns with index (0..R*C-1) as key; False returns with coordinate (r,c) as key
    """
    # 4-connectivity directions
    dirs = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    # 8-connectivity adds four diagonal directions
    if diagonals:
        dirs += [(-1, -1), (-1, 1), (1, -1), (1, 1)]

    adj = {}
    for r in range(R):
        for c in range(C):
            nbrs = []
            for dr, dc in dirs:
                nr, nc = r + dr, c + dc
                if 0 <= nr < R and 0 <= nc < C:
                    nbrs.append((nr, nc))

            key = r * C + c if by_index else (r, c)
            if by_index:
                adj[key] = [nr * C + nc for (nr, nc) in nbrs]
            else:
                adj[key] = nbrs
    return adj


def dfs_path(adj, start, target, generator=None, blocked_nodes=None):
    """
    Execute DFS to find path from start to target
    Returns visit_order, path and edges used during search

    Args:
        adj: adjacency list
        start: start node
        target: target node
        generator: torch.Generator for random neighbor selection
        blocked_nodes: list of node indices to avoid (cannot pass through these nodes)

    Returns:
        visit_order: list of all nodes visited in order
        path: final path from start to target
        edges_used: edges traversed during search
    """
    visited = set()
    visit_order = []  # Track all nodes visited in order
    path = []
    edges_used = []
    found = False

    # Convert blocked_nodes to set for O(1) lookup
    if blocked_nodes is None:
        blocked_nodes = set()
    else:
        blocked_nodes = set(blocked_nodes)

    def dfs_helper(node, current_path):
        nonlocal found
        if found:
            return

        visited.add(node)
        visit_order.append(node)  # Record visit order
        current_path.append(node)

        # Found target node
        if node == target:
            path.extend(current_path)
            found = True
            return

        # Get neighbors and optionally shuffle them
        neighbors = adj[node].copy()
        if generator is not None:
            # Use torch to randomly permute neighbors
            neighbors_tensor = torch.tensor(neighbors)
            perm = torch.randperm(len(neighbors), generator=generator)
            neighbors = neighbors_tensor[perm].tolist()

        # Continue searching
        for neighbor in neighbors:
            # Skip blocked nodes, already visited nodes
            if neighbor not in visited and neighbor not in blocked_nodes and not found:
                edges_used.append((node, neighbor))
                dfs_helper(neighbor, current_path)
                if found:
                    return
                current_path.pop()

    dfs_helper(start, [])
    return visit_order, path, edges_used


def index_to_coord(idx, C):
    """Convert index to coordinate (row, col)"""
    return idx // C, idx % C


def find_path_0_to_8(start, target, R=3, C=3, diagonals=False, generator=None, blocked_nodes=None):
    """Find path from start to target using DFS

    Args:
        start: start node
        target: target node
        R: number of rows
        C: number of columns
        diagonals: whether to use diagonal connections
        generator: torch.Generator for random neighbor selection
        blocked_nodes: list of node indices to avoid (cannot pass through these nodes)

    Returns:
        path: list of nodes in the final path from start to target (first return value)
        visit_order: list of all nodes visited in order during DFS
        edges_used: edges traversed during search
        adj: adjacency list
    """
    # Generate adjacency list
    adj = grid_adjacency(R, C, diagonals=diagonals, by_index=True)

    # Execute DFS to find path
    visit_order, path, edges_used = dfs_path(adj, start, target, generator=generator, blocked_nodes=blocked_nodes)

    if not path:
        logger.debug(f"❌ Cannot find path from {start} to {target}!")
        return None, None, None, None

    # Only print the path
    logger.debug(f"Path: {' β†’ '.join(map(str, path))}")

    return path, visit_order, edges_used, adj


def run_path_generation(run_id, generator, segments, R=3, C=7, backtrack_enable=True):
    """
    Run a single path generation with segments

    Args:
        run_id: Run identifier for printing
        generator: torch.Generator for reproducible randomness
        segments: List of segment tuples, each as (start, end, blocked_nodes)
                  e.g., [(7, 9, [8, 3, 10, 17]), (9, 11, [1, 8, 15]), ...]
        R: Number of rows in grid
        C: Number of columns in grid
        backtrack_enable: Whether random backtracking passes are allowed

    Returns:
        combined_path: Complete path through all waypoints
        combined_edges: All edges used in the path
        adj: Adjacency list of the grid
        start_node: Starting node (first segment's start)
        end_node: Ending node (last segment's end)
        Returns (None, None, None, None, None) if failed
    """
    # Extract waypoints for display
    waypoints = [segments[0][0]]  # Start with first segment's start
    for seg in segments:
        waypoints.append(seg[1])  # Add each segment's end

    logger.debug(f"\n{'=' * 60}")
    logger.debug(f"Run {run_id}: Finding path through waypoints {' β†’ '.join(map(str, waypoints))}")
    logger.debug(f"{'=' * 60}")

    all_paths = []
    all_edges = []
    adj = None

    for seg_idx, (seg_start, seg_target, seg_blocked) in enumerate(segments):

        # Randomly decide whether to use backtracking for this segment
        # Use torch generator for reproducible randomness
        do_backtrack = backtrack_enable and (torch.rand(1, generator=generator).item() > 0.5)

        # Print segment header
        logger.debug(f"\nSegment {seg_idx + 1}: {seg_start} β†’ {seg_target}")

        # Forward path: seg_start β†’ seg_target (always executed)
        logger.debug(f"  Path 1 (Forward {seg_start}β†’{seg_target}):")
        path_forward, visit_order_fwd, edges_fwd, adj = find_path_0_to_8(
            start=seg_start, target=seg_target, R=R, C=C, diagonals=False,
            generator=generator, blocked_nodes=seg_blocked
        )

        if not path_forward:
            logger.debug(f"\n❌ Failed to find forward path for segment {seg_idx + 1}, skipping this run")
            return None, None, None, None, None

        if do_backtrack:
            # Backward path: seg_target β†’ seg_start (backtracking)
            logger.debug(f"  Path 2 (Backward {seg_target}β†’{seg_start}):")
            path_backward, visit_order_bwd, edges_bwd, adj = find_path_0_to_8(
                start=seg_target, target=seg_start, R=R, C=C, diagonals=False,
                generator=generator, blocked_nodes=seg_blocked
            )

            if not path_backward:
                logger.debug(f"\n❌ Failed to find backward path for segment {seg_idx + 1}, skipping this run")
                return None, None, None, None, None

            # Forward path again: seg_start β†’ seg_target
            logger.debug(f"  Path 3 (Forward {seg_start}β†’{seg_target}):")
            path_forward2, visit_order_fwd2, edges_fwd2, adj = find_path_0_to_8(
                start=seg_start, target=seg_target, R=R, C=C, diagonals=False,
                generator=generator, blocked_nodes=seg_blocked
            )

            if not path_forward2:
                logger.debug(f"\n❌ Failed to find second forward path for segment {seg_idx + 1}, skipping this run")
                return None, None, None, None, None

            # Combine: forward + backward + forward (removing duplicate nodes at connection points)
            seg_combined = path_forward + path_backward[1:] + path_forward2[1:]
            seg_edges = edges_fwd + edges_bwd + edges_fwd2

            logger.debug(f"  β†’ Segment generated 3 paths (with backtracking)")
        else:
            # No backtracking: just use forward path
            seg_combined = path_forward
            seg_edges = edges_fwd

            logger.debug(f"  β†’ Segment generated 1 path (no backtracking)")

        all_paths.append(seg_combined)
        all_edges.extend(seg_edges)

    # Combine all paths (remove duplicate waypoint nodes between segments)
    combined_path = all_paths[0]
    for path_seg in all_paths[1:]:
        combined_path = combined_path + path_seg[1:]  # Skip first element (waypoint)

    combined_edges = all_edges

    logger.debug(f"\nβœ… Final combined path: {' β†’ '.join(map(str, combined_path))}")

    # Return data separately (not as tuple)
    return combined_path, combined_edges, adj, waypoints[0], waypoints[-1]


def visualize_single_path(path, edges_used, adj, start, target, R, C, blocked_nodes, run_id, segment_info):
    """
    Visualize a single DFS path

    Args:
        path: list of nodes in the path
        edges_used: list of edges used in the path
        adj: adjacency dictionary
        start: start node
        target: target node
        R: number of rows
        C: number of columns
        blocked_nodes: set of blocked node indices
        run_id: run identifier for file naming
        segment_info: string describing the segment selection (e.g., "segments[0:2]")
    """
    if not path:
        logger.debug("❌ No path to visualize!")
        return

    # Convert blocked_nodes to set for O(1) lookup
    if blocked_nodes is None:
        blocked_nodes = set()
    else:
        blocked_nodes = set(blocked_nodes)

    fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    ax.set_xlim(-0.5, C - 0.5)
    ax.set_ylim(-0.5, R - 0.5)
    ax.set_aspect('equal')
    ax.invert_yaxis()
    ax.set_title(f'Run {run_id}: {segment_info}\nPath: {start} β†’ {target}',
                 fontsize=14, fontweight='bold', pad=20)
    ax.set_xlabel('Column')
    ax.set_ylabel('Row')
    ax.grid(True, alpha=0.3)

    # Draw all possible edges (adjacency relationships) - very light
    for node, neighbors in adj.items():
        r1, c1 = index_to_coord(node, C)
        for neighbor in neighbors:
            r2, c2 = index_to_coord(neighbor, C)
            if node < neighbor:
                ax.plot([c1, c2], [r1, r2], 'gray', alpha=0.15, linewidth=1, zorder=1)

    # Draw path edges in blue
    color = 'blue'
    for i in range(len(path) - 1):
        node = path[i]
        neighbor = path[i + 1]
        r1, c1 = index_to_coord(node, C)
        r2, c2 = index_to_coord(neighbor, C)

        ax.plot([c1, c2], [r1, r2], color=color, linewidth=3, alpha=0.7, zorder=3)

        # Add arrows
        dx, dy = c2 - c1, r2 - r1
        ax.arrow(c1, r1, dx * 0.6, dy * 0.6,
                head_width=0.12, head_length=0.08,
                fc=color, ec=color, alpha=0.7, zorder=3)

    # Draw all nodes
    for idx in range(R * C):
        r, c = index_to_coord(idx, C)

        # Check if node is blocked
        if idx in blocked_nodes:
            node_color = 'dimgray'
        else:
            if idx == start:
                node_color = 'lightgreen'
            elif idx == target:
                node_color = 'lightcoral'
            elif idx in path:
                node_color = 'lightyellow'
            else:
                node_color = 'lightgray'

        circle = plt.Circle((c, r), 0.35, color=node_color, ec='black', linewidth=2, zorder=2)
        ax.add_patch(circle)
        ax.text(c, r, str(idx), ha='center', va='center',
                fontsize=16, fontweight='bold',
                color='white' if idx in blocked_nodes else 'black', zorder=10)

        # Add X mark for blocked nodes
        if idx in blocked_nodes:
            ax.plot([c - 0.2, c + 0.2], [r - 0.2, r + 0.2], 'r', linewidth=3, zorder=11)
            ax.plot([c - 0.2, c + 0.2], [r + 0.2, r - 0.2], 'r', linewidth=3, zorder=11)

    # Add legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen',
               markersize=12, label='Start Node', markeredgecolor='black', markeredgewidth=2),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='lightcoral',
               markersize=12, label='Target Node', markeredgecolor='black', markeredgewidth=2),
    ]

    # Add blocked node indicator to legend if there are blocked nodes
    if blocked_nodes:
        legend_elements.append(
            Line2D([0], [0], marker='o', color='w', markerfacecolor='dimgray',
                   markersize=12, label='Blocked Node', markeredgecolor='black', markeredgewidth=2)
        )

    ax.legend(handles=legend_elements, loc='upper left', fontsize=10, bbox_to_anchor=(1.02, 1))

    plt.tight_layout()

    # Save with run-specific filename
    output_path = f'/Users/fuhongze/Desktop/robotic/verl/9grid/grid_path_run_{run_id}.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    logger.debug(f"βœ… Visualization saved to: {output_path}")
    plt.close(fig)  # Close figure to free memory


if __name__ == "__main__":
    print("=" * 60)
    print("3x3 Grid: Multiple DFS Path Findings with Randomization")
    print("=" * 60)

    # Configuration
    num_runs = 10  # You can change this number

    # # Define segments as (start, end, blocked_nodes) tuples
    segments = [
        (7, 9, [8, 3, 10, 17]),           # Segment 1: 7 β†’ 9
        (9, 11, [1, 8, 15, 10, 5, 12, 19]), # Segment 2: 9 β†’ 11
        (11, 13, [3, 10, 17, 12])           # Segment 3: 11 β†’ 13
    ]

    segment_2 = [
        (13, 11, [3, 10, 17, 12]),  # Segment 3: 11 β†’ 13
        (11, 9, [1, 8, 15, 10, 5, 12, 19]),  # Segment 2: 9 β†’ 11
        (9, 7, [8, 3, 10, 17]),           # Segment 1: 7 β†’ 9
    ]


    R, C = 3, 7

    # Run multiple times and visualize each separately
    for i in range(num_runs):
        # Create generator with specific seed for this run
        generator = torch.Generator()
        generator.manual_seed(i)

        # Randomly choose between segments and segment_2
        segment_choice = torch.randint(0, 2, (1,), generator=generator).item()
        chosen_segments = segments if segment_choice == 0 else segment_2
        segment_name = "segments" if segment_choice == 0 else "segment_2"

        # Randomly select a continuous slice from chosen segments
        # Possible slices: [0:1], [1:2], [2:3], [0:2], [1:3], [0:3]
        num_segments = len(chosen_segments)
        start_idx = torch.randint(0, num_segments, (1,), generator=generator).item()
        end_idx = torch.randint(start_idx + 1, num_segments + 1, (1,), generator=generator).item()

        selected_segments = chosen_segments[start_idx:end_idx]
        segment_info = f"{segment_name}[{start_idx}:{end_idx}]"

        # Print selection
        print(f"\n🎲 Selected: {segment_info}")
        print(f"   Segments to use:")
        for idx, (seg_start, seg_target, seg_blocked) in enumerate(selected_segments, start=start_idx):
            print(f"     [{idx}] {seg_start} β†’ {seg_target}, blocked: {seg_blocked}")

        # Collect all blocked nodes for visualization
        all_blocked_nodes = set()
        for _, _, blocked in selected_segments:
            all_blocked_nodes.update(blocked)

        # Run path generation
        combined_path, combined_edges, adj, start_node, end_node = run_path_generation(
            run_id=i+1,
            generator=generator,
            segments=selected_segments,
            R=R,
            C=C
        )

        # Visualize this run if generation was successful
        if combined_path is not None:
            print(f"\n{'=' * 60}")
            print(f"Visualizing Run {i+1}")
            print(f"{'=' * 60}")
            visualize_single_path(
                path=combined_path,
                edges_used=combined_edges,
                adj=adj,
                start=start_node,
                target=end_node,
                R=R,
                C=C,
                blocked_nodes=all_blocked_nodes,
                run_id=i+1,
                segment_info=segment_info
            )