|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
from colorsys import hsv_to_rgb |
|
|
import string |
|
|
import networkx as nx |
|
|
import re |
|
|
|
|
|
|
|
|
import matplotlib.animation as animation |
|
|
from itertools import combinations |
|
|
import random |
|
|
|
|
|
|
|
|
|
|
|
def plot_string(text, figsize=(18, 4)): |
|
|
""" |
|
|
Plot string with identical colors for identical letters and 0.2x letter width spacing. |
|
|
Supports lowercase letters, uppercase letters, and underscores. |
|
|
""" |
|
|
plt.rcParams['font.family'] = 'Times New Roman' |
|
|
plt.figure(figsize=figsize) |
|
|
|
|
|
|
|
|
unique_chars = (sorted(set(string.ascii_lowercase))[:10]) |
|
|
hues = np.linspace(0, 1, len(unique_chars), endpoint=False) |
|
|
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9) |
|
|
for char, hue in zip(unique_chars, hues)} |
|
|
unique_chars += ['-'] |
|
|
|
|
|
color_map['-'] = (0.5, 0.5, 0.5) |
|
|
|
|
|
unique_chars += (sorted(set(string.ascii_lowercase))[10:]) |
|
|
color_map = {char: hsv_to_rgb(hue, 0.3, 0.8) |
|
|
for char, hue in zip(unique_chars, hues)} |
|
|
|
|
|
spacing = 0.1 |
|
|
width = 1.0 |
|
|
total_width = len(text) * width * (1 + spacing) - spacing |
|
|
|
|
|
for i, char in enumerate(text): |
|
|
x_pos = i * width * (1 + spacing) |
|
|
|
|
|
if char == '_': |
|
|
|
|
|
plt.plot([x_pos - width/3, x_pos + width/3], |
|
|
[-0.2, -0.2], |
|
|
color=color_map['_'], |
|
|
linewidth=2) |
|
|
else: |
|
|
|
|
|
color = color_map[char.lower()] |
|
|
plt.text(x_pos, 0, char, fontsize=14, color=color, |
|
|
ha='center', va='center') |
|
|
|
|
|
plt.xlim(-width/2, total_width - width/2) |
|
|
plt.ylim(-0.5, 0.5) |
|
|
plt.axis('off') |
|
|
plt.tight_layout() |
|
|
plt.savefig('transcript.svg', format='svg') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_sequence(sequence): |
|
|
"""Parse sequence into main path and branches""" |
|
|
branches = [] |
|
|
main_path = [] |
|
|
current_branch = [] |
|
|
in_branch = False |
|
|
|
|
|
for part in sequence.split(','): |
|
|
part = part.strip() |
|
|
if '(' in part: |
|
|
in_branch = True |
|
|
part = part.replace('(', '') |
|
|
if ')' in part: |
|
|
in_branch = False |
|
|
part = part.replace(')', '') |
|
|
|
|
|
nodes = re.findall(r'([A-Za-z_]\d+)', part) |
|
|
if len(nodes) == 2: |
|
|
if in_branch: |
|
|
current_branch.extend(nodes) |
|
|
else: |
|
|
if current_branch: |
|
|
branches.append(current_branch) |
|
|
current_branch = [] |
|
|
main_path.extend(nodes) |
|
|
|
|
|
if current_branch: |
|
|
branches.append(current_branch) |
|
|
|
|
|
return main_path, branches |
|
|
|
|
|
def get_node_number(node): |
|
|
"""Extract number from node label""" |
|
|
return int(re.findall(r'\d+', node)[0]) |
|
|
|
|
|
def plot_sequence(sequence, figsize=(8, 3)): |
|
|
plt.figure(figsize=figsize) |
|
|
|
|
|
|
|
|
unique_chars = sorted(set(string.ascii_uppercase))[:10] |
|
|
hues = np.linspace(0, 1, len(unique_chars), endpoint=False) |
|
|
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9) |
|
|
for char, hue in zip(unique_chars, hues)} |
|
|
|
|
|
unique_chars += ['_'] |
|
|
|
|
|
color_map['_'] = (0.5, 0.5, 0.5) |
|
|
|
|
|
color_map.update({char: hsv_to_rgb(hue, 0.3, 0.8) |
|
|
for char, hue in zip(sorted(set(string.ascii_uppercase))[10:], hues)}) |
|
|
|
|
|
main_path, branches = parse_sequence(sequence) |
|
|
G = nx.DiGraph() |
|
|
|
|
|
|
|
|
pos = {} |
|
|
x_spacing = 1 |
|
|
y_spacing = 0.5 |
|
|
|
|
|
|
|
|
nodes_by_number = {} |
|
|
all_nodes = set(main_path) |
|
|
for branch in branches: |
|
|
all_nodes.update(branch) |
|
|
|
|
|
for node in all_nodes: |
|
|
num = get_node_number(node) |
|
|
if num not in nodes_by_number: |
|
|
nodes_by_number[num] = [] |
|
|
nodes_by_number[num].append(node) |
|
|
|
|
|
|
|
|
for num in sorted(nodes_by_number.keys()): |
|
|
nodes = nodes_by_number[num] |
|
|
x = (num - 1) * x_spacing |
|
|
|
|
|
if len(nodes) == 1: |
|
|
pos[nodes[0]] = (x, 0) |
|
|
else: |
|
|
|
|
|
total_height = (len(nodes) - 1) * y_spacing |
|
|
start_y = -total_height / 2 |
|
|
for i, node in enumerate(sorted(nodes)): |
|
|
pos[node] = (x, start_y + i * y_spacing) |
|
|
|
|
|
|
|
|
for i in range(0, len(main_path)-1, 2): |
|
|
G.add_edge(main_path[i], main_path[i+1]) |
|
|
|
|
|
for branch in branches: |
|
|
for i in range(0, len(branch)-1, 2): |
|
|
G.add_edge(branch[i], branch[i+1]) |
|
|
|
|
|
|
|
|
|
|
|
nx.draw_networkx_edges(G, pos, edge_color='gray', |
|
|
arrowsize=10, width=1.5, |
|
|
arrowstyle='->') |
|
|
|
|
|
|
|
|
for node in G.nodes(): |
|
|
letter = node[0] |
|
|
color = color_map[letter] |
|
|
|
|
|
circle = plt.Circle(pos[node], 0.1, |
|
|
color=color, alpha=0.3) |
|
|
plt.gca().add_patch(circle) |
|
|
|
|
|
plt.text(pos[node][0], pos[node][1], node[0], |
|
|
color=color, fontsize=8, |
|
|
ha='center', va='center', |
|
|
fontweight='bold') |
|
|
|
|
|
plt.axis('equal') |
|
|
plt.axis('off') |
|
|
plt.tight_layout() |
|
|
plt.savefig('transcript.svg', format='svg') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_routes_animation(G, routes, colors, output_file, fps=20, duration_sec=10): |
|
|
""" |
|
|
Create an animation showing routes appearing dynamically. |
|
|
|
|
|
Args: |
|
|
G (networkx.MultiDiGraph): Street network graph |
|
|
routes (list): List of routes (each route is a list of nodes) |
|
|
colors (list): List of colors for each route |
|
|
output_file (str): Output filename (should end with .gif or .mp4) |
|
|
fps (int): Frames per second |
|
|
duration_sec (int): Total animation duration in seconds |
|
|
""" |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 8)) |
|
|
plt.rcParams['font.family'] = 'Times New Roman' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
route_lines = [] |
|
|
route_points = [] |
|
|
|
|
|
|
|
|
for color in colors: |
|
|
line, = ax.plot([], [], marker='D', color=color, linewidth=2, alpha=0.8, zorder=2) |
|
|
route_lines.append(line) |
|
|
route_points.append([]) |
|
|
|
|
|
|
|
|
all_route_coords = [] |
|
|
for route in routes: |
|
|
coords = [] |
|
|
for node in route: |
|
|
x = G.nodes[node]['x'] |
|
|
y = G.nodes[node]['y'] |
|
|
coords.append((x, y)) |
|
|
all_route_coords.append(coords) |
|
|
|
|
|
|
|
|
total_frames = fps * duration_sec |
|
|
|
|
|
|
|
|
def update(frame): |
|
|
|
|
|
progress = frame / total_frames |
|
|
|
|
|
|
|
|
for i, coords in enumerate(all_route_coords): |
|
|
|
|
|
route_progress = min(1.0, progress * len(routes) - i) |
|
|
|
|
|
if route_progress <= 0: |
|
|
|
|
|
route_lines[i].set_data([], []) |
|
|
continue |
|
|
|
|
|
|
|
|
num_points = max(2, int(route_progress * len(coords))) |
|
|
|
|
|
|
|
|
visible_coords = coords[:num_points] |
|
|
xs, ys = zip(*visible_coords) if visible_coords else ([], []) |
|
|
|
|
|
|
|
|
route_lines[i].set_data(xs, ys) |
|
|
|
|
|
|
|
|
visible_routes = [i for i, line in enumerate(route_lines) |
|
|
if len(line.get_xdata()) > 0] |
|
|
|
|
|
if visible_routes: |
|
|
|
|
|
ax.legend([route_lines[i] for i in visible_routes], |
|
|
[f'Period {i+1}' for i in visible_routes], |
|
|
loc='upper right', prop={'size': 14}, |
|
|
bbox_to_anchor=(1, 1)) |
|
|
|
|
|
return route_lines |
|
|
|
|
|
|
|
|
ani = animation.FuncAnimation( |
|
|
fig, update, frames=total_frames, |
|
|
interval=1000/fps, blit=True |
|
|
) |
|
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
if output_file.endswith('.gif'): |
|
|
ani.save(output_file, writer='pillow', fps=fps, dpi=150) |
|
|
else: |
|
|
|
|
|
writer = animation.FFMpegWriter(fps=fps, bitrate=5000) |
|
|
ani.save(output_file, writer=writer, dpi=150) |
|
|
|
|
|
plt.close() |
|
|
|
|
|
print(f"Animation saved to {output_file}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_routes(G, routes, colors, output_file): |
|
|
""" |
|
|
Plot multiple routes on the same map. |
|
|
|
|
|
Args: |
|
|
G (networkx.MultiDiGraph): Street network graph |
|
|
routes (list): List of routes (each route is a list of nodes) |
|
|
colors (list): List of colors for each route |
|
|
output_file (str): Output filename |
|
|
""" |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 6)) |
|
|
plt.rcParams['font.family'] = 'Times New Roman' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
route_lines = [] |
|
|
|
|
|
|
|
|
for route, color in zip(routes, colors): |
|
|
|
|
|
xs = [] |
|
|
ys = [] |
|
|
for node in route: |
|
|
|
|
|
x = G.nodes[node]['x'] |
|
|
y = G.nodes[node]['y'] |
|
|
xs.append(x) |
|
|
ys.append(y) |
|
|
|
|
|
|
|
|
line = ax.plot(xs, ys, marker='D', color=color, linewidth=2, alpha=0.2, zorder=2)[0] |
|
|
route_lines.append(line) |
|
|
|
|
|
|
|
|
ax.legend(route_lines, |
|
|
[f'Period {i+1}' for i in range(len(routes))], |
|
|
loc='upper right', prop={'size': 14}, |
|
|
bbox_to_anchor=(1.25, 0.85)) |
|
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
plt.savefig(output_file, dpi=100, bbox_inches='tight') |
|
|
plt.show() |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_task_2(obs_len, gt_seq_len, pred_seq_len, figsize_w=10, title=None): |
|
|
""" |
|
|
Plot both GT and Pred timelines in the same figure with aligned scales. |
|
|
|
|
|
Args: |
|
|
obs_len: Length of observation period |
|
|
gt_seq_len: Total sequence length for ground truth |
|
|
pred_seq_len: Total sequence length for prediction |
|
|
figsize_w: Width of the figure |
|
|
title: Optional title for the figure |
|
|
""" |
|
|
|
|
|
max_seq_len = max(gt_seq_len, pred_seq_len) |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(2, 1, figsize=(figsize_w, 2.5), gridspec_kw={'hspace': 0.3}) |
|
|
plt.rcParams['font.family'] = 'Times New Roman' |
|
|
|
|
|
|
|
|
if title: |
|
|
fig.suptitle(title, fontsize=14, fontweight='bold') |
|
|
|
|
|
|
|
|
bar_height = 0.5 |
|
|
label_offset = max_seq_len * 0.15 |
|
|
|
|
|
|
|
|
y_position = 0 |
|
|
axes[0].barh(y_position, obs_len, height=bar_height, left=0, color='lightgray') |
|
|
axes[0].barh(y_position, gt_seq_len+1-obs_len, height=bar_height, left=obs_len, color='lightgreen') |
|
|
axes[0].text(-label_offset, y_position, "GT:", fontsize=12, fontweight='bold', verticalalignment='center') |
|
|
|
|
|
|
|
|
axes[1].barh(y_position, obs_len, height=bar_height, left=0, color='lightgray') |
|
|
axes[1].barh(y_position, pred_seq_len+1-obs_len, height=bar_height, left=obs_len, color='lightblue') |
|
|
axes[1].text(-label_offset, y_position, "Pred:", fontsize=12, fontweight='bold', verticalalignment='center') |
|
|
|
|
|
|
|
|
for ax in axes: |
|
|
|
|
|
ax.set_xlim(-label_offset, max_seq_len+1) |
|
|
ax.set_ylim(-0.5, 0.5) |
|
|
ax.set_yticks([]) |
|
|
|
|
|
|
|
|
for spine in ax.spines.values(): |
|
|
spine.set_visible(False) |
|
|
|
|
|
|
|
|
ax.axhline(y_position - bar_height/2, color='black', linewidth=0.5) |
|
|
|
|
|
|
|
|
axes[0].set_xticks([0, obs_len, gt_seq_len]) |
|
|
axes[1].set_xticks([0, obs_len, pred_seq_len]) |
|
|
|
|
|
plt.tight_layout(rect=[0, 0, 1, 0.95] if title else [0, 0, 1, 1]) |
|
|
return fig |
|
|
|
|
|
|
|
|
def plot_task_3(gt_seq_len, GT_start, GT_end, pred_start, pred_end, figsize_w=10, title=None): |
|
|
""" |
|
|
Plot both GT and Pred timelines in the same figure with aligned scales. |
|
|
|
|
|
Args: |
|
|
gt_seq_len: Total sequence length for ground truth |
|
|
GT_start: Start position of GT highlight bar |
|
|
GT_end: End position of GT highlight bar |
|
|
pred_start: Start position of prediction highlight bar |
|
|
pred_end: End position of prediction highlight bar |
|
|
figsize_w: Width of the figure |
|
|
title: Optional title for the figure |
|
|
""" |
|
|
|
|
|
max_seq_len = gt_seq_len |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(2, 1, figsize=(figsize_w, 2.5), gridspec_kw={'hspace': 0.3}) |
|
|
plt.rcParams['font.family'] = 'Times New Roman' |
|
|
|
|
|
|
|
|
if title: |
|
|
fig.suptitle(title, fontsize=14, fontweight='bold') |
|
|
|
|
|
|
|
|
bar_height = 0.5 |
|
|
label_offset = max_seq_len * 0.15 |
|
|
|
|
|
|
|
|
y_position = 0 |
|
|
|
|
|
axes[0].barh(y_position, gt_seq_len, height=bar_height, left=0, color='lightgray') |
|
|
|
|
|
axes[0].barh(y_position, GT_end - GT_start, height=bar_height, left=GT_start, color='lightgreen') |
|
|
axes[0].text(-label_offset, y_position, "GT:", fontsize=12, fontweight='bold', verticalalignment='center') |
|
|
|
|
|
|
|
|
|
|
|
axes[1].barh(y_position, gt_seq_len, height=bar_height, left=0, color='lightgray') |
|
|
|
|
|
axes[1].barh(y_position, pred_end - pred_start, height=bar_height, left=pred_start, color='lightblue') |
|
|
axes[1].text(-label_offset, y_position, "Pred:", fontsize=12, fontweight='bold', verticalalignment='center') |
|
|
|
|
|
|
|
|
for ax in axes: |
|
|
|
|
|
ax.set_xlim(-label_offset, max_seq_len+1) |
|
|
ax.set_ylim(-0.5, 0.5) |
|
|
ax.set_yticks([]) |
|
|
|
|
|
|
|
|
for spine in ax.spines.values(): |
|
|
spine.set_visible(False) |
|
|
|
|
|
|
|
|
ax.axhline(y_position - bar_height/2, color='black', linewidth=0.5) |
|
|
|
|
|
|
|
|
axes[0].set_xticks([0, GT_start, GT_end, gt_seq_len]) |
|
|
axes[1].set_xticks([0, pred_start, pred_end, gt_seq_len]) |
|
|
|
|
|
plt.tight_layout(rect=[0, 0, 1, 0.95] if title else [0, 0, 1, 1]) |
|
|
return fig |
|
|
|
|
|
|
|
|
import string |
|
|
def plot_images_with_token(images, tokens, n_rows = 2): |
|
|
assert len(images) == len(tokens), "Each image must have a corresponding token" |
|
|
|
|
|
n_images = len(images) |
|
|
|
|
|
|
|
|
n_cols = (n_images + 1) // n_rows |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows)) |
|
|
plt.rcParams['font.family'] = 'Times New Roman' |
|
|
unique_chars = (sorted(set(string.ascii_lowercase))[:10]) |
|
|
hues = np.linspace(0, 1, len(unique_chars), endpoint=False) |
|
|
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9) |
|
|
for char, hue in zip(unique_chars, hues)} |
|
|
|
|
|
|
|
|
if n_cols == 1: |
|
|
axes = axes.reshape(-1, 1) |
|
|
|
|
|
|
|
|
axes_flat = axes.flatten() |
|
|
|
|
|
for i, (image, token) in enumerate(zip(images, tokens)): |
|
|
if i < len(axes_flat): |
|
|
color = color_map[token.lower()] |
|
|
axes_flat[i].imshow(image) |
|
|
axes_flat[i].set_title(token, color=color, size=50) |
|
|
axes_flat[i].axis('off') |
|
|
|
|
|
|
|
|
for j in range(i + 1, n_rows * n_cols): |
|
|
if j < len(axes_flat): |
|
|
axes_flat[j].axis('off') |
|
|
fig.delaxes(axes_flat[j]) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('anchors.jpg', bbox_inches='tight', pad_inches=0) |
|
|
|
|
|
|
|
|
|
|
|
|