File size: 4,117 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import numpy as np  # Used for plotting and data structures
import matplotlib.pyplot as plt

from ...logging_utils import logger


def generate_dynamic_walk(indices, steps=50, start_idx=None, allow_backtracking=True,
                          generator=None, plot=False):
    """
    Generate random walk trajectory (supports PyTorch Generator).

    Args:
        indices (list): Discrete position points (e.g., [0, 2, 4...])
        steps (int): Total steps
        start_idx (int): Start index, random if None
        allow_backtracking (bool): Whether to allow immediate backtracking
        generator (torch.Generator): PyTorch generator for controlling random seed
        plot (bool): Whether to plot
    """

    # 1. Initialization
    if start_idx is None:
        # [Modification] Use torch to generate random start point
        # randint returns tensor, need .item() to convert to python int
        start_idx = torch.randint(0, len(indices), (1,), generator=generator).item()

    start_val = indices[start_idx]
    logger.debug(f"   -> Random start index: {start_idx} (Value: {start_val})")

    history_idxs = [start_idx]

    # 2. Generation loop
    for _ in range(steps):
        current_idx = history_idxs[-1]
        prev_idx = history_idxs[-2] if len(history_idxs) > 1 else None

        # Find all physically reachable neighbors
        neighbors = []
        if current_idx > 0:
            neighbors.append(current_idx - 1)
        if current_idx < len(indices) - 1:
            neighbors.append(current_idx + 1)

        # --- Core logic: Backtracking filter ---
        candidates = []
        if allow_backtracking:
            candidates = neighbors
        else:
            # Backtracking not allowed
            if prev_idx is not None:
                filtered = [n for n in neighbors if n != prev_idx]
                candidates = filtered if filtered else neighbors  # Must backtrack if no way forward
            else:
                candidates = neighbors

        # [Modification] Use torch to select randomly from candidates
        # Principle: Randomly generate an index from 0 to len(candidates)-1
        rand_choice_idx = torch.randint(0, len(candidates), (1,), generator=generator).item()
        next_idx = candidates[rand_choice_idx]

        history_idxs.append(next_idx)

    # Map back to real values
    path_values = [indices[i] for i in history_idxs]

    # 3. Visualization
    if plot:
        plt.figure(figsize=(12, 5))
        time_axis = range(len(path_values))
        color = '#1f77b4' if allow_backtracking else '#ff7f0e'
        mode_str = "With Backtracking" if allow_backtracking else "No Backtracking (Inertia)"

        plt.step(time_axis, path_values, where='post', marker='o', markersize=5,
                 linestyle='-', color=color, alpha=0.8, linewidth=2)
        plt.yticks(indices)
        plt.grid(axis='y', linestyle='--', alpha=0.5)
        plt.title(f'Random Walk (Torch Seeded): {mode_str}\nStart Value: {start_val}', fontsize=12)
        plt.xlabel('Time Step')
        plt.ylabel('Button Value')

        # Mark start and end points
        plt.scatter(0, path_values[0], c='green', s=150, label='Start', zorder=5, edgecolors='white')
        plt.scatter(steps, path_values[-1], c='red', marker='X', s=150, label='End', zorder=5, edgecolors='white')
        plt.legend()
        plt.tight_layout()
        plt.show()

    return path_values


# # --- Comparison test (with seed) ---
# button_indices = [0, 2, 4, 6, 8]

# # Create a Generator and set seed (guarantee reproducible results)
# seed = 42
# rng = torch.Generator()
# rng.manual_seed(seed)

# print(f"--- Test Start (Seed: {seed}) ---")

# # 1. Enable backtracking
# print("Scheme 1: Allow backtracking")
# traj_1 = generate_dynamic_walk(button_indices, steps=30, allow_backtracking=True, generator=rng)

# # 2. Disable backtracking (Note: using same generator, random sequence continues from last call)
# print("\nScheme 2: Forbid backtracking (Inertia mode)")
# traj_2 = generate_dynamic_walk(button_indices, steps=30, allow_backtracking=False, generator=rng)