Spaces:
Runtime error
Runtime error
| import matplotlib.pyplot as plt | |
| from matplotlib.gridspec import GridSpec | |
| import numpy as np | |
| def display_image_and_text(image, captions, prompt=None): | |
| # Create a subplot to display the image and text side by side | |
| fig, ax = plt.subplots(2, 1, figsize=(10, 5), height_ratios=np.array([0.9, 0.1])) | |
| max_length = 100 | |
| new_captions = [] | |
| for caption in captions: | |
| if len(caption) > max_length: | |
| # Split the string into chunks of max_length | |
| chunks = [caption[i:i + max_length] for i in range(0, len(caption), max_length)] | |
| formatted_string = '\n'.join(chunks) | |
| new_captions.append(formatted_string) | |
| else: | |
| new_captions.append(caption) | |
| text_lines = '\n\n'.join(new_captions) | |
| # Display the image on the upper subplot | |
| ax[0].imshow(image, cmap='gray') | |
| ax[0].axis('off') | |
| if prompt is not None: | |
| fig.text(0.5, 0.95, prompt, fontsize=14, ha='center', va='center') | |
| # Display the text on the bottom subplot | |
| ax[1].text(0.5, 0.5, text_lines, fontsize=12, ha='center', va='center') | |
| ax[1].axis('off') | |
| # plt.subplots_adjust(wspace=0.8) | |
| return fig | |
| # Show the image and text side by side | |
| # plt.show() | |
| def display_images_and_text(images, captions, model): | |
| text_lines = '\n\n'.join(captions) | |
| fig, ax = plt.subplots(3, 3, figsize=(26, 10)) | |
| # # Create a 2x3 grid for images | |
| # grid = GridSpec(2, 3) | |
| # | |
| # # Create a figure and axes | |
| # fig = plt.figure(figsize=(10, 5)) | |
| # axes = [fig.add_subplot(grid[i // 3, i % 3]) for i in range(6)] | |
| # Display images on axes | |
| for i in range(9): | |
| # Display the image on the upper subplot | |
| ax[i // 3, i % 3].imshow(images[i], cmap='gray') | |
| ax[i // 3, i % 3].axis('off') | |
| ax[i // 3, i % 3].set_title("rank " + str(i + 1), fontsize=14) | |
| # plt.subplots_adjust(hspace=-0.2) | |
| # Add a text sentence beneath images | |
| fig.text(0.5, 0.05, text_lines, fontsize=16, ha='center', va='center') | |
| # add model name | |
| fig.suptitle(model, fontsize=16) | |
| # Show the plot | |
| # plt.tight_layout() | |
| return fig | |