Spaces:
Running
Running
| """ | |
| Visualization utilities for LLaVA. | |
| """ | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import torch | |
| from typing import List, Dict, Tuple, Optional, Union | |
| import cv2 | |
| def display_image_with_caption( | |
| image_path: str, | |
| caption: str, | |
| figsize: Tuple[int, int] = (10, 10) | |
| ) -> None: | |
| """ | |
| Display an image with a caption. | |
| Args: | |
| image_path: Path to the image file | |
| caption: Caption text | |
| figsize: Figure size | |
| """ | |
| image = Image.open(image_path).convert('RGB') | |
| plt.figure(figsize=figsize) | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.title(caption) | |
| plt.tight_layout() | |
| plt.show() | |
| def visualize_attention( | |
| image_path: str, | |
| attention_weights: torch.Tensor, | |
| figsize: Tuple[int, int] = (15, 5) | |
| ) -> None: | |
| """ | |
| Visualize attention weights on an image. | |
| Args: | |
| image_path: Path to the image file | |
| attention_weights: Attention weights tensor | |
| figsize: Figure size | |
| """ | |
| # Load image | |
| image = Image.open(image_path).convert('RGB') | |
| image_np = np.array(image) | |
| # Normalize attention weights | |
| if attention_weights.dim() > 2: | |
| # Average across heads and layers if necessary | |
| attention_weights = attention_weights.mean(dim=(0, 1)) | |
| attention_weights = attention_weights.detach().cpu().numpy() | |
| attention_weights = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min()) | |
| # Resize attention map to image size | |
| attention_map = cv2.resize(attention_weights, (image.width, image.height)) | |
| # Create heatmap | |
| heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| # Overlay heatmap on image | |
| alpha = 0.5 | |
| overlay = heatmap * alpha + image_np * (1 - alpha) | |
| overlay = overlay.astype(np.uint8) | |
| # Display original image and attention overlay | |
| fig, axes = plt.subplots(1, 3, figsize=figsize) | |
| axes[0].imshow(image_np) | |
| axes[0].set_title('Original Image') | |
| axes[0].axis('off') | |
| axes[1].imshow(heatmap) | |
| axes[1].set_title('Attention Map') | |
| axes[1].axis('off') | |
| axes[2].imshow(overlay) | |
| axes[2].set_title('Overlay') | |
| axes[2].axis('off') | |
| plt.tight_layout() | |
| plt.show() | |
| def create_comparison_grid( | |
| image_path: str, | |
| responses: List[Dict[str, str]], | |
| output_path: Optional[str] = None, | |
| figsize: Tuple[int, int] = (12, 10) | |
| ) -> None: | |
| """ | |
| Create a comparison grid of different model responses. | |
| Args: | |
| image_path: Path to the image file | |
| responses: List of dictionaries with 'model' and 'response' keys | |
| output_path: Optional path to save the figure | |
| figsize: Figure size | |
| """ | |
| # Load image | |
| image = Image.open(image_path).convert('RGB') | |
| # Create figure | |
| fig = plt.figure(figsize=figsize) | |
| # Add image | |
| ax1 = plt.subplot2grid((len(responses) + 1, 3), (0, 0), colspan=3) | |
| ax1.imshow(image) | |
| ax1.set_title('Input Image') | |
| ax1.axis('off') | |
| # Add responses | |
| for i, resp in enumerate(responses): | |
| ax = plt.subplot2grid((len(responses) + 1, 3), (i + 1, 0), colspan=3) | |
| ax.text(0.5, 0.5, f"{resp['model']}: {resp['response']}", | |
| wrap=True, horizontalalignment='center', | |
| verticalalignment='center', fontsize=10) | |
| ax.axis('off') | |
| plt.tight_layout() | |
| # Save figure if output path is provided | |
| if output_path: | |
| plt.savefig(output_path, bbox_inches='tight') | |
| plt.show() | |
| def add_caption_to_image( | |
| image_path: str, | |
| caption: str, | |
| output_path: str, | |
| font_size: int = 20, | |
| font_color: Tuple[int, int, int] = (255, 255, 255), | |
| bg_color: Tuple[int, int, int] = (0, 0, 0) | |
| ) -> None: | |
| """ | |
| Add a caption to an image and save it. | |
| Args: | |
| image_path: Path to the input image | |
| caption: Caption text | |
| output_path: Path to save the output image | |
| font_size: Font size | |
| font_color: Font color (RGB) | |
| bg_color: Background color (RGB) | |
| """ | |
| # Load image | |
| image = Image.open(image_path).convert('RGB') | |
| # Create a new image with space for the caption | |
| caption_height = font_size + 20 # Add some padding | |
| new_image = Image.new('RGB', (image.width, image.height + caption_height), bg_color) | |
| new_image.paste(image, (0, 0)) | |
| # Add caption | |
| draw = ImageDraw.Draw(new_image) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", font_size) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| # Calculate text position | |
| text_width = draw.textlength(caption, font=font) | |
| text_position = ((image.width - text_width) // 2, image.height + 10) | |
| # Draw text | |
| draw.text(text_position, caption, font=font, fill=font_color) | |
| # Save image | |
| new_image.save(output_path) |