| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
|
|
| from ...logging_utils import logger |
|
|
|
|
| def generate_dynamic_walk(indices, steps=50, start_idx=None, allow_backtracking=True, |
| generator=None, plot=False): |
| """ |
| Generate random walk trajectory (supports PyTorch Generator). |
| |
| Args: |
| indices (list): Discrete position points (e.g., [0, 2, 4...]) |
| steps (int): Total steps |
| start_idx (int): Start index, random if None |
| allow_backtracking (bool): Whether to allow immediate backtracking |
| generator (torch.Generator): PyTorch generator for controlling random seed |
| plot (bool): Whether to plot |
| """ |
|
|
| |
| if start_idx is None: |
| |
| |
| start_idx = torch.randint(0, len(indices), (1,), generator=generator).item() |
|
|
| start_val = indices[start_idx] |
| logger.debug(f" -> Random start index: {start_idx} (Value: {start_val})") |
|
|
| history_idxs = [start_idx] |
|
|
| |
| for _ in range(steps): |
| current_idx = history_idxs[-1] |
| prev_idx = history_idxs[-2] if len(history_idxs) > 1 else None |
|
|
| |
| neighbors = [] |
| if current_idx > 0: |
| neighbors.append(current_idx - 1) |
| if current_idx < len(indices) - 1: |
| neighbors.append(current_idx + 1) |
|
|
| |
| candidates = [] |
| if allow_backtracking: |
| candidates = neighbors |
| else: |
| |
| if prev_idx is not None: |
| filtered = [n for n in neighbors if n != prev_idx] |
| candidates = filtered if filtered else neighbors |
| else: |
| candidates = neighbors |
|
|
| |
| |
| rand_choice_idx = torch.randint(0, len(candidates), (1,), generator=generator).item() |
| next_idx = candidates[rand_choice_idx] |
|
|
| history_idxs.append(next_idx) |
|
|
| |
| path_values = [indices[i] for i in history_idxs] |
|
|
| |
| if plot: |
| plt.figure(figsize=(12, 5)) |
| time_axis = range(len(path_values)) |
| color = '#1f77b4' if allow_backtracking else '#ff7f0e' |
| mode_str = "With Backtracking" if allow_backtracking else "No Backtracking (Inertia)" |
|
|
| plt.step(time_axis, path_values, where='post', marker='o', markersize=5, |
| linestyle='-', color=color, alpha=0.8, linewidth=2) |
| plt.yticks(indices) |
| plt.grid(axis='y', linestyle='--', alpha=0.5) |
| plt.title(f'Random Walk (Torch Seeded): {mode_str}\nStart Value: {start_val}', fontsize=12) |
| plt.xlabel('Time Step') |
| plt.ylabel('Button Value') |
|
|
| |
| plt.scatter(0, path_values[0], c='green', s=150, label='Start', zorder=5, edgecolors='white') |
| plt.scatter(steps, path_values[-1], c='red', marker='X', s=150, label='End', zorder=5, edgecolors='white') |
| plt.legend() |
| plt.tight_layout() |
| plt.show() |
|
|
| return path_values |
|
|
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| |
| |