| | import torch |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| |
|
| | from ...logging_utils import logger |
| |
|
| |
|
| | def generate_dynamic_walk(indices, steps=50, start_idx=None, allow_backtracking=True, |
| | generator=None, plot=False): |
| | """ |
| | Generate random walk trajectory (supports PyTorch Generator). |
| | |
| | Args: |
| | indices (list): Discrete position points (e.g., [0, 2, 4...]) |
| | steps (int): Total steps |
| | start_idx (int): Start index, random if None |
| | allow_backtracking (bool): Whether to allow immediate backtracking |
| | generator (torch.Generator): PyTorch generator for controlling random seed |
| | plot (bool): Whether to plot |
| | """ |
| |
|
| | |
| | if start_idx is None: |
| | |
| | |
| | start_idx = torch.randint(0, len(indices), (1,), generator=generator).item() |
| |
|
| | start_val = indices[start_idx] |
| | logger.debug(f" -> Random start index: {start_idx} (Value: {start_val})") |
| |
|
| | history_idxs = [start_idx] |
| |
|
| | |
| | for _ in range(steps): |
| | current_idx = history_idxs[-1] |
| | prev_idx = history_idxs[-2] if len(history_idxs) > 1 else None |
| |
|
| | |
| | neighbors = [] |
| | if current_idx > 0: |
| | neighbors.append(current_idx - 1) |
| | if current_idx < len(indices) - 1: |
| | neighbors.append(current_idx + 1) |
| |
|
| | |
| | candidates = [] |
| | if allow_backtracking: |
| | candidates = neighbors |
| | else: |
| | |
| | if prev_idx is not None: |
| | filtered = [n for n in neighbors if n != prev_idx] |
| | candidates = filtered if filtered else neighbors |
| | else: |
| | candidates = neighbors |
| |
|
| | |
| | |
| | rand_choice_idx = torch.randint(0, len(candidates), (1,), generator=generator).item() |
| | next_idx = candidates[rand_choice_idx] |
| |
|
| | history_idxs.append(next_idx) |
| |
|
| | |
| | path_values = [indices[i] for i in history_idxs] |
| |
|
| | |
| | if plot: |
| | plt.figure(figsize=(12, 5)) |
| | time_axis = range(len(path_values)) |
| | color = '#1f77b4' if allow_backtracking else '#ff7f0e' |
| | mode_str = "With Backtracking" if allow_backtracking else "No Backtracking (Inertia)" |
| |
|
| | plt.step(time_axis, path_values, where='post', marker='o', markersize=5, |
| | linestyle='-', color=color, alpha=0.8, linewidth=2) |
| | plt.yticks(indices) |
| | plt.grid(axis='y', linestyle='--', alpha=0.5) |
| | plt.title(f'Random Walk (Torch Seeded): {mode_str}\nStart Value: {start_val}', fontsize=12) |
| | plt.xlabel('Time Step') |
| | plt.ylabel('Button Value') |
| |
|
| | |
| | plt.scatter(0, path_values[0], c='green', s=150, label='Start', zorder=5, edgecolors='white') |
| | plt.scatter(steps, path_values[-1], c='red', marker='X', s=150, label='End', zorder=5, edgecolors='white') |
| | plt.legend() |
| | plt.tight_layout() |
| | plt.show() |
| |
|
| | return path_values |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |