Spaces:
Running
Running
| import gradio as gr | |
| import json | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import io | |
| import base64 | |
| import math | |
| # Function to process and visualize log probs | |
| def visualize_logprobs(json_input): | |
| try: | |
| # Parse the JSON input | |
| data = json.loads(json_input) | |
| if isinstance(data, dict) and "content" in data: | |
| content = data["content"] | |
| elif isinstance(data, list): | |
| content = data | |
| else: | |
| raise ValueError("Input must be a list or dictionary with 'content' key") | |
| # Extract tokens and log probs, skipping None or non-finite values | |
| tokens = [] | |
| logprobs = [] | |
| for entry in content: | |
| if ( | |
| "logprob" in entry | |
| and entry["logprob"] is not None | |
| and math.isfinite(entry["logprob"]) | |
| ): | |
| tokens.append(entry["token"]) | |
| logprobs.append(entry["logprob"]) | |
| # Prepare table data, handling None in top_logprobs | |
| table_data = [] | |
| for entry in content: | |
| # Only include entries with finite logprob and non-None top_logprobs | |
| if ( | |
| "logprob" in entry | |
| and entry["logprob"] is not None | |
| and math.isfinite(entry["logprob"]) | |
| and "top_logprobs" in entry | |
| and entry["top_logprobs"] is not None | |
| ): | |
| token = entry["token"] | |
| logprob = entry["logprob"] | |
| top_logprobs = entry["top_logprobs"] | |
| # Extract top 3 alternatives from top_logprobs | |
| top_3 = sorted( | |
| top_logprobs.items(), key=lambda x: x[1], reverse=True | |
| )[:3] | |
| row = [token, f"{logprob:.4f}"] | |
| for alt_token, alt_logprob in top_3: | |
| row.append(f"{alt_token}: {alt_logprob:.4f}") | |
| # Pad with empty strings if fewer than 3 alternatives | |
| while len(row) < 5: | |
| row.append("") | |
| table_data.append(row) | |
| # Create the plot | |
| if logprobs: | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(range(len(logprobs)), logprobs, marker="o", linestyle="-", color="b") | |
| plt.title("Log Probabilities of Generated Tokens") | |
| plt.xlabel("Token Position") | |
| plt.ylabel("Log Probability") | |
| plt.grid(True) | |
| plt.xticks(range(len(logprobs)), tokens, rotation=45, ha="right") | |
| plt.tight_layout() | |
| # Save plot to a bytes buffer | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches="tight") | |
| buf.seek(0) | |
| plt.close() | |
| # Convert to base64 for Gradio | |
| img_bytes = buf.getvalue() | |
| img_base64 = base64.b64encode(img_bytes).decode("utf-8") | |
| img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%; height: auto;">' | |
| else: | |
| img_html = "No finite log probabilities to plot." | |
| # Create DataFrame for the table | |
| df = ( | |
| pd.DataFrame( | |
| table_data, | |
| columns=[ | |
| "Token", | |
| "Log Prob", | |
| "Top 1 Alternative", | |
| "Top 2 Alternative", | |
| "Top 3 Alternative", | |
| ], | |
| ) | |
| if table_data | |
| else None | |
| ) | |
| # Generate colored text | |
| if logprobs: | |
| min_logprob = min(logprobs) | |
| max_logprob = max(logprobs) | |
| if max_logprob == min_logprob: | |
| normalized_probs = [0.5] * len(logprobs) | |
| else: | |
| normalized_probs = [ | |
| (lp - min_logprob) / (max_logprob - min_logprob) for lp in logprobs | |
| ] | |
| colored_text = "" | |
| for i, (token, norm_prob) in enumerate(zip(tokens, normalized_probs)): | |
| r = int(255 * (1 - norm_prob)) # Red for low confidence | |
| g = int(255 * norm_prob) # Green for high confidence | |
| b = 0 | |
| color = f"rgb({r}, {g}, {b})" | |
| colored_text += f'<span style="color: {color}; font-weight: bold;">{token}</span>' | |
| if i < len(tokens) - 1: | |
| colored_text += " " | |
| colored_text_html = f"<p>{colored_text}</p>" | |
| else: | |
| colored_text_html = "No finite log probabilities to display." | |
| return img_html, df, colored_text_html | |
| except Exception as e: | |
| return f"Error: {str(e)}", None, None | |
| # Gradio interface | |
| with gr.Blocks(title="Log Probability Visualizer") as app: | |
| gr.Markdown("# Log Probability Visualizer") | |
| gr.Markdown( | |
| "Paste your JSON or Python dictionary log prob data below to visualize the tokens and their probabilities." | |
| ) | |
| json_input = gr.Textbox( | |
| label="JSON Input", | |
| lines=10, | |
| placeholder="Paste your JSON or Python dict here...", | |
| ) | |
| plot_output = gr.HTML(label="Log Probability Plot") | |
| table_output = gr.Dataframe(label="Token Log Probabilities and Top Alternatives") | |
| text_output = gr.HTML(label="Colored Text (Confidence Visualization)") | |
| btn = gr.Button("Visualize") | |
| btn.click( | |
| fn=visualize_logprobs, | |
| inputs=json_input, | |
| outputs=[plot_output, table_output, text_output], | |
| ) | |
| app.launch() |