Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from openai import OpenAI, AuthenticationError, RateLimitError | |
| from dotenv import load_dotenv | |
| import os | |
| load_dotenv() | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| oai_client = OpenAI(api_key=openai_api_key) | |
| def calculate_embeddings(words): | |
| # Get word embeddings | |
| response = oai_client.embeddings.create(input=words, model="text-embedding-3-small") | |
| embeddings = [e.embedding for e in response.data] | |
| return embeddings | |
| def process_array(arr): | |
| # Ensure the input is a square array | |
| if arr.shape[0] != arr.shape[1]: | |
| raise ValueError("Input must be a square array") | |
| n = arr.shape[0] | |
| # Step 1: Keep only the upper triangle (excluding diagonal) | |
| upper_triangle = np.triu(arr, k=1) | |
| # Step 2: Reverse horizontally | |
| reversed_upper_triangle = np.fliplr(upper_triangle) | |
| # Step 3: Drop the final row and column | |
| result = reversed_upper_triangle[:-1, :-1] | |
| # Step 4: Mask the zeros | |
| masked_result = np.ma.masked_where(result == 0, result) | |
| return masked_result | |
| def plot_heatmap(masked_result, l1: list[str]): | |
| n, _ = masked_result.shape | |
| # Create the heatmap | |
| fig, ax = plt.subplots( | |
| figsize=(12, 10) | |
| ) # Increased figure size for better visibility | |
| # Create a custom colormap | |
| colors = ["darkred", "lightgray", "dodgerblue"] | |
| n_bins = 100 | |
| cmap = LinearSegmentedColormap.from_list("custom", colors, N=n_bins) | |
| cmap.set_bad("white") # Set color for masked values (zeros) to white | |
| # Plot the heatmap | |
| im = ax.imshow(masked_result, cmap=cmap, vmin=-1, vmax=1) | |
| # Add text annotations | |
| for i in range(n): | |
| for j in range(n): | |
| if not np.ma.is_masked(masked_result[i, j]): | |
| text = ax.text( | |
| j, | |
| i, | |
| f"{masked_result[i, j]:.2f}", | |
| ha="center", | |
| va="center", | |
| color="black", | |
| ) | |
| # Set y and x axis labels | |
| ax.set_yticks(range(n)) | |
| ax.set_yticklabels(l1[:-1]) | |
| ax.set_xticks(range(n)) | |
| ax.set_xticklabels(reversed(l1[1:])) | |
| # Move x-axis to the top | |
| ax.xaxis.tick_top() | |
| ax.xaxis.set_label_position("top") | |
| # Rotate x-axis labels for better readability | |
| plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor") | |
| # Add colorbar | |
| cbar = plt.colorbar(im) | |
| cbar.set_ticks([-1, 0, 1]) | |
| cbar.set_ticklabels(["-1", "0", "1"]) | |
| # Add title | |
| plt.title("Correlation Heatmap", pad=20) | |
| # Adjust layout and display the plot | |
| plt.tight_layout() | |
| return fig | |
| def plot_pca(embeddings, words): | |
| fig, ax = plt.subplots(figsize=(12, 10)) | |
| pca = PCA(n_components=2) | |
| embeddings_2d = pca.fit_transform(embeddings) | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1]) | |
| for i, word in enumerate(words): | |
| ax.annotate(word, (embeddings_2d[i, 0], embeddings_2d[i, 1])) | |
| ax.set_title("PCA of Word Embeddings") | |
| ax.set_xlabel("First Principal Component") | |
| ax.set_ylabel("Second Principal Component") | |
| plt.tight_layout() | |
| return fig | |
| def word_similarity_heatmap(input_text): | |
| words = [word.strip() for word in input_text.split(",")] | |
| if len(words) < 2: | |
| return "Please enter at least two words." | |
| try: | |
| embeddings = calculate_embeddings(words) | |
| similarities = cosine_similarity(embeddings) | |
| new_array = process_array(similarities) | |
| heatmap = plot_heatmap(new_array, words) | |
| pca_plot = plot_pca(embeddings, words) | |
| return heatmap, pca_plot | |
| # return heatmap | |
| except AuthenticationError as e: | |
| print("OpenAI API key is invalid. Please check your API key.") | |
| raise e | |
| except RateLimitError as e: | |
| print("OpenAI API rate limit exceeded. Please try again later.") | |
| raise e | |
| except Exception as e: | |
| print(f"An error occurred: {str(e)}") | |
| raise e | |
| iface = gr.Interface( | |
| fn=word_similarity_heatmap, # _and_pca, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter words separated by commas"), | |
| outputs=[gr.Plot(label="Similarity Heatmap"), gr.Plot(label="PCA Plot")], | |
| title="Word Similarity Heatmap and PCA Plot using OpenAI Embeddings", | |
| description="Enter a list of words separated by commas. The app will calculate the cosine similarity between their OpenAI embeddings, display a compact heatmap of the upper triangle similarities, and show a PCA plot of the embeddings.", | |
| ) | |
| # Launch the app | |
| iface.launch(share=True) | |