| """Helpers for visualization""" |
| import os |
| import itertools |
| from os.path import exists |
|
|
| import numpy as np |
| import matplotlib |
| import matplotlib.pyplot as plt |
| import cv2 |
| import pandas as pd |
| import PIL |
| from PIL import Image, ImageOps, ImageDraw |
| from tqdm import tqdm |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| from IPython.display import Audio, Markdown, display |
| try: |
| from ipywidgets import Button, HBox, VBox, Text, Label, HTML, widgets |
| except: |
| exit("Failed to import ipywidgets. Please install.") |
|
|
| from shared.utils.log import tqdm_iterator |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| try: |
| import torchvideotransforms |
| except: |
| print("Failed to import torchvideotransforms. Proceeding without.") |
| print("Please install using:") |
| print("pip install git+https://github.com/hassony2/torch_videovision") |
|
|
|
|
| |
| COLORS = { |
| "pink": (242, 116, 223), |
| "cyan": (46, 242, 203), |
| "red": (255, 0, 0), |
| "green": (0, 255, 0), |
| "blue": (0, 0, 255), |
| "yellow": (255, 255, 0), |
| } |
|
|
|
|
| def get_predominant_color(color_key, mode="RGB", alpha=0): |
| assert color_key in COLORS.keys(), f"Unknown color key: {color_key}" |
| if mode == "RGB": |
| return COLORS[color_key] |
| elif mode == "RGBA": |
| return COLORS[color_key] + (alpha,) |
|
|
|
|
| def show_single_image(image: np.ndarray, figsize: tuple = (8, 8), title: str = None, cmap: str = None, ticks=False): |
| """Show a single image.""" |
| fig, ax = plt.subplots(1, 1, figsize=figsize) |
|
|
| if isinstance(image, Image.Image): |
| image = np.asarray(image) |
|
|
| ax.set_title(title) |
| ax.imshow(image, cmap=cmap) |
| |
| if not ticks: |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
|
|
| plt.show() |
|
|
|
|
| def show_grid_of_images( |
| images: np.ndarray, |
| n_cols: int = 4, |
| figsize: tuple = (8, 8), |
| subtitlesize=14, |
| cmap=None, |
| subtitles=None, |
| title=None, |
| save=False, |
| savepath="sample.png", |
| titlesize=20, |
| ysuptitle=0.8, |
| xlabels=None, |
| sizealpha=0.7, |
| show=True, |
| row_labels=None, |
| aspect=None, |
| width_ratios=None, |
| return_as_pil=False, |
| ): |
| """Show a grid of images.""" |
| n_cols = min(n_cols, len(images)) |
|
|
| copy_of_images = images.copy() |
| for i, image in enumerate(copy_of_images): |
| if isinstance(image, Image.Image): |
| image = np.asarray(image) |
| copy_of_images[i] = image |
|
|
| if subtitles is None: |
| subtitles = [None] * len(images) |
|
|
| if xlabels is None: |
| xlabels = [None] * len(images) |
| |
| if row_labels is None: |
| num_rows = int(np.ceil(len(images) / n_cols)) |
| row_labels = [None] * num_rows |
|
|
| n_rows = int(np.ceil(len(images) / n_cols)) |
| fig, axes = plt.subplots( |
| n_rows, n_cols, figsize=figsize, width_ratios=width_ratios, |
| ) |
| if len(images) == 1: |
| axes = np.array([[axes]]) |
| for i, ax in enumerate(axes.flat): |
| if i < len(copy_of_images): |
| if len(copy_of_images[i].shape) == 2 and cmap is None: |
| cmap="gray" |
| ax.imshow(copy_of_images[i], cmap=cmap, aspect=aspect) |
| ax.set_title(subtitles[i], fontsize=subtitlesize) |
| ax.set_xlabel(xlabels[i], fontsize=sizealpha * subtitlesize) |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
| ax.axis('off') |
|
|
| col_idx = i % n_cols |
| if col_idx == 0: |
| ax.set_ylabel(row_labels[i // n_cols], fontsize=sizealpha * subtitlesize) |
|
|
| fig.tight_layout() |
| plt.tight_layout() |
| plt.suptitle(title, y=ysuptitle, fontsize=titlesize) |
|
|
| if return_as_pil: |
| fig.canvas.draw() |
| pil_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) |
| plt.close() |
| return pil_image |
|
|
| |
| if save: |
| plt.savefig(savepath, bbox_inches='tight') |
| plt.close() |
| else: |
| if show: |
| plt.show() |
| plt.close() |
|
|
|
|
| def add_frame_around_image(image, color="red", thickness=3): |
| """Add a frame around the image""" |
| image = image.copy().convert("RGB") |
| draw = ImageDraw.Draw(image) |
| draw.rectangle([0, 0, image.size[0] - 1, image.size[1] - 1], outline=color, width=thickness) |
| return image |
|
|
|
|
| def add_text_to_image(image, text): |
| from PIL import ImageFont |
| from PIL import ImageDraw |
| |
| |
| |
|
|
| draw = ImageDraw.Draw(image) |
| font = ImageFont.load_default() |
| |
| |
| |
|
|
| |
| |
| |
| if np.mean(image) > 200: |
| draw.text((0, 0), text, (0,0,0), font=font) |
| else: |
| draw.text((0, 0), text, (255,255,255), font=font) |
| |
| |
| return image |
|
|
|
|
| def show_keypoint_matches( |
| img1, kp1, img2, kp2, matches, |
| K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0)), |
| choose_matches="random", |
| ): |
| """Displays matches found in the pair of images""" |
| if choose_matches == "random": |
| selected_matches = np.random.choice(matches, K) |
| elif choose_matches == "all": |
| K = len(matches) |
| selected_matches = matches |
| elif choose_matches == "topk": |
| selected_matches = matches[:K] |
| else: |
| raise ValueError(f"Unknown value for choose_matches: {choose_matches}") |
|
|
| |
| cmap = matplotlib.cm.get_cmap('gist_rainbow', K) |
| colors = [[int(x*255) for x in cmap(i)[:3]] for i in np.arange(0,K)] |
| drawMatches_args.update({"matchColor": -1, "singlePointColor": (100, 100, 100)}) |
| |
| img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args) |
| show_single_image( |
| img3, |
| figsize=figsize, |
| title=f"[{choose_matches.upper()}] Selected K = {K} matches between the pair of images.", |
| ) |
| return img3 |
|
|
|
|
| def draw_kps_on_image(image: np.ndarray, kps: np.ndarray, color=COLORS["red"], radius=5, thickness=-1, return_as="PIL"): |
| """ |
| Draw keypoints on image. |
| |
| Args: |
| image: Image to draw keypoints on. |
| kps: Keypoints to draw. Note these should be in (x, y) format. |
| """ |
| if isinstance(image, Image.Image): |
| image = np.asarray(image) |
| if isinstance(color, str): |
| color = PIL.ImageColor.getrgb(color) |
| colors = [color] * len(kps) |
| elif isinstance(color, tuple): |
| colors = [color] * len(kps) |
| elif isinstance(color, list): |
| colors = [PIL.ImageColor.getrgb(c) for c in color] |
| assert len(colors) == len(kps), f"Number of colors ({len(colors)}) must be equal to number of keypoints ({len(kps)})" |
|
|
| for kp, c in zip(kps, colors): |
| image = cv2.circle( |
| image.copy(), (int(kp[0]), int(kp[1])), radius=radius, color=c, thickness=thickness) |
| |
| if return_as == "PIL": |
| return Image.fromarray(image) |
|
|
| return image |
|
|
|
|
| def get_concat_h(im1, im2): |
| """Concatenate two images horizontally""" |
| dst = Image.new('RGB', (im1.width + im2.width, im1.height)) |
| dst.paste(im1, (0, 0)) |
| dst.paste(im2, (im1.width, 0)) |
| return dst |
|
|
|
|
| def get_concat_v(im1, im2): |
| """Concatenate two images vertically""" |
| dst = Image.new('RGB', (im1.width, im1.height + im2.height)) |
| dst.paste(im1, (0, 0)) |
| dst.paste(im2, (0, im1.height)) |
| return dst |
|
|
|
|
| def show_images_with_keypoints(images: list, kps: list, radius=15, color=(0, 220, 220), figsize=(10, 8)): |
| assert len(images) == len(kps) |
|
|
| |
| images_with_kps = [] |
| for i in range(len(images)): |
| img_with_kps = draw_kps_on_image(images[i], kps[i], radius=radius, color=color, return_as="PIL") |
| images_with_kps.append(img_with_kps) |
| |
| |
| show_grid_of_images(images_with_kps, n_cols=len(images), figsize=figsize) |
|
|
|
|
| def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs): |
| try: |
| plt.rcParams.update({ |
| "text.usetex": usetex, |
| "font.family": "serif", |
| |
| "font.size": fontsize, |
| **kwargs, |
| }) |
| if show_sample: |
| plt.figure() |
| plt.title("Sample $y = x^2$") |
| plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o") |
| plt.grid() |
| plt.show() |
| except: |
| print("Failed to setup LaTeX fonts. Proceeding without.") |
| pass |
|
|
|
|
|
|
| def plot_2d_points( |
| list_of_points_2d, |
| colors=None, |
| sizes=None, |
| markers=None, |
| alpha=0.75, |
| h=256, |
| w=256, |
| ax=None, |
| save=True, |
| savepath="test.png", |
| ): |
|
|
| if ax is None: |
| fig, ax = plt.subplots(1, 1) |
| ax.set_xlim([0, w]) |
| ax.set_ylim([0, h]) |
| |
| if sizes is None: |
| sizes = [0.1 for _ in range(len(list_of_points_2d))] |
| if colors is None: |
| colors = ["gray" for _ in range(len(list_of_points_2d))] |
| if markers is None: |
| markers = ["o" for _ in range(len(list_of_points_2d))] |
|
|
| for points_2d, color, s, m in zip(list_of_points_2d, colors, sizes, markers): |
| ax.scatter(points_2d[:, 0], points_2d[:, 1], s=s, alpha=alpha, color=color, marker=m) |
| |
| if save: |
| plt.savefig(savepath, bbox_inches='tight') |
|
|
|
|
| def plot_2d_points_on_image( |
| image, |
| img_alpha=1.0, |
| ax=None, |
| list_of_points_2d=[], |
| scatter_args=dict(), |
| ): |
| if ax is None: |
| fig, ax = plt.subplots(1, 1) |
| ax.imshow(image, alpha=img_alpha) |
| scatter_args["save"] = False |
| plot_2d_points(list_of_points_2d, ax=ax, **scatter_args) |
| |
| |
| ax.set_ylim(ax.get_ylim()[::-1]) |
|
|
|
|
| def compare_landmarks( |
| image, ground_truth_landmarks, v2d, predicted_landmarks, |
| save=False, savepath="compare_landmarks.png", num_kps_to_show=-1, |
| show_matches=True, |
| ): |
|
|
| |
| fig, axes = plt.subplots(1, 3, figsize=(11, 4)) |
| ax = axes[0] |
| plot_2d_points_on_image( |
| image, |
| list_of_points_2d=[ground_truth_landmarks], |
| scatter_args=dict(sizes=[15], colors=["limegreen"]), |
| ax=ax, |
| ) |
| ax.set_title("GT landmarks", fontsize=12) |
| |
| |
| ax = axes[1] |
| plot_2d_points_on_image( |
| image, |
| list_of_points_2d=[v2d, predicted_landmarks], |
| scatter_args=dict(sizes=[0.08, 15], markers=["o", "x"], colors=["royalblue", "red"]), |
| ax=ax, |
| ) |
| ax.set_title("Projection of predicted mesh", fontsize=12) |
| |
| |
| ax = axes[2] |
| plot_2d_points_on_image( |
| image, |
| list_of_points_2d=[ |
| ground_truth_landmarks[:num_kps_to_show], |
| predicted_landmarks[:num_kps_to_show], |
| ], |
| scatter_args=dict(sizes=[15, 15], markers=["o", "x"], colors=["limegreen", "red"]), |
| ax=ax, |
| img_alpha=0.5, |
| ) |
| ax.set_title("GT and predicted landmarks", fontsize=12) |
|
|
| if show_matches: |
| for i in range(num_kps_to_show): |
| x_values = [ground_truth_landmarks[i, 0], predicted_landmarks[i, 0]] |
| y_values = [ground_truth_landmarks[i, 1], predicted_landmarks[i, 1]] |
| ax.plot(x_values, y_values, color="yellow", markersize=1, linewidth=2.) |
|
|
| fig.tight_layout() |
| if save: |
| plt.savefig(savepath, bbox_inches="tight") |
| |
|
|
|
|
| def plot_historgam_values( |
| X, display_vals=False, |
| bins=50, figsize=(8, 5), |
| show_mean=True, |
| xlabel=None, ylabel=None, |
| ax=None, title=None, show=False, |
| **kwargs, |
| ): |
| if ax is None: |
| fig, ax = plt.subplots(1, 1, figsize=figsize) |
|
|
| ax.hist(X, bins=bins, **kwargs) |
| if title is None: |
| title = "Histogram of values" |
| |
| ax.set_xlabel(xlabel) |
| ax.set_ylabel(ylabel) |
| |
| if display_vals: |
| x, counts = np.unique(X, return_counts=True) |
| |
| |
| |
| |
| |
| |
| ax.grid(alpha=0.3) |
| |
| if show_mean: |
| mean = np.mean(X) |
| mean_string = f"$\mu$: {mean:.2f}" |
| ax.set_title(title + f" ({mean_string}) ") |
| else: |
| ax.set_title(title) |
| |
| if not show: |
| return ax |
| else: |
| plt.show() |
|
|
|
|
| """Helper functions for all kinds of 2D/3D visualization""" |
| def bokeh_2d_scatter(x, y, desc, figsize=(700, 700), colors=None, use_nb=False, title="Bokeh scatter plot"): |
| import matplotlib.colors as mcolors |
| from bokeh.plotting import figure, output_file, show, ColumnDataSource |
| from bokeh.models import HoverTool |
| from bokeh.io import output_notebook |
|
|
| if use_nb: |
| output_notebook() |
|
|
| |
| if colors is None: |
| |
| |
| |
| colors = [np.random.choice(["red", "green", "blue", "yellow", "pink", "black", "gray"])] * len(x) |
|
|
| |
| |
|
|
|
|
| |
| source = ColumnDataSource( |
| data=dict( |
| x=x, |
| y=y, |
| desc=desc, |
| color=colors, |
| ) |
| ) |
|
|
| |
| hover = HoverTool( |
| tooltips=[ |
| ("index", "$index"), |
| ("(x, y)", "($x, $y)"), |
| ("Desc", "@desc"), |
| ] |
| ) |
|
|
| p = figure( |
| plot_width=figsize[0], plot_height=figsize[1], tools=[hover], title=title, |
| ) |
| p.circle('x', 'y', size=10, source=source, fill_color="color") |
| show(p) |
|
|
|
|
|
|
|
|
| def bokeh_2d_scatter_new( |
| df, x, y, hue, label, color_column=None, size_col=None, |
| figsize=(650, 600), use_nb=False, title="Bokeh scatter plot", |
| legend_loc="bottom_left", edge_color="black", audio_col=None, |
| ): |
| from bokeh.plotting import figure, output_file, show, ColumnDataSource |
| from bokeh.models import HoverTool |
| from bokeh.io import output_notebook |
|
|
| if use_nb: |
| output_notebook() |
|
|
| assert {x, y, hue, label}.issubset(set(df.keys())) |
|
|
| if isinstance(color_column, str) and color_column in df.keys(): |
| color_column_name = color_column |
| else: |
| import matplotlib.colors as mcolors |
| colors = list(mcolors.BASE_COLORS.keys()) + list(mcolors.TABLEAU_COLORS.values()) |
| |
| colors = itertools.cycle(np.unique(colors)) |
|
|
| hue_to_color = dict() |
| unique_hues = np.unique(df[hue].values) |
| for _hue in unique_hues: |
| hue_to_color[_hue] = next(colors) |
| df["color"] = df[hue].apply(lambda k: hue_to_color[k]) |
| color_column_name = "color" |
| |
| if size_col is not None: |
| assert isinstance(size_col, str) and size_col in df.keys() |
| else: |
| sizes = [10.] * len(df) |
| df["size"] = sizes |
| size_col = "size" |
|
|
| source = ColumnDataSource( |
| dict( |
| x = df[x].values, |
| y = df[y].values, |
| hue = df[hue].values, |
| label = df[label].values, |
| color = df[color_column_name].values, |
| edge_color = [edge_color] * len(df), |
| sizes = df[size_col].values, |
| ) |
| ) |
|
|
| |
| hover = HoverTool( |
| tooltips=[ |
| ("index", "$index"), |
| ("(x, y)", "($x, $y)"), |
| ("Desc", "@label"), |
| ("Cluster", "@hue"), |
| ] |
| ) |
|
|
| p = figure( |
| plot_width=figsize[0], |
| plot_height=figsize[1], |
| tools=["pan","wheel_zoom","box_zoom","save","reset","help"] + [hover], |
| title=title, |
| ) |
| p.circle( |
| 'x', 'y', size="sizes", |
| source=source, fill_color="color", |
| legend_group="hue", line_color="edge_color", |
| ) |
| p.legend.location = legend_loc |
| p.legend.click_policy="hide" |
|
|
|
|
| show(p) |
|
|
| |
| import torch |
| def get_sentence_embedding(model, tokenizer, sentence): |
| encoded = tokenizer.encode_plus(sentence, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| output = model(**encoded) |
| |
| last_hidden_state = output.last_hidden_state |
| assert last_hidden_state.shape[0] == 1 |
| assert last_hidden_state.shape[-1] == 768 |
| |
| |
| sentence_embedding = last_hidden_state[0, 0] |
| |
| return sentence_embedding |
|
|
|
|
| def lighten_color(color, amount=0.5): |
| """ |
| Lightens the given color by multiplying (1-luminosity) by the given amount. |
| Input can be matplotlib color string, hex string, or RGB tuple. |
| |
| Examples: |
| >> lighten_color('g', 0.3) |
| >> lighten_color('#F034A3', 0.6) |
| >> lighten_color((.3,.55,.1), 0.5) |
| """ |
| import matplotlib.colors as mc |
| import colorsys |
| try: |
| c = mc.cnames[color] |
| except: |
| c = color |
| c = colorsys.rgb_to_hls(*mc.to_rgb(c)) |
| return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]) |
|
|
|
|
| def plot_histogram(df, col, ax=None, color="blue", title=None, xlabel=None, **kwargs): |
| if ax is None: |
| fig, ax = plt.subplots(1, 1, figsize=(5, 4)) |
| ax.grid(alpha=0.3) |
| xlabel = col if xlabel is None else xlabel |
| ax.set_xlabel(xlabel) |
| ax.set_ylabel("Frequency") |
| title = f"Historgam of {col}" if title is None else title |
| ax.set_title(title) |
| label = f"Mean: {np.round(df[col].mean(), 1)}" |
| ax.hist(df[col].values, density=False, color=color, edgecolor=lighten_color(color, 0.1), label=label, **kwargs) |
| if "bins" in kwargs: |
| xticks = list(np.arange(kwargs["bins"])[::5]) |
| xticks += list(np.linspace(xticks[-1], int(df[col].max()), 5, dtype=int)) |
| |
| ax.set_xticks(xticks) |
| ax.legend() |
| plt.show() |
|
|
|
|
| def beautify_ax(ax, title=None, titlesize=20, sizealpha=0.7, xlabel=None, ylabel=None): |
| labelsize = sizealpha * titlesize |
| ax.grid(alpha=0.3) |
| ax.set_xlabel(xlabel, fontsize=labelsize) |
| ax.set_ylabel(ylabel, fontsize=labelsize) |
| ax.set_title(title, fontsize=titlesize) |
|
|
|
|
|
|
|
|
| def get_text_features(text: list, model, device, batch_size=16): |
| import clip |
| text_batches = [text[i:i+batch_size] for i in range(0, len(text), batch_size)] |
| text_features = [] |
| model = model.to(device) |
| model = model.eval() |
| for batch in tqdm(text_batches, desc="Getting text features", bar_format="{l_bar}{bar:20}{r_bar}"): |
| batch = clip.tokenize(batch).to(device) |
| with torch.no_grad(): |
| batch_features = model.encode_text(batch) |
| text_features.append(batch_features.cpu().numpy()) |
| text_features = np.concatenate(text_features, axis=0) |
| return text_features |
|
|
|
|
| from sklearn.manifold import TSNE |
| def reduce_dim(X, method="tsne", perplexity=30, n_iter=1000): |
| if method == "tsne": |
| tsne = TSNE( |
| n_components=2, |
| perplexity=perplexity, |
| |
| init='pca', |
| random_state=42, |
| |
| ) |
| Z = tsne.fit_transform(X) |
| elif method == "pca": |
| from sklearn.decomposition import PCA |
| pca = PCA(n_components=2) |
| Z = pca.fit_transform(X) |
| elif method == "umap": |
| import umap |
| reducer = umap.UMAP(random_state=42) |
| Z = reducer.fit_transform(X) |
| else: |
| raise ValueError(f"Unknown method {method}") |
| return Z |
|
|
|
|
| from IPython.display import Video |
| def show_video(video_path): |
| """Show a video in a Jupyter notebook""" |
| assert exists(video_path), f"Video path {video_path} does not exist" |
| |
| |
| return Video(video_path, embed=True, width=480) |
| |
| |
|
|
|
|
|
|
|
|
| def show_single_audio(filepath=None, data=None, rate=None, start=None, end=None, label="Sample audio"): |
| import librosa |
| |
| if filepath is None: |
| assert data is not None and rate is not None, "Either filepath or data and rate must be provided" |
| args = dict(data=data, rate=rate) |
| else: |
| assert data is None and rate is None, "Either filepath or data and rate must be provided" |
| data, rate = librosa.load(filepath) |
| |
| args = dict(data=data, rate=rate) |
| |
| if start is not None and end is not None: |
| start = max(int(start * rate), 0) |
| end = min(int(end * rate), len(data)) |
| else: |
| start = 0 |
| end = len(data) |
| data = data[start:end] |
| args["data"] = data |
|
|
| if label is None: |
| label = "Sample audio" |
|
|
| |
| |
| label_width = 400 |
| label_html = f'<div style="width: {label_width}px; word-wrap: break-word; overflow-wrap: break-word;">{label}</div>' |
| label_widget = HTML(value=label_html) |
| |
| out = widgets.Output() |
| with out: |
| display(Audio(**args)) |
| vbox = VBox([label_widget, out]) |
| return vbox |
|
|
|
|
| def show_single_audio_with_spectrogram(filepath=None, data=None, rate=None, label="Sample audio", figsize=(6, 2)): |
| import librosa |
| if filepath is None: |
| assert data is not None and rate is not None, "Either filepath or data and rate must be provided" |
| else: |
| data, rate = librosa.load(filepath) |
| |
| |
| vbox = show_single_audio(data=data, rate=rate, label=label) |
| |
| width = vbox.children[1].layout.width |
|
|
| |
| spec_out = widgets.Output() |
| D = librosa.stft(data) |
| S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) |
| with spec_out: |
| fig, ax = plt.subplots(figsize=figsize) |
| img = librosa.display.specshow( |
| S_db, |
| ax=ax, |
| x_axis='time', |
| |
| ) |
| |
| |
| |
| |
| vbox.children += (spec_out,) |
| return vbox |
|
|
| def show_spectrogram(audio_path=None, data=None, rate=None, figsize=(6, 2), ax=None, show=True): |
| import librosa |
| if data is None and rate is None: |
| |
| data, rate = librosa.load(audio_path) |
| else: |
| assert audio_path is None, "Either audio_path or data and rate must be provided" |
|
|
| hop_length = 512 |
| D = librosa.stft(data, n_fft=2048, hop_length=hop_length, win_length=2048) |
| S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) |
|
|
| |
| if ax is None: |
| fig, ax = plt.subplots(1, 1, figsize=figsize) |
| im = ax.imshow(S_db, origin='lower', aspect='auto', cmap='inferno') |
|
|
| |
| xticks = ax.get_xticks() |
| time_in_seconds = librosa.frames_to_time(xticks, sr=rate, hop_length=hop_length) |
| ax.set_xticklabels(np.round(time_in_seconds, 1)) |
| ax.set_xlabel('Time') |
| ax.set_yticks([]) |
| if ax is None: |
| plt.close(fig) |
|
|
| |
| spec_out = widgets.Output() |
| with spec_out: |
| display(fig) |
| return spec_out |
|
|
|
|
| def show_single_video_and_spectrogram( |
| video_path, audio_path, |
| label="Sample video", figsize=(6, 2), |
| width=480, |
| show_spec_stats=False, |
| ): |
| import librosa |
| |
| vbox = show_single_video(video_path, label=label, width=width) |
| |
| width = vbox.children[1].layout.width |
|
|
| |
| data, rate = librosa.load(audio_path) |
| hop_length = 512 |
| D = librosa.stft(data, n_fft=2048, hop_length=hop_length, win_length=2048) |
| S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) |
|
|
| |
| fig, ax = plt.subplots(1, 1, figsize=figsize) |
| im = ax.imshow(S_db, origin='lower', aspect='auto', cmap='inferno') |
|
|
| |
| xticks = ax.get_xticks() |
| time_in_seconds = librosa.frames_to_time(xticks, sr=rate, hop_length=hop_length) |
| ax.set_xticklabels(np.round(time_in_seconds, 1)) |
| ax.set_xlabel('Time') |
| ax.set_yticks([]) |
| plt.close(fig) |
|
|
| |
| spec_out = widgets.Output() |
| with spec_out: |
| display(fig) |
| vbox.children += (spec_out,) |
|
|
| if show_spec_stats: |
| |
| eps = 1e-5 |
| S_db_normalized = (S_db - S_db.mean(axis=1)[:, None]) / (S_db.std(axis=1)[:, None] + eps) |
| S_db_over_time = S_db_normalized.sum(axis=0) |
| |
| fig, ax = plt.subplots(1, 1, figsize=(6, 2)) |
| |
| ax.grid(alpha=0.5) |
| x = np.arange(len(S_db_over_time)) |
| x = librosa.frames_to_time(x, sr=rate, hop_length=hop_length) |
| x = np.round(x, 1) |
| ax.plot(x, S_db_over_time) |
| ax.set_xlabel('Time') |
| ax.set_yticks([]) |
| plt.close(fig) |
| plot_out = widgets.Output() |
| with plot_out: |
| display(fig) |
| vbox.children += (plot_out,) |
|
|
| return vbox |
|
|
|
|
| def show_single_spectrogram( |
| filepath=None, |
| data=None, |
| rate=None, |
| start=None, |
| end=None, |
| ax=None, |
| label="Sample spectrogram", |
| figsize=(6, 2), |
| xlabel="Time", |
| ): |
| import librosa |
| |
| if filepath is None: |
| assert data is not None and rate is not None, "Either filepath or data and rate must be provided" |
| else: |
| rate = 22050 |
| offset = start or 0 |
| clip_duration = end - start if end is not None else None |
| data, rate = librosa.load(filepath, sr=rate, offset=offset, duration=clip_duration) |
| |
| |
| |
| |
| |
| |
| spec_out = widgets.Output() |
| D = librosa.stft(data) |
| S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) |
|
|
| if ax is None: |
| fig, ax = plt.subplots(figsize=figsize) |
|
|
| with spec_out: |
| img = librosa.display.specshow( |
| S_db, |
| ax=ax, |
| x_axis='time', |
| sr=rate, |
| |
| ) |
| ax.set_xlabel(xlabel) |
| ax.margins(x=0) |
| plt.subplots_adjust(wspace=0, hspace=0) |
|
|
| |
| |
| |
| |
| vbox = VBox([spec_out]) |
| return vbox |
| |
|
|
|
|
| |
| def show_single_video(filepath, label="Sample video", width=480, fix_resolution=True): |
| |
| if label is None: |
| label = "Sample video" |
| |
| height = None |
| if fix_resolution: |
| aspect_ratio = 16. / 9. |
| height = int(width * (1/ aspect_ratio)) |
|
|
| |
| |
| label_width = width - 20 |
| label_html = f'<div style="width: {label_width}px; word-wrap: break-word; overflow-wrap: break-word; line-height: 1.2; margin: 0; padding: 0;">{label}</div>' |
| label_widget = HTML(value=label_html) |
| |
| out = widgets.Output() |
| with out: |
| display(Video(filepath, embed=True, width=width, height=height)) |
| |
| vbox = VBox([out, label_widget]) |
| return vbox |
|
|
|
|
|
|
| def color_text(text: str, color: str): |
| from ipywidgets import widgets |
| htmlWidget = widgets.HTML(value = f"<font color='{color}'>{text}") |
| return htmlWidget |
|
|
|
|
| def show_colored_text(text: list, colors: list): |
| """ |
| Display Label() widgets with different colors. |
| |
| Args: |
| text: list of strings |
| colors: list of colors |
| |
| The output must be <text[0] with colors[0] <text[1] with colors[1] ... |
| |
| Returns: |
| VBox() widget |
| """ |
| from ipywidgets import Label |
| from IPython.display import display |
| from ipywidgets import VBox |
| from ipywidgets import Layout |
|
|
| |
| |
| labels = [color_text(t, c) for t, c in zip(text, colors)] |
| |
| |
| hbox = HBox(labels) |
| display(hbox) |
|
|
|
|
|
|
|
|
| def show_single_image_sequence( |
| filepath, n_frames=4, label="Sample image sequence", |
| width=480, fix_resolution=True, max_width=1000, reverse=False, |
| ): |
| import decord |
| decord.bridge.set_bridge('native') |
| |
| |
| |
| height = None |
| if fix_resolution: |
| aspect_ratio = 16. / 9. |
| height = int(width * (1/ aspect_ratio)) |
| |
| if label is not None: |
| |
| |
| label_width = min(max_width, width) - 20 |
| label_html = f'<div style="width: {label_width}px; word-wrap: break-word; overflow-wrap: break-word;">{label}</div>' |
| label_widget = HTML(value=label) |
| |
| out = widgets.Output() |
| |
| |
| from decord import VideoReader |
| vr = VideoReader(filepath, num_threads=1) |
| n_frames = min(n_frames, len(vr)) |
| |
| indices = np.linspace(0, len(vr)-1, n_frames) |
| if reverse: |
| indices = indices[::-1] |
| frames = vr.get_batch(indices).asnumpy() |
| frames = [Image.fromarray(f) for f in frames] |
| |
| canvas = concat_images_with_border(frames) |
|
|
| |
| if canvas.size[0] > max_width: |
| canvas = canvas.resize((max_width, int(max_width * canvas.size[1] / canvas.size[0]))) |
|
|
| with out: |
| display(canvas) |
| |
| if label is not None: |
| vbox = VBox([label_widget, out]) |
| else: |
| vbox = out |
| return vbox |
|
|
|
|
| def show_grid_of_audio(files, starts=None, ends=None, labels=None, ncols=None, show_spec=False): |
| |
| for f in files: |
| assert os.path.exists(f), f"File {f} does not exist." |
|
|
| if labels is None: |
| labels = [None] * len(files) |
| |
| if starts is None: |
| starts = [None] * len(files) |
| |
| if ends is None: |
| ends = [None] * len(files) |
|
|
| assert len(files) == len(labels) |
| |
| if ncols is None: |
| ncols = 3 |
| nfiles = len(files) |
| nrows = nfiles // ncols + (nfiles % ncols != 0) |
| |
| |
| for i in range(nrows): |
| row_hbox = [] |
| for j in range(ncols): |
| idx = i * ncols + j |
| |
| |
| if idx < len(files): |
| file, label = files[idx], labels[idx] |
| start, end = starts[idx], ends[idx] |
| vbox = show_single_audio( |
| filepath=file, label=label, start=start, end=end |
| ) |
| if show_spec: |
| spec_box = show_spectrogram(file, figsize=(3.6, 1)) |
| |
| vbox.children += (spec_box,) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| row_hbox.append(vbox) |
| row_hbox = HBox(row_hbox) |
| display(row_hbox) |
|
|
|
|
| def show_grid_of_videos( |
| files, |
| cut=False, |
| starts=None, |
| ends=None, |
| labels=None, |
| ncols=None, |
| width_overflow=False, |
| show_spec=False, |
| width_of_screen=1000, |
| ): |
| from moviepy.editor import VideoFileClip |
| |
| for f in files: |
| assert os.path.exists(f), f"File {f} does not exist." |
|
|
| if labels is None: |
| labels = [None] * len(files) |
| if starts is not None and ends is not None: |
| cut = True |
| if starts is None: |
| starts = [None] * len(files) |
| if ends is None: |
| ends = [None] * len(files) |
|
|
| assert len(files) == len(labels) == len(starts) == len(ends) |
| |
| |
| if cut: |
| cut_files = [] |
| for i, f in enumerate(files): |
| start, end = starts[i], ends[i] |
| |
| tmp_f = os.path.join(os.path.expanduser("~"), f"tmp/clip_{i}.mp4") |
| cut_files.append(tmp_f) |
| |
| video = VideoFileClip(f) |
| start = 0 if start is None else start |
| end = video.duration-1 if end is None else end |
| |
| video.subclip(start, end).write_videofile(tmp_f, logger=None, verbose=False) |
| files = cut_files |
|
|
| if ncols is None: |
| ncols = 3 |
| width_of_screen = 1000 |
|
|
| |
| if not width_overflow: |
| width_of_single_video = width_of_screen // ncols |
| else: |
| width_of_single_video = 280 |
|
|
| nfiles = len(files) |
| nrows = nfiles // ncols + (nfiles % ncols != 0) |
| |
| |
| for i in range(nrows): |
| row_hbox = [] |
| for j in range(ncols): |
| idx = i * ncols + j |
| |
| |
| if idx < len(files): |
| file, label = files[idx], labels[idx] |
| if not show_spec: |
| vbox = show_single_video(file, label, width_of_single_video) |
| else: |
| vbox = show_single_video_and_spectrogram(file, file, width=width_of_single_video, label=label) |
| row_hbox.append(vbox) |
| row_hbox = HBox(row_hbox) |
| display(row_hbox) |
|
|
|
|
| def convert_video_to_gif(file, tmp_f, reverse=False): |
| from moviepy.editor import VideoFileClip, vfx |
| video = VideoFileClip(file) |
| if reverse: |
| video = video.fx(vfx.time_mirror) |
| video.write_gif(tmp_f, logger=None, verbose=False) |
| return tmp_f |
|
|
|
|
|
|
| def show_grid_of_gifs(files, labels=None, ncols=None, width_of_screen=1000): |
| for f in files: |
| assert os.path.exists(f), f"File {f} does not exist." |
|
|
| if labels is None: |
| labels = [None] * len(files) |
|
|
| if ncols is None: |
| ncols = 3 |
| nfiles = len(files) |
| nrows = nfiles // ncols + (nfiles % ncols != 0) |
| |
| |
| width_of_single_video = width_of_screen // ncols |
| |
| for i in range(nrows): |
| row_hbox = [] |
| for j in range(ncols): |
| idx = i * ncols + j |
| |
| if idx < len(files): |
| file, label = files[idx], labels[idx] |
| vbox = show_single_gif(file, label, width_of_single_video) |
| row_hbox.append(vbox) |
| row_hbox = HBox(row_hbox) |
| display(row_hbox) |
|
|
|
|
| def show_single_gif(file, label, width_of_single_video): |
| from ipywidgets import Image as WidgetImage, VBox, Label |
| |
| |
| image_widget = WidgetImage(value=open(file, 'rb').read(), width=width_of_single_video) |
| |
| |
| if label is not None: |
| label_widget = Label(value=str(label)) |
| return VBox([image_widget, label_widget]) |
| else: |
| return image_widget |
|
|
|
|
| def show_grid_of_videos_as_gifs(files, labels=None, reverse=None,ncols=None, width_of_screen=1000): |
| reverse = [False] * len(files) if reverse is None else reverse |
| |
| |
| tmp_dir = os.path.join(os.path.expanduser("~"), "tmp") |
| os.makedirs(tmp_dir, exist_ok=True) |
| tmp_files = [] |
| for i, f in enumerate(files): |
| tmp_f = os.path.join(tmp_dir, f"clip_{i}.gif") |
| convert_video_to_gif(f, tmp_f, reverse=reverse[i]) |
| tmp_files.append(tmp_f) |
|
|
| |
| show_grid_of_gifs(tmp_files, labels, ncols, width_of_screen) |
| |
| return tmp_files |
|
|
|
|
| def show_grid_of_image_sequences( |
| files, |
| cut=False, |
| starts=None, |
| ends=None, |
| labels=None, |
| ncols=None, |
| width_overflow=False, |
| show_spec=False, |
| width_of_screen=1200, |
| n_frames=4, |
| ): |
| from moviepy.editor import VideoFileClip |
| |
| for f in files: |
| assert os.path.exists(f), f"File {f} does not exist." |
|
|
| if labels is None: |
| labels = [None] * len(files) |
| if starts is not None and ends is not None: |
| cut = True |
| if starts is None: |
| starts = [None] * len(files) |
| if ends is None: |
| ends = [None] * len(files) |
|
|
| assert len(files) == len(labels) == len(starts) == len(ends) |
| |
| |
| if cut: |
| cut_files = [] |
| for i, f in enumerate(files): |
| start, end = starts[i], ends[i] |
| |
| tmp_f = os.path.join(os.path.expanduser("~"), f"tmp/clip_{i}.mp4") |
| cut_files.append(tmp_f) |
| |
| video = VideoFileClip(f) |
| start = 0 if start is None else start |
| end = video.duration-1 if end is None else end |
| |
| video.subclip(start, end).write_videofile(tmp_f, logger=None, verbose=False) |
| files = cut_files |
|
|
| if ncols is None: |
| ncols = 3 |
| width_of_screen = 1000 |
|
|
| |
| if not width_overflow: |
| width_of_single_video = width_of_screen // ncols |
| else: |
| width_of_single_video = 280 |
|
|
| nfiles = len(files) |
| nrows = nfiles // ncols + (nfiles % ncols != 0) |
| |
| |
| for i in range(nrows): |
| row_hbox = [] |
| for j in range(ncols): |
| idx = i * ncols + j |
| |
| |
| if idx < len(files): |
| file, label = files[idx], labels[idx] |
| if not show_spec: |
| vbox = show_single_image_sequence( |
| file, n_frames, label, max_width=(width_of_single_video * n_frames), |
| ) |
| else: |
| raise NotImplementedError |
| row_hbox.append(vbox) |
| row_hbox = HBox(row_hbox) |
| display(row_hbox) |
|
|
|
|
| def preview_video(fp, label="Sample video frames", mode="uniform", frames_to_show=6): |
| from decord import VideoReader |
| |
| assert exists(fp), f"Video does not exist at {fp}" |
| vr = VideoReader(fp) |
|
|
| nfs = len(vr) |
| fps = vr.get_avg_fps() |
| dur = nfs / fps |
| |
| if mode == "all": |
| frame_indices = np.arange(nfs) |
| elif mode == "uniform": |
| frame_indices = np.linspace(0, nfs - 1, frames_to_show, dtype=int) |
| elif mode == "random": |
| frame_indices = np.random.randint(0, nfs - 1, replace=False) |
| frame_indices = sorted(frame_indices) |
| else: |
| raise ValueError(f"Unknown frame viewing mode {mode}.") |
| |
| |
| images = vr.get_batch(frame_indices).asnumpy() |
| show_grid_of_images(images, n_cols=len(frame_indices), title=label, figsize=(12, 2.3), titlesize=10) |
|
|
|
|
| def preview_multiple_videos(fps, labels, mode="uniform", frames_to_show=6): |
| for fp in fps: |
| assert exists(fp), f"Video does not exist at {fp}" |
| |
| for fp, label in zip(fps, labels): |
| preview_video(fp, label, mode=mode, frames_to_show=frames_to_show) |
|
|
|
|
|
|
| def show_small_clips_in_a_video( |
| video_path, |
| clip_segments: list, |
| width=360, |
| labels=None, |
| show_spec=False, |
| resize=False, |
| ): |
| try: |
| from moviepy.editor import VideoFileClip |
| except: |
| from moviepy import VideoFileClip |
|
|
| from ipywidgets import Layout |
|
|
| video = VideoFileClip(video_path) |
| |
| if resize: |
| |
| print("Resizing the video to width", width) |
| video = video.resize(width=width) |
| |
| if labels is None: |
| labels = [ |
| f"Clip {i+1} [{clip_segments[i][0]} : {clip_segments[i][1]}]" for i in range(len(clip_segments)) |
| ] |
| else: |
| assert len(labels) == len(clip_segments) |
| |
| tmp_dir = os.path.join(os.path.expanduser("~"), "tmp") |
| os.makedirs(tmp_dir, exist_ok=True) |
| tmp_clippaths = [f"{tmp_dir}/clip_{i}.mp4" for i in range(len(clip_segments))] |
| |
| iterator = tqdm_iterator(zip(clip_segments, tmp_clippaths), total=len(clip_segments), desc="Preparing clips") |
| clips = [ |
| video.subclip(x, y).write_videofile(f, logger=None, verbose=False) \ |
| for (x, y), f in iterator |
| ] |
| |
| hbox = [] |
| for i in range(len(clips)): |
| |
| |
| vbox = widgets.Output() |
| with vbox: |
| if show_spec: |
| display( |
| show_single_video_and_spectrogram( |
| tmp_clippaths[i], tmp_clippaths[i], |
| width=width, figsize=(4.4, 1.5), |
| ) |
| ) |
| else: |
| display(Video(tmp_clippaths[i], embed=True, width=width)) |
| |
| display(Label(labels[i], layout=Layout(margin="-8px 0px 0px 0px"))) |
| |
| |
| hbox.append(vbox) |
| hbox = HBox(hbox) |
| display(hbox) |
|
|
|
|
| def show_single_video_and_audio( |
| video_path, audio_path, label="Sample video and audio", |
| start=None, end=None, width=360, sr=44100, show=True, |
| ): |
| from moviepy.editor import VideoFileClip |
| import librosa |
|
|
| |
| video = VideoFileClip(video_path) |
| video_args = {"embed": True, "width": width} |
| filepath = video_path |
|
|
| |
| audio_waveform, sr = librosa.load(audio_path, sr=sr) |
| audio_args = {"data": audio_waveform, "rate": sr} |
|
|
| if start is not None and end is not None: |
| |
| |
| tmp_dir = os.path.join(os.path.expanduser("~"), "tmp") |
| clip_path = os.path.join(tmp_dir, "clip_sample.mp4") |
| video.subclip(start, end).write_videofile(clip_path, logger=None, verbose=False) |
| filepath = clip_path |
| |
| |
| audio_waveform = audio_waveform[int(start * sr): int(end * sr)] |
| audio_args["data"] = audio_waveform |
|
|
| out = widgets.Output() |
| with out: |
| label_text = f"{label} [{start} : {end}]" |
| |
| label_width = width - 20 |
| label_html = f'<div style="width: {label_width}px; word-wrap: break-word; overflow-wrap: break-word; line-height: 1.2; margin: 0; padding: 0;">{label_text}</div>' |
| label_widget = HTML(value=label_html) |
| display(label_widget) |
| display(Video(filepath, **video_args)) |
| display(Audio(**audio_args)) |
| |
| if show: |
| display(out) |
| else: |
| return out |
|
|
|
|
| def plot_waveform(waveform, sample_rate, figsize=(10, 2), ax=None, skip=100, show=True, title=None): |
| if isinstance(waveform, torch.Tensor): |
| waveform = waveform.numpy() |
| |
| time_axis = torch.arange(0, len(waveform)) / sample_rate |
| waveform = waveform[::skip] |
| time_axis = time_axis[::skip] |
|
|
| if len(waveform.shape) == 1: |
| num_channels = 1 |
| num_frames = waveform.shape[0] |
| waveform = waveform.reshape(1, num_frames) |
| elif len(waveform.shape) == 2: |
| num_channels, num_frames = waveform.shape |
| else: |
| raise ValueError(f"Waveform has invalid shape {waveform.shape}") |
| |
| if ax is None: |
| figure, axes = plt.subplots(num_channels, 1, figsize=figsize) |
| if num_channels == 1: |
| axes = [axes] |
| for c in range(num_channels): |
| axes[c].plot(time_axis, waveform[c], linewidth=1) |
| axes[c].grid(True) |
| if num_channels > 1: |
| axes[c].set_ylabel(f"Channel {c+1}") |
| figure.suptitle(title) |
| else: |
| assert num_channels == 1 |
| ax.plot(time_axis, waveform[0], linewidth=1) |
| ax.grid(True) |
| |
| |
| |
| ax.set_ylim(-0.05, 0.05) |
| |
| if show: |
| plt.show(block=False) |
|
|
|
|
| def show_waveform_as_image(waveform, sr=16000): |
| """Plots a waveform as plt fig and converts into PIL.Image""" |
| fig, ax = plt.subplots(figsize=(10, 2)) |
| plot_waveform(waveform, sr, ax=ax, show=False) |
| fig.canvas.draw() |
| img = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) |
| plt.close(fig) |
| return img |
|
|
|
|
| def plot_raw_audio_signal_with_markings(signal: np.ndarray, markings: list, |
| title: str = 'Raw audio signal with markings', |
| figsize: tuple = (23, 4), |
| ): |
|
|
| plt.figure(figsize=figsize) |
| plt.grid() |
|
|
| plt.plot(signal) |
| for value in markings: |
| plt.axvline(x=value, c='red') |
| plt.xlabel('Time') |
| plt.title(title) |
|
|
| plt.show() |
| plt.close() |
|
|
|
|
| def get_concat_h(im1, im2): |
| """Concatenate two images horizontally""" |
| dst = Image.new('RGB', (im1.width + im2.width, im1.height)) |
| dst.paste(im1, (0, 0)) |
| dst.paste(im2, (im1.width, 0)) |
| return dst |
|
|
|
|
| def concat_images(images): |
| im1 = images[0] |
| dst = Image.new('RGB', (sum([im.width for im in images]), im1.height)) |
| start_width = 0 |
| for i, im in enumerate(images): |
| dst.paste(im, (start_width, 0)) |
| start_width += im.width |
| return dst |
|
|
|
|
| def concat_images_with_border(images, border_width=5, border_color="white"): |
| im1 = images[0] |
| dst = Image.new('RGB', (sum([im.width for im in images]) + (len(images) - 1) * border_width, im1.height), border_color) |
| start_width = 0 |
| uniform_height = im1.height |
| for i, im in enumerate(images): |
| |
| |
| dst.paste(im, (start_width, 0)) |
| start_width += im.width + border_width |
| return dst |
|
|
|
|
| def concat_images_vertically(images): |
| im1 = images[0] |
| dst = Image.new('RGB', (im1.width, sum([im.height for im in images]))) |
| start_height = 0 |
| for i, im in enumerate(images): |
| dst.paste(im, (0, start_height)) |
| start_height += im.height |
| return dst |
|
|
|
|
| def concat_images_vertically_with_border(images, border_width=5, border_color="white"): |
| im1 = images[0] |
| dst = Image.new('RGB', (im1.width, sum([im.height for im in images]) + (len(images) - 1) * border_width), border_color) |
| start_height = 0 |
| for i, im in enumerate(images): |
| dst.paste(im, (0, start_height)) |
| start_height += im.height + border_width |
| return dst |
|
|
|
|
| def get_concat_v(im1, im2): |
| """Concatenate two images vertically""" |
| dst = Image.new('RGB', (im1.width, im1.height + im2.height)) |
| dst.paste(im1, (0, 0)) |
| dst.paste(im2, (0, im1.height)) |
| return dst |
|
|
|
|
| def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs): |
| try: |
| plt.rcParams.update({ |
| "text.usetex": usetex, |
| "font.family": "serif", |
| "font.serif": ["Computer Modern Roman"], |
| "font.size": fontsize, |
| **kwargs, |
| }) |
| if show_sample: |
| plt.figure() |
| plt.title("Sample $y = x^2$") |
| plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o") |
| plt.grid() |
| plt.show() |
| except: |
| print("Failed to setup LaTeX fonts. Proceeding without.") |
| pass |
|
|
|
|
| def get_colors(num_colors, palette="jet"): |
| cmap = plt.get_cmap(palette) |
| colors = [cmap(i) for i in np.linspace(0, 1, num_colors)] |
| return colors |
|
|
|
|
| def add_box_on_image(image, bbox, color="red", thickness=3, resized=False, fillcolor=None, fillalpha=0.2): |
| """ |
| Adds bounding box on image. |
| |
| Args: |
| image (PIL.Image): image |
| bbox (list): [xmin, ymin, xmax, ymax] |
| color: - |
| thickness: - |
| """ |
| image = image.copy().convert("RGB") |
| |
| color = PIL.ImageColor.getrgb(color) |
| |
| |
| if fillcolor is not None: |
| if isinstance(fillcolor, str): |
| fillcolor = PIL.ImageColor.getrgb(fillcolor) |
| fillcolor= fillcolor + (int(fillalpha * 255),) |
| elif isinstance(fillcolor, tuple): |
| if len(fillcolor) == 3: |
| fillcolor= fillcolor + (int(fillalpha * 255),) |
| else: |
| pass |
|
|
| |
| draw = ImageDraw.Draw(image, "RGBA") |
|
|
| |
| draw.rectangle(bbox, outline=color, width=thickness, fill=fillcolor) |
|
|
| |
| new_width, new_height = (320, 240) |
| if resized: |
| image = image.resize((new_width, new_height)) |
|
|
| return image |
|
|
|
|
| def add_multiple_boxes_on_image(image, bboxes, colors=None, thickness=3, resized=False, fillcolor=None, fillalpha=0.2): |
| image = image.copy().convert("RGB") |
| if colors is None: |
| colors = ["red"] * len(bboxes) |
| for bbox, color in zip(bboxes, colors): |
| image = add_box_on_image(image, bbox, color, thickness, resized, fillcolor, fillalpha) |
| return image |
|
|
|
|
| def colorize_mask(mask, color="red"): |
| |
| color = PIL.ImageColor.getrgb(color) |
| mask = ImageOps.colorize(mask, (0, 0, 0, 0), color) |
| return mask |
|
|
|
|
| def convert_mask_to_image(mask, threshold=0.5): |
| """Converts a numpy array or torch tensor between [0, 1] to a PIL image.""" |
| binary_array = (mask > threshold).astype(np.uint8) * 255 |
| binary_array = np.clip(binary_array, 0, 255) |
| binary_image = Image.fromarray(binary_array, mode="L") |
| return binary_image |
|
|
|
|
| def add_mask_on_image(image: Image, mask: Image, color="green", alpha=0.5): |
|
|
| if isinstance(mask, np.ndarray): |
| mask = convert_mask_to_image(mask) |
|
|
| image = image.copy() |
| mask = mask.copy() |
|
|
| |
| if isinstance(color, str): |
| color = PIL.ImageColor.getrgb(color) |
| |
| mask = ImageOps.colorize(mask, (0, 0, 0, 0), color) |
|
|
| mask = mask.convert("RGB") |
| assert (mask.size == image.size) |
| assert (mask.mode == image.mode) |
|
|
| |
| blended_image = Image.blend(image, mask, alpha) |
| return blended_image |
|
|
|
|
| def blend_images(img1, img2, alpha=0.5): |
| |
| img1 = img1.convert("RGBA") |
| img2 = img2.convert("RGBA") |
| alpha_blended = Image.blend(img1, img2, alpha=alpha) |
| |
| alpha_blended = alpha_blended.convert("RGB") |
| return alpha_blended |
|
|
|
|
| def visualize_youtube_clip( |
| youtube_id, st, et, label="", |
| show_spec=False, |
| video_width=360, video_height=240, |
| ): |
| import librosa |
| |
| url = f"https://www.youtube.com/embed/{youtube_id}?start={int(st)}&end={int(et)}" |
| video_html_code = f""" |
| <iframe height="{video_height}" width="{video_width}" src="{url}" frameborder="0" allowfullscreen></iframe> |
| """ |
| label_html_code = f"""<b>Caption</b>: {label} <br> <b>Time</b>: {st} to {et}""" |
| |
| |
| label = widgets.HTML(label_html_code) |
| video = widgets.HTML(video_html_code) |
| |
| if show_spec: |
| import pytube |
| import base64 |
| from io import BytesIO |
| from moviepy.video.io.VideoFileClip import VideoFileClip |
| from moviepy.audio.io.AudioFileClip import AudioFileClip |
|
|
| |
| video_url = f"https://www.youtube.com/watch?v={youtube_id}" |
| yt = pytube.YouTube(video_url) |
| |
| audio_stream = yt.streams.filter(only_audio=True).first() |
|
|
| |
| |
| audio_stream.download(output_path='/tmp', filename='sample.mp4') |
| |
| audio_clip = AudioFileClip("/tmp/sample.mp4") |
| audio_subclip = audio_clip.subclip(st, et) |
| sr = audio_subclip.fps |
| y = audio_subclip.to_soundarray().mean(axis=1) |
| audio_subclip.close() |
| audio_clip.close() |
| |
| |
| S_db = librosa.power_to_db(librosa.feature.melspectrogram(y, sr=sr), ref=np.max) |
| |
| width = video_width / plt.rcParams["figure.dpi"] + 0.63 |
| height = video_height / plt.rcParams["figure.dpi"] |
| out = widgets.Output() |
| with out: |
| fig, ax = plt.subplots(figsize=(width, height)) |
| librosa.display.specshow(S_db, sr=sr, x_axis='time', ax=ax) |
| ax.set_ylabel("Frequency (Hz)") |
| else: |
| out = widgets.Output() |
| |
| vbox = widgets.VBox([label, video, out]) |
|
|
| return vbox |
| |
|
|
| def visualize_pair_of_youtube_clips(clip_a, clip_b): |
| yt_id_a = clip_a["youtube_id"] |
| label_a = clip_a["sentence"] |
| st_a, et_a = clip_a["time"] |
| |
| yt_id_b = clip_b["youtube_id"] |
| label_b = clip_b["sentence"] |
| st_b, et_b = clip_b["time"] |
| |
| |
| clip_a = visualize_youtube_clip(yt_id_a, st_a, et_a, label_a, show_spec=True) |
| |
| |
| |
| |
| clip_b = visualize_youtube_clip(yt_id_b, st_b, et_b, label_b, show_spec=True) |
| |
| |
| |
|
|
| hbox = HBox([ |
| clip_a, clip_b |
| ]) |
| display(hbox) |
| |
|
|
| def plot_1d(x: np.ndarray, figsize=(6, 2), title=None, xlabel=None, ylabel=None, show=True, **kwargs): |
| assert (x.ndim == 1) |
| fig, ax = plt.subplots(figsize=figsize) |
| ax.grid(alpha=0.3) |
| ax.set_title(title) |
| ax.set_xlabel(xlabel) |
| ax.set_ylabel(ylabel) |
| ax.plot(np.arange(len(x)), x, **kwargs) |
| if show: |
| plt.show() |
| else: |
| plt.close() |
| return fig |
|
|
|
|
|
|
| def make_grid(cols,rows): |
| import streamlit as st |
| grid = [0]*cols |
| for i in range(cols): |
| with st.container(): |
| grid[i] = st.columns(rows) |
| return grid |
|
|
|
|
| def display_clip(video_path, stime, etime, label=None): |
| """Displays clip at index i.""" |
| assert exists(video_path), f"Video does not exist at {video_path}" |
| display( |
| show_small_clips_in_a_video( |
| video_path, [(stime, etime)], labels=[label], |
| ), |
| ) |
|
|
|
|
| def countplot(df, column, title=None, rotation=90, ylabel="Count", figsize=(8, 5), ax=None, show=True, show_counts=False): |
| |
| if ax is None: |
| fig, ax = plt.subplots(figsize=figsize) |
|
|
| ax.grid(alpha=0.4) |
| ax.set_xlabel(column) |
| ax.set_ylabel(ylabel) |
| ax.set_title(title) |
| |
| data = dict(df[column].value_counts()) |
| |
| categories = list(data.keys()) |
| counts = list(data.values()) |
|
|
| |
| ax.bar(categories, counts) |
| ax.set_xticklabels(categories, rotation=rotation) |
| |
| |
| if show_counts: |
| max_v = max(counts) |
| for i, v in enumerate(counts): |
| delta = 0.01 * max_v |
| ax.text(i, v + delta, str(v), ha="center") |
| |
| if show: |
| plt.show() |
|
|
|
|
| def get_linspace_colors(cmap_name='viridis', num_colors = 10): |
| import matplotlib.colors as mcolors |
|
|
| |
| cmap = plt.cm.get_cmap(cmap_name) |
|
|
| |
| if num_colors == 1: |
| indices = [0.5] |
| elif num_colors == 2: |
| indices = [0.1, 0.9] |
| else: |
| gap = 1 / (num_colors) |
| indices = np.arange(0, 1, gap) |
|
|
| |
| colors = [mcolors.to_hex(cmap(idx)) for idx in indices] |
| |
| return colors |
|
|
|
|
| def hex_to_rgb(colors): |
| from PIL import ImageColor |
| return [ImageColor.getcolor(c, "RGB") for c in colors] |
|
|
|
|
| def plot_audio_feature(times, feature, feature_label="Feature", xlabel="Time", figsize=(20, 2)): |
| fig, ax = plt.subplots(1, 1, figsize=figsize) |
| ax.grid(alpha=0.4) |
| ax.set_xlabel(xlabel) |
| ax.set_ylabel(feature_label) |
| ax.set_yticks([]) |
| |
| ax.plot(times, feature, '--', linewidth=0.5) |
| plt.show() |
|
|
|
|
|
|
| def compute_rms(y, frame_length=512): |
| import librosa |
| rms = librosa.feature.rms(y=y, frame_length=frame_length)[0] |
| times = librosa.samples_to_time(frame_length * np.arange(len(rms))) |
| return times, rms |
|
|
|
|
| def plot_audio_features(path, label, show=True, show_video=True, features=["rms"], frame_length=512, figsize=(5, 2), return_features=False): |
| import librosa |
| |
| y, sr = librosa.load(path) |
| |
| |
| if show_video: |
| if show: |
| display( |
| show_single_video_and_spectrogram( |
| path, path, label=label, figsize=figsize, |
| width=410, |
| ) |
| ) |
| else: |
| if show: |
| |
| display( |
| show_single_audio_with_spectrogram(path, label=label, figsize=figsize) |
| ) |
|
|
| feature_data = dict() |
| for f in features: |
| fn = eval(f"compute_{f}") |
| args = dict(y=y, frame_length=frame_length) |
| xvals, yvals = fn(**args) |
| feature_data[f] = (xvals, yvals) |
| |
| if show: |
| display( |
| plot_audio_feature( |
| xvals, yvals, feature_label=f.upper(), figsize=(figsize[0] - 0.25, figsize[1]), |
| ) |
| ) |
| |
| if return_features: |
| return feature_data |
|
|
|
|
| def rescale_frame(frame, scale=1.): |
| """Rescales a frame by a factor of scale.""" |
| return frame.resize((int(frame.width * scale), int(frame.height * scale))) |
|
|
|
|
| def save_gif(images, path, duration=None, fps=30): |
| import imageio |
| images = [np.asarray(image) for image in images] |
| if fps is not None: |
| imageio.mimsave(path, images, fps=fps) |
| else: |
| assert duration is not None |
| imageio.mimsave(path, images, duration=duration) |
|
|
|
|
| def show_subsampled_frames(frames, n_show, figsize=(15, 3), as_canvas=True): |
| indices = np.arange(len(frames)) |
| indices = np.linspace(0, len(frames) - 1, n_show, dtype=int) |
| show_frames = [frames[i] for i in indices] |
| if as_canvas: |
| return concat_images(show_frames) |
| else: |
| show_grid_of_images(show_frames, n_cols=n_show, figsize=figsize, subtitles=indices) |
|
|
|
|
| def tensor_to_heatmap(x, scale=True, cmap="viridis", flip_vertically=False): |
| import PIL |
| |
| if isinstance(x, torch.Tensor): |
| x = x.numpy() |
| |
| if scale: |
| x = (x - x.min()) / (x.max() - x.min()) |
| |
| cm = plt.get_cmap(cmap) |
| if flip_vertically: |
| x = np.flip(x, axis=0) |
| x = cm(x) |
| x = (x * 255).astype(np.uint8) |
| if x.shape[-1] == 3: |
| x = PIL.Image.fromarray(x, mode="RGB") |
| elif x.shape[-1] == 4: |
| x = PIL.Image.fromarray(x, mode="RGBA").convert("RGB") |
| else: |
| raise ValueError(f"Invalid shape {x.shape}") |
| return x |
|
|
|
|
| def batch_tensor_to_heatmap( |
| x, scale=True, cmap="viridis", flip_vertically=False, resize=None, |
| concat=False, |
| ): |
| y = [] |
| for i in range(len(x)): |
| h = tensor_to_heatmap(x[i], scale, cmap, flip_vertically) |
| if resize is not None: |
| h = h.resize(resize) |
| y.append(h) |
| if concat: |
| y = concat_images_with_border(y) |
| return y |
|
|
|
|
| def change_contrast(img, level): |
| factor = (259 * (level + 255)) / (255 * (259 - level)) |
| def contrast(c): |
| return 128 + factor * (c - 128) |
| return img.point(contrast) |
|
|
|
|
| def change_brightness(img, alpha): |
| import PIL |
| enhancer = PIL.ImageEnhance.Brightness(img) |
| |
| img = enhancer.enhance(alpha) |
| return img |
|
|
|
|
| def draw_horizontal_lines(image, y_values, color=(255, 0, 0), colors=None, line_thickness=2): |
| """ |
| Draw horizontal lines on a PIL image at specified Y positions. |
| |
| Args: |
| image (PIL.Image.Image): The input PIL image. |
| y_values (list or int): List of Y positions where lines will be drawn. |
| If a single integer is provided, a line will be drawn at that Y position. |
| color (tuple): RGB color tuple (e.g., (255, 0, 0) for red). |
| line_thickness (int): Thickness of the lines. |
| |
| Returns: |
| PIL.Image.Image: The PIL image with the drawn lines. |
| """ |
| image = image.copy() |
| |
| if isinstance(color, str): |
| color = PIL.ImageColor.getcolor(color, "RGB") |
| |
| if colors is None: |
| colors = [color] * len(y_values) |
| else: |
| if isinstance(colors[0], str): |
| colors = [PIL.ImageColor.getcolor(c, "RGB") for c in colors] |
| |
| if isinstance(y_values, int): |
| y_values = [y_values] |
| |
| |
| draw = PIL.ImageDraw.Draw(image) |
|
|
| if isinstance(y_values, int): |
| y_values = [y_values] |
|
|
| for y, c in zip(y_values, colors): |
| draw.line([(0, y), (image.width, y)], fill=c, width=line_thickness) |
|
|
| return image |
|
|
|
|
| def draw_vertical_lines(image, x_values, color=(255, 0, 0), colors=None, line_thickness=2): |
| """ |
| Draw vertical lines on a PIL image at specified X positions. |
| |
| Args: |
| image (PIL.Image.Image): The input PIL image. |
| x_values (list or int): List of X positions where lines will be drawn. |
| If a single integer is provided, a line will be drawn at that X position. |
| color (tuple): RGB color tuple (e.g., (255, 0, 0) for red). |
| line_thickness (int): Thickness of the lines. |
| |
| Returns: |
| PIL.Image.Image: The PIL image with the drawn lines. |
| """ |
| image = image.copy() |
| |
| if isinstance(color, str): |
| color = PIL.ImageColor.getcolor(color, "RGB") |
| |
| if colors is None: |
| colors = [color] * len(x_values) |
| else: |
| if isinstance(colors[0], str): |
| colors = [PIL.ImageColor.getcolor(c, "RGB") for c in colors] |
| |
| if isinstance(x_values, int): |
| x_values = [x_values] |
| |
| |
| draw = PIL.ImageDraw.Draw(image) |
|
|
| if isinstance(x_values, int): |
| x_values = [x_values] |
|
|
| for x, c in zip(x_values, colors): |
| draw.line([(x, 0), (x, image.height)], fill=c, width=line_thickness) |
|
|
| return image |
|
|
|
|
| def show_arrow_on_image(image, start_loc, end_loc, color="red", thickness=3): |
| """Draw a line on PIL image from start_loc to end_loc.""" |
| image = image.copy() |
| color = get_predominant_color(color) |
|
|
| |
| draw = ImageDraw.Draw(image) |
|
|
| |
| draw.line([start_loc, end_loc], fill=color, width=thickness) |
|
|
| return image |
|
|
|
|
| def draw_arrow_on_image_cv2(image, start_loc, end_loc, color="red", thickness=2, both_ends=False): |
| image = image.copy() |
| image = np.asarray(image) |
| if isinstance(color, str): |
| color = PIL.ImageColor.getcolor(color, "RGB") |
| image = cv2.arrowedLine(image, start_loc, end_loc, color, thickness) |
| if both_ends: |
| image = cv2.arrowedLine(image, end_loc, start_loc, color, thickness) |
| return PIL.Image.fromarray(image) |
|
|
|
|
| def draw_arrow_with_text(image, start_loc, end_loc, text="", color="red", thickness=2, font_size=20, both_ends=False, delta=5): |
| image = np.asarray(image) |
| if isinstance(color, str): |
| color = PIL.ImageColor.getcolor(color, "RGB") |
|
|
| |
| center_x = (start_loc[0] + end_loc[0]) // 2 |
| center_y = (start_loc[1] + end_loc[1]) // 2 |
| center_point = (center_x, center_y) |
|
|
| |
| image = cv2.arrowedLine(image, start_loc, end_loc, color, thickness) |
| if both_ends: |
| image = cv2.arrowedLine(image, end_loc, start_loc, color, thickness) |
|
|
| |
| image_with_text = Image.fromarray(image) |
| draw = PIL.ImageDraw.Draw(image_with_text) |
| |
| |
| |
| |
| text_width, text_height = draw.textsize(text) |
| |
| |
| text_x = center_x - (text_width // 2) - delta |
| text_y = center_y - (text_height // 2) |
|
|
| |
| draw.text((text_x, text_y), text, color) |
|
|
| return image_with_text |
|
|
|
|
| def draw_arrowed_line(image, start_loc, end_loc, color="red", thickness=2): |
| """ |
| Draw an arrowed line on a PIL image from a starting point to an ending point. |
| |
| Args: |
| image (PIL.Image.Image): The input PIL image. |
| start_loc (tuple): Starting point (x, y) for the arrowed line. |
| end_loc (tuple): Ending point (x, y) for the arrowed line. |
| color (str): Color of the line (e.g., 'red', 'green', 'blue'). |
| thickness (int): Thickness of the line and arrowhead. |
| |
| Returns: |
| PIL.Image.Image: The PIL image with the drawn arrowed line. |
| """ |
| image = image.copy() |
| if isinstance(color, str): |
| color = PIL.ImageColor.getcolor(color, "RGB") |
| |
| |
| |
| draw = ImageDraw.Draw(image) |
|
|
| |
| draw.line([start_loc, end_loc], fill=color, width=thickness) |
|
|
| |
| arrow_size = 10 |
| dx = end_loc[0] - start_loc[0] |
| dy = end_loc[1] - start_loc[1] |
| length = (dx ** 2 + dy ** 2) ** 0.5 |
| cos_theta = dx / length |
| sin_theta = dy / length |
| x1 = end_loc[0] - arrow_size * cos_theta |
| y1 = end_loc[1] - arrow_size * sin_theta |
| x2 = end_loc[0] - arrow_size * sin_theta |
| y2 = end_loc[1] + arrow_size * cos_theta |
| x3 = end_loc[0] + arrow_size * sin_theta |
| y3 = end_loc[1] - arrow_size * cos_theta |
|
|
| |
| draw.polygon([end_loc, (x1, y1), (x2, y2), (x3, y3)], fill=color) |
|
|
| return image |
|
|
|
|
| def center_crop_to_fraction(image, frac=0.5): |
| """Center crop an image to a fraction of its original size.""" |
| width, height = image.size |
| new_width = int(width * frac) |
| new_height = int(height * frac) |
| left = (width - new_width) // 2 |
| top = (height - new_height) // 2 |
| right = (width + new_width) // 2 |
| bottom = (height + new_height) // 2 |
| return image.crop((left, top, right, bottom)) |
|
|
|
|
| def decord_load_frames(vr, frame_indices): |
| if isinstance(frame_indices, int): |
| frame_indices = [frame_indices] |
| frames = vr.get_batch(frame_indices).asnumpy() |
| frames = [Image.fromarray(frame) for frame in frames] |
| return frames |
|
|
|
|
| def paste_mask_on_image(original_image, bounding_box, mask): |
| """ |
| Paste a 2D mask onto the original image at the location specified by the bounding box. |
| |
| Parameters: |
| - original_image (PIL.Image): The original image. |
| - bounding_box (tuple): Bounding box coordinates (left, top, right, bottom). |
| - mask (PIL.Image): The 2D mask. |
| |
| Returns: |
| - PIL.Image: Image with the mask pasted on it. |
| |
| Example: |
| ``` |
| original_image = Image.open('original.jpg') |
| bounding_box = (100, 100, 200, 200) |
| mask = Image.open('mask.png') |
| result_image = paste_mask_on_image(original_image, bounding_box, mask) |
| result_image.show() |
| ``` |
| """ |
| |
| result_image = original_image.copy() |
|
|
| |
| mask_cropped = mask.crop((0, 0, bounding_box[2] - bounding_box[0], bounding_box[3] - bounding_box[1])) |
|
|
| |
| result_image.paste(mask_cropped, (bounding_box[0], bounding_box[1])) |
|
|
| return result_image |
|
|
|
|
| def display_images_as_video_moviepy(image_list, fps=5, show=True): |
| """ |
| Display a list of PIL images as a video in Jupyter Notebook using MoviePy. |
| |
| Parameters: |
| - image_list (list): List of PIL images. |
| - fps (int): Frames per second for the video. |
| - show (bool): Whether to display the video in the notebook. |
| |
| Example: |
| ``` |
| image_list = [Image.open('frame1.jpg'), Image.open('frame2.jpg'), ...] |
| display_images_as_video_moviepy(image_list, fps=10) |
| ``` |
| """ |
| from IPython.display import display |
| from moviepy.editor import ImageSequenceClip |
|
|
| image_list = list(map(np.asarray, image_list)) |
| clip = ImageSequenceClip(image_list, fps=fps) |
| if show: |
| display(clip.ipython_display(width=200)) |
| os.remove("__temp__.mp4") |
|
|
|
|
| def resize_height(img, H): |
| w, h = img.size |
| asp_ratio = w / h |
| W = np.ceil(asp_ratio * H).astype(int) |
| return img.resize((W, H)) |
|
|
|
|
| def resize_width(img, W): |
| w, h = img.size |
| asp_ratio = w / h |
| H = int(W / asp_ratio) |
| return img.resize((W, H)) |
|
|
|
|
| def resized_minor_side(img, size=256): |
| H, W = img.size |
| if H < W: |
| H_new = size |
| W_new = int(size * W / H) |
| return img.resize((W_new, H_new)) |
| else: |
| W_new = size |
| H_new = int(size * H / W) |
| return img.resize((W_new, H_new)) |
|
|
|
|
| def brighten_image(img, alpha=1.2): |
| enhancer = PIL.ImageEnhance.Brightness(img) |
| img = enhancer.enhance(alpha) |
| return img |
|
|
|
|
| def darken_image(img, alpha=0.8): |
| enhancer = PIL.ImageEnhance.Brightness(img) |
| img = enhancer.enhance(alpha) |
| return img |
|
|
|
|
| def fig2img(fig): |
| """Convert a Matplotlib figure to a PIL Image and return it""" |
| import io |
| buf = io.BytesIO() |
| fig.savefig(buf) |
| buf.seek(0) |
| img = Image.open(buf) |
| return img |
|
|
|
|
| def show_temporal_tsne( |
| tsne, |
| timestamps=None, |
| title="Feature projections over time", |
| cmap='viridis', |
| ax=None, |
| fig=None, |
| show=True, |
| num_ticks=10, |
| return_as_pil=False, |
| dpi=100, |
| label='Time (s)', |
| figsize=(6, 4.5), |
| xlim=None, |
| ylim=None, |
| s=None, |
| tsne_kwargs=dict(), |
| scatter=True, |
| colorbar=True, |
| marker="o", |
| alpha=0.9, |
| plot_label=None, |
| ): |
|
|
| if "method" not in tsne_kwargs: |
| method = "tsne" |
| tsne_kwargs["method"] = method |
| else: |
| method = tsne_kwargs["method"] |
| title = f"{title} ({method.upper()})" |
|
|
| if tsne.shape[1] > 2: |
| tsne = reduce_dim(tsne, **tsne_kwargs) |
| assert (tsne.shape[1] == 2), f"Invalid shape {tsne.shape}" |
|
|
| if timestamps is None: |
| timestamps = np.arange(len(tsne)) |
|
|
| if ax is None or fig is None: |
| fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) |
|
|
| cmap = plt.get_cmap(cmap) |
| if scatter: |
| scatter = ax.scatter( |
| tsne[:, 0], tsne[:, 1], c=np.arange(len(tsne)), cmap=cmap, s=s, |
| edgecolor='k', linewidth=0.5, marker=marker, alpha=alpha, label=plot_label, |
| ) |
| else: |
| |
| |
| |
| |
| scatter = ax.scatter( |
| tsne[:, 0], tsne[:, 1], c=np.arange(len(tsne)), cmap=cmap, s=s, |
| edgecolor='k', linewidth=0.5, marker=marker, alpha=alpha, label=plot_label, |
| ) |
| |
| colors = cmap(np.linspace(0, 1, len(tsne))) |
| for i in range(len(tsne) - 1): |
| avg_color = (colors[i] + colors[i + 1]) / 2 |
| ax.plot( |
| tsne[i:i + 2, 0], tsne[i:i + 2, 1], color=avg_color, linewidth=0.8, alpha=alpha, |
| ) |
| ax.annotate( |
| '', xy=tsne[i + 1], xytext=tsne[i], |
| arrowprops=dict( |
| arrowstyle='->, head_width=0.3, head_length=0.5', |
| color=avg_color, |
| linewidth=0.9, |
| shrinkA=0, |
| shrinkB=0, |
| ), |
| alpha=0.5, |
| ) |
|
|
|
|
| ax.grid(alpha=0.4) |
| ax.set_title(f"{title}", fontsize=11) |
| ax.set_xlabel("$z_{1}$") |
| ax.set_ylabel("$z_{2}$") |
| if xlim is not None: |
| ax.set_xlim(xlim) |
| if ylim is not None: |
| ax.set_ylim(ylim) |
| if plot_label is not None: |
| ax.legend() |
|
|
| |
| if colorbar: |
| cbar = fig.colorbar( |
| scatter, ax=ax, label=label, location='bottom', fraction=0.1, |
| ) |
| |
| |
| ticks = np.linspace(0, len(tsne) - 1, num_ticks, dtype=int) |
| tick_labels = np.round(timestamps[ticks], 1) |
| cbar.set_ticks(ticks) |
| cbar.set_ticklabels(tick_labels) |
|
|
| plt.tight_layout() |
|
|
| if show: |
| plt.show() |
| else: |
| if return_as_pil: |
| plt.tight_layout(pad=0.2) |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| image = fig2img(fig) |
| plt.close(fig) |
| return image |
|
|
|
|
| def show_projections_with_labels( |
| X, labels, title="", |
| ax=None, fig=None, show=True, s=10, figsize=(6, 4), dpi=100, |
| cmap="viridis", verbose=True, diff_markers=True, method="tsne", legend=True, |
| alpha=0.8, legend_ncol=2, legend_outside=False, |
| ): |
| if ax is None or fig is None: |
| fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) |
| else: |
| show = False |
| labels = np.array(labels) |
|
|
| |
| if X.shape[1] == 2: |
| Z = X |
| else: |
| if method == "tsne": |
| Z = reduce_dim(X, method="tsne") |
| elif method == "umap": |
| Z = reduce_dim(X, method="umap") |
| elif method == "pca": |
| Z = reduce_dim(X, method="pca") |
| else: |
| raise ValueError(f"Unknown method: {method}") |
|
|
| |
| unique_labels = np.unique(labels) |
| n_labels = len(unique_labels) |
| colors = get_colors(n_labels, palette=cmap) |
| if verbose: |
| print("Number of unique labels:", len(unique_labels)) |
| |
| |
|
|
| |
| if diff_markers: |
| markers = ['o', 'X', '+', 'p', 's', '^', 'v', '<', '>', 'd', 'h'] |
| |
| markers = markers * (n_labels // len(markers) + 1) |
| else: |
| markers = ['o'] * n_labels |
|
|
| |
| for i, y in enumerate(unique_labels): |
| indices = np.where(labels == y) |
| ax.scatter( |
| Z[indices, 0], |
| Z[indices, 1], |
| label=y, |
| s=s, |
| color=colors[i], |
| marker=markers[i], |
| alpha=alpha, |
| ) |
|
|
| ax.grid(alpha=0.4) |
| if title is not None and len(title): |
| ax.set_title(f"{title}", fontsize=11) |
| ax.set_xlabel("$z_{1}$") |
| ax.set_ylabel("$z_{2}$") |
| if legend: |
| if legend_outside: |
| ax.legend(ncol=legend_ncol, loc='upper right', bbox_to_anchor=(1.3, 1)) |
| else: |
| ax.legend(ncol=legend_ncol) |
| |
|
|
| if show: |
| plt.show() |
|
|
|
|
| def mark_keypoints(image, keypoints, color=(255, 255, 0), radius=1): |
| """ |
| Marks keypoints on an image with a given color and radius. |
| |
| :param image: The input PIL image. |
| :param keypoints: A list of (x, y) tuples representing the keypoints. |
| :param color: The color to use for the keypoints (default: red). |
| :param radius: The radius of the circle to draw for each keypoint (default: 5). |
| :return: A new PIL image with the keypoints marked. |
| """ |
| |
| image_copy = image.copy() |
|
|
| |
| draw = ImageDraw.Draw(image_copy) |
|
|
| |
| for x, y in keypoints: |
| |
| draw.ellipse( |
| (x - radius, y - radius, x + radius, y + radius), |
| fill=color, |
| width=2 |
| ) |
|
|
| return image_copy |
|
|
|
|
| def draw_line_on_image(image, x_coords, y_coords, color=(255, 255, 0), width=3): |
| """ |
| Draws a line on an image given lists of x and y coordinates. |
| |
| :param image: The input PIL image. |
| :param x_coords: List of x-coordinates for the line. |
| :param y_coords: List of y-coordinates for the line. |
| :param color: Color of the line in RGB (default is red). |
| :param width: Width of the line (default is 3). |
| :return: The PIL image with the line drawn. |
| """ |
| image = image.copy() |
|
|
| |
| if len(x_coords) != len(y_coords): |
| raise ValueError("x_coords and y_coords must have the same length") |
|
|
| |
| draw = ImageDraw.Draw(image) |
|
|
| |
| coordinates = list(zip(x_coords, y_coords)) |
|
|
| |
| draw.line(coordinates, fill=color, width=width) |
|
|
| return image |
|
|
|
|
| def add_binary_strip_vertically( |
| image, |
| binary_vector, |
| strip_width=15, |
| one_color="yellow", |
| zero_color="gray", |
| ): |
| """ |
| Add a binary strip to the right side of an image. |
| |
| :param image: PIL Image to which the strip will be added. |
| :param binary_vector: Binary vector of length 512 representing the strip. |
| :param strip_width: Width of the strip to be added. |
| :param one_color: Color for "1" pixels (default: red). |
| :param zero_color: Color for "0" pixels (default: white). |
| :return: New image with the binary strip added on the right side. |
| """ |
| one_color = PIL.ImageColor.getrgb(one_color) |
| zero_color = PIL.ImageColor.getrgb(zero_color) |
|
|
| height = image.height |
| if len(binary_vector) != height: |
| raise ValueError("Binary vector must be of length 512") |
|
|
| |
| strip = PIL.Image.new("RGB", (strip_width, height)) |
|
|
| |
| pixels = strip.load() |
| for i in range(height): |
| color = one_color if binary_vector[i] == 1 else zero_color |
| for w in range(strip_width): |
| pixels[w, i] = color |
|
|
| |
| |
| |
| |
| new_image = image.copy() |
| new_image.paste(strip, (image.width - strip_width, 0)) |
|
|
| return new_image |
|
|
|
|
| def add_binary_strip_horizontally( |
| image, |
| binary_vector, |
| strip_height=15, |
| one_color="limegreen", |
| zero_color="gray", |
| ): |
| """ |
| Add a binary strip to the top of an image. |
| |
| :param image: PIL Image to which the strip will be added. |
| :param binary_vector: Binary vector of length 512 representing the strip. |
| :param strip_height: Height of the strip to be added. |
| :param one_color: Color for "1" pixels, accepts color names or hex (default: red). |
| :param zero_color: Color for "0" pixels, accepts color names or hex (default: white). |
| :return: New image with the binary strip added at the top. |
| """ |
| width = image.width |
| if len(binary_vector) != width: |
| raise ValueError("Binary vector must be of length 512") |
|
|
| |
| one_color_rgb = PIL.ImageColor.getrgb(one_color) |
| zero_color_rgb = PIL.ImageColor.getrgb(zero_color) |
|
|
| |
| strip = PIL.Image.new("RGB", (width, strip_height)) |
|
|
| |
| pixels = strip.load() |
| for i in range(width): |
| color = one_color_rgb if binary_vector[i] == 1 else zero_color_rgb |
| for h in range(strip_height): |
| pixels[i, h] = color |
|
|
| |
| |
| |
| |
| new_image = image.copy() |
| new_image.paste(strip, (0, 0)) |
|
|
| return new_image |
|
|
|
|
| |
| def increase_font_sizes(ax, font_scale=1.6): |
| for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + |
| ax.get_xticklabels() + ax.get_yticklabels()): |
| item.set_fontsize(item.get_fontsize() * font_scale) |
|
|
|
|
| def draw_multiple_boxes_on_image(image, boxes, colors=None, thickness=3): |
| image = image.copy() |
| if colors is None: |
| colors = get_linspace_colors(num_colors=len(boxes), cmap_name="bwr") |
| else: |
| assert len(colors) == len(boxes) |
| for box, color in zip(boxes, colors): |
| image = add_box_on_image(image, box, color, thickness) |
| return image |
|
|
|
|
|
|
| def mask_out_bbox(image, box): |
| """ |
| Masks out a bounding box in an image by setting the pixel values to zero. |
| |
| Args: |
| image (PIL.Image): The input image. |
| box (list): The bounding box coordinates [xmin, ymin, xmax, ymax]. |
| """ |
| |
| image = np.array(image) |
|
|
| |
| xmin, ymin, xmax, ymax = box |
|
|
| |
| image[ymin:ymax, xmin:xmax] = 0 |
|
|
| |
| return Image.fromarray(image) |
|
|
|
|
| def sample_cmap(cmap='viridis', K=10): |
| """ |
| Samples K colors from a given colormap. |
| |
| Args: |
| cmap: The name of the matplotlib colormap to sample from. Defaults to 'viridis'. |
| K: The number of colors to sample. Defaults to 10. |
| |
| Returns: |
| A list of K RGB color tuples. |
| """ |
| |
| |
| cmap_obj = plt.get_cmap(cmap) |
| |
| |
| values = np.linspace(0, 1, K) |
| |
| |
| colors = [cmap_obj(value) for value in values] |
| |
| return colors |
|
|
|
|
| def cut_long_string(s, max_len=40): |
| if len(s) > max_len: |
| s = s[:max_len] + " ..." |
| return s |
|
|
|
|
| import textwrap |
|
|
|
|
| def get_terminal_width(): |
| import shutil |
| return shutil.get_terminal_size().columns |
|
|
|
|
| def wrap_text(text: str, max_length: int = 100) -> str: |
| """ |
| Wraps a long string to the specified max_length for easier printing. |
| |
| Args: |
| text (str): The input string to wrap. |
| max_length (int): The maximum length of each line. Default is 80. |
| |
| Returns: |
| str: The wrapped text with lines at most max_length long. |
| """ |
| terminal_width = get_terminal_width() |
| max_length = min(max_length, terminal_width) |
| wrapped_text = textwrap.fill(text, width=max_length) |
| return wrapped_text |
|
|
|
|
| def blank_image(color="black", size=(256, 256), border_color="black", border_width=2): |
| """ |
| Creates a blank image with a specified color and size. |
| |
| Args: |
| color (str): The color of the image. Default is 'white'. |
| size (tuple): The size of the image in pixels. Default is (256, 256). |
| |
| Returns: |
| PIL.Image: A blank image with the specified color and size. |
| """ |
| image = Image.new("RGB", size, color) |
|
|
| |
| if border_width > 0: |
| draw = ImageDraw.Draw(image) |
| draw.rectangle([0, 0, size[0] - 1, size[1] - 1], outline=border_color, width=border_width) |
| |
| return image |
|
|
|
|
| def insert_text_center( |
| image, |
| text, |
| font_size=100, |
| font_path="/users/piyush/.local/fonts/arial.ttf", |
| ): |
| assert os.path.exists(font_path), f"Font file not found at {font_path}" |
|
|
| from PIL import ImageFont |
|
|
| |
| font = ImageFont.truetype(font_path, font_size) |
| |
| |
| draw = ImageDraw.Draw(image) |
| |
| |
| text_bbox = draw.textbbox((0, 0), text, font=font) |
| text_width = text_bbox[2] - text_bbox[0] |
| text_height = text_bbox[3] - text_bbox[1] |
| |
| |
| x = (image.width - text_width) / 2 |
| y = (image.height - text_height) / 2 |
| |
| |
| draw.text((x, y), text, font=font, fill="black") |
| |
| return image |
|
|
|
|
| from PIL import Image, ImageDraw, ImageColor |
|
|
| def add_shape_to_image( |
| image, |
| shape="circle", |
| size=0.1, |
| location=(0.5, 0.5), |
| facecolor="red", |
| edgecolor="red", |
| edgethickness=2, |
| check_bounds=True, |
| ): |
| """ |
| Adds a shape to an image at a specified location. |
| |
| Args: |
| image (PIL.Image): The input image. |
| shape (str): The shape to add. Can be 'circle', 'rectangle', or 'triangle'. |
| size (float): The size of the shape as a fraction of the image size (min side). |
| location (tuple): The location of the shape as a fraction of the image size. |
| facecolor (str): The fill color of the shape. |
| edgecolor (str): The edge color of the shape. |
| edgethickness (int): The thickness of the edge. |
| check_bounds (bool): Whether to check if the shape is within the image bounds. |
| |
| Returns: |
| PIL.Image: The image with the shape added. |
| """ |
| image = image.copy() |
|
|
| |
| width, height = image.size |
| min_side = min(width, height) |
| |
| |
| center_x = int(location[0] * width) |
| center_y = int(location[1] * height) |
| |
| |
| draw = ImageDraw.Draw(image) |
| |
| |
| shape_size = int(size * min_side) |
| if shape == "circle": |
| shape_size /= 2. |
| |
| |
| if check_bounds: |
| if shape == "circle": |
| if not (0 <= center_x - shape_size < width and 0 <= center_x + shape_size < width and |
| 0 <= center_y - shape_size < height and 0 <= center_y + shape_size < height): |
| raise ValueError("The shape would be out of image bounds with the specified location and size.") |
| |
| elif shape == "rectangle": |
| if not (0 <= center_x - shape_size // 2 < width and 0 <= center_x + shape_size // 2 < width and |
| 0 <= center_y - shape_size // 2 < height and 0 <= center_y + shape_size // 2 < height): |
| raise ValueError("The shape would be out of image bounds with the specified location and size.") |
| |
| elif shape == "triangle": |
| if not (0 <= center_y - shape_size < height and |
| 0 <= center_x - shape_size // 2 < width and 0 <= center_x + shape_size // 2 < width): |
| raise ValueError("The shape would be out of image bounds with the specified location and size.") |
| |
| |
| if shape == "circle": |
| |
| bbox = [center_x - shape_size, center_y - shape_size, center_x + shape_size, center_y + shape_size] |
| draw.ellipse(bbox, fill=facecolor, outline=edgecolor, width=edgethickness) |
| |
| elif shape == "rectangle": |
| |
| bbox = [center_x - shape_size // 2, center_y - shape_size // 2, |
| center_x + shape_size // 2, center_y + shape_size // 2] |
| draw.rectangle(bbox, fill=facecolor, outline=edgecolor, width=edgethickness) |
| |
| elif shape == "triangle": |
| |
| points = [ |
| (center_x, center_y - shape_size), |
| (center_x - shape_size // 2, center_y + shape_size // 2), |
| (center_x + shape_size // 2, center_y + shape_size // 2), |
| ] |
| draw.polygon(points, fill=facecolor, outline=edgecolor) |
| |
| else: |
| raise ValueError("Shape not recognized. Use 'circle', 'rectangle', or 'triangle'.") |
| |
| return image |
|
|
|
|
| import math |
|
|
| def draw_star( |
| center_x, |
| center_y, |
| outer_radius, |
| inner_radius, |
| num_points, |
| draw_obj, |
| facecolor, |
| edgecolor, |
| edgethickness, |
| ): |
| """ |
| Draws a star on the given `draw_obj`. |
| |
| Args: |
| center_x (int): X-coordinate of the star's center. |
| center_y (int): Y-coordinate of the star's center. |
| outer_radius (int): Radius of the outer points of the star. |
| inner_radius (int): Radius of the inner points of the star. |
| num_points (int): Number of points the star has (must be >= 5). |
| draw_obj (ImageDraw.Draw): The drawing object to draw the star. |
| facecolor (tuple): RGBA color for the star's fill. |
| edgecolor (tuple): RGBA color for the star's edge. |
| edgethickness (int): Thickness of the star's edge. |
| |
| Returns: |
| None |
| """ |
| if num_points < 5: |
| raise ValueError("Number of points must be at least 5 for a star.") |
|
|
| |
| points = [] |
| angle = 2 * math.pi / (2 * num_points) |
| for i in range(2 * num_points): |
| radius = outer_radius if i % 2 == 0 else inner_radius |
| x = center_x + int(radius * math.sin(i * angle)) |
| y = center_y - int(radius * math.cos(i * angle)) |
| points.append((x, y)) |
|
|
| |
| draw_obj.polygon(points, fill=facecolor, outline=edgecolor) |
|
|
|
|
| def add_shape_to_image_with_opacity( |
| image, |
| shape="circle", |
| size=0.1, |
| location=(0.5, 0.5), |
| facecolor="red", |
| edgecolor="red", |
| edgethickness=2, |
| opacity=1.0, |
| check_bounds=True, |
| ): |
| |
| opacity = max(0.0, min(1.0, opacity)) |
|
|
| |
| width, height = image.size |
| min_side = min(width, height) |
| shape_size = int(size * min_side) |
| center_x = int(location[0] * width) |
| center_y = int(location[1] * height) |
| if shape == "circle": |
| shape_size /= 2. |
|
|
| |
| overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
| overlay_draw = ImageDraw.Draw(overlay) |
|
|
| |
| if isinstance(facecolor, str): |
| rgba_facecolor = (*ImageColor.getrgb(facecolor), int(255 * opacity)) |
| elif isinstance(facecolor, (np.ndarray, tuple, list)): |
| assert len(facecolor) == 3, "RGB color must have 3 components" |
| if np.max(facecolor) <= 1: |
| facecolor = [int(255 * c) for c in facecolor] |
| rgba_facecolor = (*list(facecolor), int(255 * opacity)) |
| else: |
| raise ValueError("Facecolor must be a string or RGB tuple.") |
| if isinstance(edgecolor, str): |
| rgba_edgecolor = (*ImageColor.getrgb(edgecolor), int(255 * opacity)) |
| elif isinstance(edgecolor, (np.ndarray, tuple, list)): |
| assert len(edgecolor) == 3, "RGB color must have 3 components" |
| if np.max(edgecolor) <= 1: |
| edgecolor = [int(255 * c) for c in edgecolor] |
| rgba_edgecolor = (*list(edgecolor), int(255 * opacity)) |
| else: |
| raise ValueError("Edgecolor must be a string or RGB tuple.") |
|
|
| |
| if check_bounds: |
| if shape == "circle": |
| if not (0 <= center_x - shape_size < width and 0 <= center_x + shape_size < width and |
| 0 <= center_y - shape_size < height and 0 <= center_y + shape_size < height): |
| raise ValueError("The shape would be out of image bounds with the specified location and size.") |
| |
| elif shape == "rectangle": |
| if not (0 <= center_x - shape_size // 2 < width and 0 <= center_x + shape_size // 2 < width and |
| 0 <= center_y - shape_size // 2 < height and 0 <= center_y + shape_size // 2 < height): |
| raise ValueError("The shape would be out of image bounds with the specified location and size.") |
| |
| elif shape == "triangle": |
| if not (0 <= center_y - shape_size < height and |
| 0 <= center_x - shape_size // 2 < width and 0 <= center_x + shape_size // 2 < width): |
| raise ValueError("The shape would be out of image bounds with the specified location and size.") |
| |
| elif shape == "star": |
| if not (0 <= center_x - shape_size // 2 < width and 0 <= center_x + shape_size // 2 < width and |
| 0 <= center_y - shape_size // 2 < height and 0 <= center_y + shape_size // 2 < height): |
| raise ValueError("The shape would be out of image bounds with the specified location and size.") |
|
|
| |
| if shape == "circle": |
| bbox = [center_x - shape_size, center_y - shape_size, center_x + shape_size, center_y + shape_size] |
| overlay_draw.ellipse(bbox, fill=rgba_facecolor, outline=rgba_edgecolor, width=edgethickness) |
| |
| elif shape == "rectangle": |
| bbox = [center_x - shape_size // 2, center_y - shape_size // 2, |
| center_x + shape_size // 2, center_y + shape_size // 2] |
| overlay_draw.rectangle(bbox, fill=rgba_facecolor, outline=rgba_edgecolor, width=edgethickness) |
| |
| elif shape == "triangle": |
| points = [ |
| (center_x, center_y - shape_size // 2), |
| (center_x - shape_size // 2, center_y + shape_size // 2), |
| (center_x + shape_size // 2, center_y + shape_size // 2), |
| ] |
| overlay_draw.polygon(points, fill=rgba_facecolor, outline=rgba_edgecolor) |
| |
| elif shape == "star": |
| |
| outer_radius = shape_size // 2 |
| inner_radius = shape_size // 4 |
| num_points = 5 |
|
|
| |
| draw_star( |
| center_x, center_y, outer_radius, inner_radius, |
| num_points, overlay_draw, rgba_facecolor, |
| rgba_edgecolor, edgethickness, |
| ) |
|
|
| |
| image = Image.alpha_composite(image.convert("RGBA"), overlay) |
|
|
| return image.convert("RGB") |
|
|
|
|
| from PIL import Image, ImageDraw, ImageFont |
|
|
| def add_text_to_frames( |
| image, |
| text_numbers=["1", "2", "3"], |
| position="top_left", |
| font_size=50, |
| font_color="red", |
| ): |
| """ |
| Adds text numbers to an image containing frames. |
| |
| Parameters: |
| image (PIL.Image): Image containing frames. |
| text_numbers (list): List of text numbers to add (e.g., ['1', '2', '3']). |
| positions (list): List of positions for the text (e.g., ['top_left', 'top_right']). |
| font_size (int): Font size of the text. |
| output_path (str): Path to save the output image. |
| font_color (str): Color of the font. |
| |
| Supported Positions: |
| - 'top_left', 'top_right', 'bottom_left', 'bottom_right' |
| """ |
| image = image.copy() |
| positions = [position] * len(text_numbers) |
|
|
| |
| draw = ImageDraw.Draw(image) |
| width, height = image.size |
| frame_width = width // len(text_numbers) |
| |
| |
| font = ImageFont.load_default(size=font_size) |
|
|
| |
| offset_map = { |
| 'top_left': (10, 10), |
| 'top_right': (frame_width - 10, 10), |
| 'bottom_left': (10, height - 10), |
| 'bottom_right': (frame_width - 10, height - 10), |
| } |
|
|
| |
| for i, (text, position) in enumerate(zip(text_numbers, positions)): |
| x_offset = i * frame_width + 5 |
| x, y = offset_map[position] |
| draw.text((x + x_offset, y), text, fill=font_color, font=font) |
|
|
| return image |
|
|
|
|
| POSITIONS = [ |
| "top_left", "top_right", "bottom_left", "bottom_right", "center" |
| ] |
|
|
|
|
| def add_texts_to_frames( |
| frames, |
| texts, |
| fontsize=60, |
| color="red", |
| position="top_left", |
| ): |
| """ |
| Adds texts onto given frames. |
| """ |
| assert len(frames) == len(texts), \ |
| "Number of frames and texts must be the same." |
| assert position in POSITIONS, \ |
| f"Position must be one of {POSITIONS}." |
|
|
| |
| new_frames = [] |
| for frame, text in zip(frames, texts): |
| frame = frame.copy() |
|
|
| |
| draw = ImageDraw.Draw(frame) |
| width, height = frame.size |
|
|
| |
| font = ImageFont.load_default(size=fontsize) |
|
|
| |
| offset_map = { |
| 'top_left': (10, 10), |
| 'top_right': (width - 10, 10), |
| 'bottom_left': (10, height - 10), |
| 'bottom_right': (width - 10, height - 10), |
| "center": (width // 2, height // 2), |
| } |
|
|
| |
| x, y = offset_map[position] |
| draw.text((x, y), text, fill=color, font=font) |
|
|
| new_frames.append(frame) |
| |
| return new_frames |
|
|
|
|
| def add_border(image, color="red", thickness=10): |
| """ |
| Adds border to an image without adding new pixels. |
| """ |
| image = image.copy() |
| draw = ImageDraw.Draw(image) |
| width, height = image.size |
| draw.rectangle([0, 0, width - 1, height - 1], outline=color, width=thickness) |
| return image |
|
|
|
|
| def visualize_dense_feature_norm(x, size=(224, 224), stitch=True): |
| """ |
| Args: |
| x (torch.Tensor): [F, H', W', D] |
| """ |
| x = x.norm(dim=-1) |
| x = torch.nn.functional.interpolate( |
| x.unsqueeze(0), size=size, mode='bilinear', align_corners=False, |
| ).squeeze(0) |
| x = x.cpu().numpy() |
| x = batch_tensor_to_heatmap(x) |
| if stitch: |
| x = concat_images_with_border(x) |
| return x |
|
|
|
|
| import seaborn as sns |
| def plot_confusion_matrix(x, cmap="viridis", title=None, show=False, return_as_pil=False): |
| """ |
| Args: |
| x (torch.Tensor): [C, C] |
| """ |
| |
| fig, ax = plt.subplots(1, 1, figsize=(5, 4)) |
| sns.heatmap(x, cmap=cmap, annot=True, ax=ax, vmin=-1, vmax=1) |
| ax.set_title(title) |
| if show: |
| plt.show() |
| if return_as_pil: |
| plt.tight_layout(pad=0.2) |
| image = fig2img(fig) |
| plt.close(fig) |
| return image |
| return fig |
|
|
|
|
| from IPython.display import display, Markdown |
| def show_text(text): |
| display(Markdown(text)) |
|
|
|
|
| def show_text_with_color(text, color="red"): |
| display(Markdown(f"<span style='color: {color};'>{text}</span>")) |