Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import os | |
| def plot_entropies( # Renamed from plot_entropies_revised for final output | |
| patch_lengths: torch.Tensor, | |
| scores: torch.Tensor, | |
| tokens: torch.Tensor, # Length used via scores. Content implicitly for UTF-8 assumption. | |
| chars: str, | |
| threshold: float | |
| ): | |
| patch_lengths_np = patch_lengths.cpu().numpy().flatten() | |
| scores_np = scores.cpu().float().numpy().flatten() | |
| num_total_bytes_from_scores = len(scores_np) | |
| # Prepare display string (prepend '<', replace spaces with '_') | |
| display_string_processed_chars = chars.replace(" ", "_") | |
| display_string = "<" + display_string_processed_chars | |
| display_chars_list = list(display_string) | |
| num_display_chars = len(display_chars_list) | |
| if num_display_chars == 0 and num_total_bytes_from_scores == 0: | |
| fig, ax = plt.subplots(figsize=(15,5)) | |
| ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center', fontsize=12) | |
| ax.set_xlabel("Characters (on underlying byte sequence)") | |
| ax.set_ylabel("Entropy of Next Byte") | |
| ax.set_ylim(bottom=0) | |
| ax.set_xlim(left = -0.5, right = 0.5) # Default xlim for empty plot | |
| return fig | |
| elif num_display_chars == 0 and num_total_bytes_from_scores > 0: | |
| # Edge case: scores exist but no characters to map them to (implies an issue) | |
| # For now, proceed with byte plot but no char labels. Or raise error. | |
| # Assuming display_chars_list should not be empty if scores_np is not. | |
| # This case should ideally be caught by byte_counts_per_display_char validation if it were run. | |
| # If display_chars_list is truly empty but scores are not, an error should be raised by validation. | |
| pass # Will be caught by validation if sum(byte_counts) != len(scores) | |
| # Calculate byte counts for each character in the display string (assuming UTF-8) | |
| try: | |
| byte_counts_per_display_char = [len(c.encode('utf-8')) for c in display_chars_list] | |
| except UnicodeEncodeError as e: | |
| raise ValueError( | |
| f"Could not encode characters in 'chars' string using UTF-8. " | |
| f"Problematic part: '{display_string_processed_chars}'. Error: {e}" | |
| ) | |
| # --- Validations --- | |
| if sum(byte_counts_per_display_char) != num_total_bytes_from_scores: | |
| # This condition also handles num_display_chars == 0 but num_total_bytes_from_scores > 0 | |
| raise ValueError( | |
| f"Mismatch in byte counts: Sum of UTF-8 bytes for display_string " | |
| f"('{display_string}' -> {sum(byte_counts_per_display_char)} bytes) " | |
| f"does not match length of scores tensor ({num_total_bytes_from_scores}). " | |
| f"Ensure 'chars' (and the prepended '<') correctly correspond to the byte sequence " | |
| f"represented by 'scores'/'tokens'." | |
| ) | |
| if patch_lengths_np.sum() != num_total_bytes_from_scores: | |
| raise ValueError( | |
| f"Sum of patch_lengths ({patch_lengths_np.sum()}) " | |
| f"does not match length of scores ({num_total_bytes_from_scores})." | |
| ) | |
| # --- Plotting Setup --- | |
| fig, ax = plt.subplots(figsize=(15, 5)) # Fixed size as requested | |
| x_byte_indices = np.arange(num_total_bytes_from_scores) | |
| # --- Plot Scores (Horizontally per byte) --- | |
| # Original plot line style from user's code: marker='.', linestyle='-' | |
| ax.plot(x_byte_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores per byte') | |
| # --- Plot Vertical Patch Boundary Lines --- | |
| # Using (cumulative_length - 0.5) logic for lines between byte elements. | |
| # This matches the intent of `boundary - 1 + 0.5` from user's original code snippet. | |
| patch_end_byte_cumulative_lengths = np.cumsum(patch_lengths_np) | |
| for boundary_len in patch_end_byte_cumulative_lengths[:-1]: # Exclude the last boundary (end of all data) | |
| ax.axvline(x=boundary_len, color='grey', linestyle='--', linewidth=1) | |
| # --- Horizontal Threshold Line and Annotation --- | |
| ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1) | |
| ax.annotate(f'Entropy Threshold', # Original text from user's code | |
| xy=(0.05, threshold), # Original xy from user's code | |
| xytext=(0.05, threshold + 0.1),# Original xytext from user's code | |
| xycoords='axes fraction', # Original xycoords | |
| textcoords='data', # Original textcoords | |
| color='red' | |
| ) | |
| # --- X-axis Ticks and Labels (Character labels at start of their byte sequences) --- | |
| char_label_positions = [] | |
| char_labels_for_ticks = [] | |
| current_byte_tracker = 0 | |
| if num_display_chars > 0 : # Ensure byte_counts_per_display_char is not empty | |
| for i_char in range(num_display_chars): | |
| char_label_positions.append(current_byte_tracker) | |
| char_labels_for_ticks.append(display_chars_list[i_char]) | |
| current_byte_tracker += byte_counts_per_display_char[i_char] | |
| ax.set_xticks(char_label_positions) | |
| ax.set_xticklabels(char_labels_for_ticks, rotation=0, fontsize=8) # User's original rotation and fontsize | |
| # --- Axes Configuration --- | |
| ax.set_ylabel("Entropy of Next Byte", fontsize=12) # User's original | |
| ax.set_xlabel("Characters (on underlying byte sequence)", fontsize=12) # Descriptive X-axis label | |
| ax.set_ylim(bottom=0) # User's original y-axis bottom limit | |
| # Set x-axis limits to show all bytes clearly from -0.5 to last_byte_idx + 0.5 | |
| if num_total_bytes_from_scores > 0: | |
| ax.set_xlim(left=-0.5, right=num_total_bytes_from_scores - 0.5) | |
| else: # Handle case of no bytes (e.g. if chars was empty and scores was empty) | |
| ax.set_xlim(left=-0.5, right=0.5) | |
| # Spines (as per user's original code removing top and right) | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| # Grid: User's original code did not explicitly add grid lines. | |
| plt.tight_layout() | |
| return fig | |