import json import os import numpy as np import matplotlib.pyplot as plt import gradio as gr # Update this to the correct path of your folder task_folder = "./evaluation" # Use absolute path if necessary def plot_arc_problem(task_file_path, data_type="train"): """ Visualize input-output pairs for a given ARC task file. Args: task_file_path (str): Path to the ARC task JSON file. data_type (str): 'train' or 'test' to visualize respective examples. Returns: matplotlib.figure.Figure or str: The plotted figure or an error message. """ try: # Load the JSON file with open(task_file_path, 'r') as f: task = json.load(f) # Get the data type (train or test) pairs pairs = task.get(data_type, []) if not pairs: # Check if the section exists and has data return f"No '{data_type}' data found in the selected file." # Create a figure with subplots for each pair fig, axes = plt.subplots(len(pairs), 2, figsize=(10, 5 * len(pairs))) fig.suptitle(f"ARC Task: {os.path.basename(task_file_path)} ({data_type.capitalize()})", fontsize=16) # Handle case where there is only one pair if len(pairs) == 1: axes = [axes] for idx, pair in enumerate(pairs): input_grid = np.array(pair['input']) output_grid = np.array(pair['output']) # Plot input grid axes[idx][0].imshow(input_grid, cmap='tab20', interpolation='none') axes[idx][0].set_title(f"Input Pair {idx + 1} ({data_type.capitalize()})") axes[idx][0].axis('off') # Plot output grid axes[idx][1].imshow(output_grid, cmap='tab20', interpolation='none') axes[idx][1].set_title(f"Output Pair {idx + 1} ({data_type.capitalize()})") axes[idx][1].axis('off') plt.tight_layout(rect=[0, 0, 1, 0.96]) return fig except Exception as e: print("Error in plot_arc_problem:", e) return f"An error occurred while plotting the {data_type} data: {e}" def visualize_task(file_name, data_type): """ Load and visualize the ARC task for a given file and data type. Args: file_name (str): Name of the JSON file in the folder. data_type (str): 'train' or 'test'. Returns: matplotlib.figure.Figure or str: Figure of the visualized task or an error message. """ try: print(f"Selected file: {file_name}, Data type: {data_type}") # Debugging task_file_path = os.path.join(task_folder, file_name) result = plot_arc_problem(task_file_path, data_type) return result except Exception as e: print("Error in visualize_task:", e) return f"An error occurred while visualizing the task: {e}" # Gradio Interface task_files = [f for f in os.listdir(task_folder) if f.endswith('.json')] interface = gr.Interface( fn=visualize_task, inputs=[ gr.Dropdown(choices=task_files, label="Select ARC Task File"), gr.Radio(choices=["train", "test"], label="Select Data Type to Visualize", value="train") ], outputs="plot", title="ARC Task Visualizer", description="Select a task file and data type (train or test) to visualize its input-output grids." ) if __name__ == "__main__": print("Task files:", task_files) # Debugging interface.launch()