Spaces:
Running on Zero
Running on Zero
File size: 8,347 Bytes
8788c39 c873cda 8788c39 7d0101d 8788c39 6ab2bc7 8788c39 6ab2bc7 8788c39 6ab2bc7 c090ed3 8788c39 afd2381 a17ff57 6ab2bc7 c928ab8 6ab2bc7 8788c39 c41c60e 6ab2bc7 617e44f 8788c39 afd2381 c41c60e 8788c39 c41c60e 6ab2bc7 afd2381 6ab2bc7 72ac605 a17ff57 6ab2bc7 a17ff57 6ab2bc7 afd2381 6ab2bc7 55bcef6 6ab2bc7 360bdcd 6ab2bc7 55bcef6 6ab2bc7 f91a0ec c090ed3 6ab2bc7 c090ed3 6ab2bc7 360bdcd 6ab2bc7 8788c39 6ab2bc7 c928ab8 6ab2bc7 c928ab8 55bcef6 c928ab8 6ab2bc7 8788c39 afd2381 3a2c4d3 55bcef6 afd2381 6ab2bc7 afd2381 6ab2bc7 afd2381 6ab2bc7 8788c39 afd2381 8788c39 afd2381 8788c39 6ab2bc7 7d9e2f5 6ab2bc7 55bcef6 6ab2bc7 55bcef6 6ab2bc7 7d9e2f5 8788c39 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | import gradio as gr
import spaces
import math
from inflections_funcs import start_model, make_beams, get_beam_tokens, calculate_score_vectors
# Global model initialization
print("Loading model...")
model, processor = start_model()
print("Model loaded.")
def generate_beam_html(index, all_beams_data, dark_mode):
"""
Helper to construct the main and detailed HTML for a specific beam index.
"""
beam_tokens = all_beams_data["beam_tokens"][index]
beam_scores = all_beams_data["score_vectors"][index]
# Use log-prob (beam search) or average prob (sampling)
raw_score = all_beams_data["sequences_scores"][index]
if all_beams_data.get("is_sampling", False):
beam_overall_score = raw_score
else:
beam_overall_score = math.exp(raw_score)
# Construct Main HTML output
main_container_style = "border: 2px solid lightblue; padding: 15px; border-radius: 8px; font-family: sans-serif;"
title_style = "font-weight: bold; margin-bottom: 10px; font-size: 1.1em;"
main_html = f'<div style="{main_container_style}">'
main_html += f'<div style="{title_style}">Response {index + 1} | Score: {beam_overall_score:.4f}</div>'
main_html += '<div style="font-family: monospace; white-space: pre-wrap; line-height: 1.5; overflow-wrap: break-word; word-wrap: break-word;">'
normal_color = "white" if dark_mode else "black"
for token, score in zip(beam_tokens, beam_scores):
# Comprehensive replacement of SentencePiece space (U+2581),
# literal underscores (U+005F), and non-breaking spaces (U+00A0).
display_token = token.replace('▁', ' ').replace('_', ' ').replace(' ', ' ')
if score < 0.6:
color = "red"
bg_color = "#441111" if dark_mode else "#ffe6e6"
style = f"color: {color}; background-color: {bg_color};"
else:
color = normal_color
style = f"color: {color};"
main_html += f'<span style="{style}" title="Score: {score:.4f}">{display_token}</span>'
main_html += '</div></div>'
# Construct Detailed HTML output
detail_table_style = "border-collapse: collapse; width: 100%; max-width: 500px; font-family: monospace;"
th_style = "border: 1px solid #ccc; padding: 8px; text-align: left; background-color: #f2f2f2;" if not dark_mode else "border: 1px solid #444; padding: 8px; text-align: left; background-color: #333;"
td_style = "border: 1px solid #ccc; padding: 8px; text-align: left;" if not dark_mode else "border: 1px solid #444; padding: 8px; text-align: left;"
detail_html = f'<table style="{detail_table_style}"><thead><tr><th style="{th_style}">Token</th><th style="{th_style}">Score</th></tr></thead><tbody>'
for token, score in zip(beam_tokens, beam_scores):
display_token = token.replace(' ', ' ').replace('_', ' ').replace(' ', ' ')
if score < 0.6:
color = "red"
bg_color = "#441111" if dark_mode else "#ffe6e6"
token_cell_style = f"{td_style} color: {color}; background-color: {bg_color};"
else:
color = normal_color
token_cell_style = f"{td_style} color: {color};"
detail_html += f'<tr><td style="{token_cell_style}">{display_token}</td><td style="{td_style}">{score:.4f}</td></tr>'
detail_html += '</tbody></table>'
return main_html, detail_html
@spaces.GPU
def predict(prompt, dark_mode, temperature):
"""
Generates responses for 3 beams and returns the first beam's visualization and visibility for controls.
"""
# Generate beams
generated_dicts, transcription = make_beams(model, processor, prompt, temperature=temperature)
# Get tokens and scores for all beams
beam_tokens = get_beam_tokens(generated_dicts, processor)
score_vectors = calculate_score_vectors(model, generated_dicts)
if not beam_tokens or not score_vectors:
return "<span style='color: grey'>No tokens generated.</span>", "", gr.update(visible=False), gr.update(visible=False), None, 0
# Initialize state with all beam data
# Convert tensors to lists to avoid ZeroGPU serialization errors
# Safely handle sequence scores
if hasattr(generated_dicts, 'sequences_scores') and generated_dicts.sequences_scores is not None:
seq_scores = generated_dicts.sequences_scores.tolist() if hasattr(generated_dicts.sequences_scores, 'tolist') else generated_dicts.sequences_scores
is_sampling = False
else:
# For sampling, approximate overall score as the average probability of tokens in the beam
seq_scores = [sum(scores) / len(scores) if scores else 0.0 for scores in score_vectors]
is_sampling = True
all_beams_data = {
"beam_tokens": beam_tokens,
"score_vectors": score_vectors,
"sequences_scores": seq_scores,
"is_sampling": is_sampling
}
# Generate HTML for the first beam (index 0)
main_html, detail_html = generate_beam_html(0, all_beams_data, dark_mode)
return main_html, detail_html, gr.update(visible=True), gr.update(visible=True), all_beams_data, 0
def switch_beam(current_index, all_beams_data, dark_mode):
"""
Increments the beam index and returns the updated HTML.
"""
if all_beams_data is None:
return None, None, 0
new_index = (current_index + 1) % 3
main_html, detail_html = generate_beam_html(new_index, all_beams_data, dark_mode)
return main_html, detail_html, new_index
# Gradio Interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Image("Stochastic_parrot.JPG")
with gr.Column(scale=3):
gr.Markdown("# InflectionLM: Output and Token Visualization")
gr.Markdown('''Input a prompt. The model will:
- Generate multiple independent responses using Top-P and Top-K sampling to ensure diversity.
- Display the first response, highlighting any word or parts of words (tokens) with a probability score < 0.6 in red.
- These unconfident words/tokens can be considered inflection points, or places where the output could change easily.
- Provide a toggle button to view detailed word/token probabilities for the response.
- Switch between the generated responses to see different possible outputs and their associated token probabilities.
''')
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=1, show_label=True)
with gr.Row():
dark_mode_toggle = gr.Checkbox(label="Dark Mode", value=True)
temp_radio = gr.Radio(
label="Temperature",
choices=[0.0, 0.5, 1.0, 1.5],
value=1.0
)
submit_btn = gr.Button("Generate")
output_html = gr.HTML(label="Highlighted Beam")
with gr.Row():
toggle_btn = gr.Button("Show Token Details", visible=False)
next_beam_btn = gr.Button("Next Beam", visible=False)
detail_html = gr.HTML(label="Detailed Token Probabilities", visible=False)
# State components
visibility_state = gr.State(value=False)
all_beams_state = gr.State(value=None)
current_beam_state = gr.State(value=0)
def toggle_view(visible):
return not visible, gr.update(visible=not visible)
toggle_btn.click(fn=toggle_view, inputs=visibility_state, outputs=[visibility_state, detail_html])
# Switch beam logic
next_beam_btn.click(
fn=switch_beam,
inputs=[current_beam_state, all_beams_state, dark_mode_toggle],
outputs=[output_html, detail_html, current_beam_state]
)
# Trigger generation on both button click and Enter key in textbox
submit_btn.click(
fn=predict,
inputs=[prompt_input, dark_mode_toggle, temp_radio],
outputs=[output_html, detail_html, toggle_btn, next_beam_btn, all_beams_state, current_beam_state]
)
prompt_input.submit(
fn=predict,
inputs=[prompt_input, dark_mode_toggle, temp_radio],
outputs=[output_html, detail_html, toggle_btn, next_beam_btn, all_beams_state, current_beam_state]
)
if __name__ == "__main__":
demo.launch()
|