File size: 8,347 Bytes
8788c39
 
c873cda
8788c39
 
 
7d0101d
 
 
8788c39
6ab2bc7
8788c39
6ab2bc7
8788c39
6ab2bc7
 
c090ed3
 
 
 
 
 
 
8788c39
afd2381
a17ff57
 
 
6ab2bc7
c928ab8
6ab2bc7
8788c39
c41c60e
 
6ab2bc7
617e44f
 
 
8788c39
 
afd2381
c41c60e
8788c39
c41c60e
 
6ab2bc7
 
afd2381
 
 
 
 
 
6ab2bc7
 
72ac605
a17ff57
 
 
 
 
 
 
6ab2bc7
 
a17ff57
6ab2bc7
afd2381
6ab2bc7
55bcef6
6ab2bc7
360bdcd
6ab2bc7
 
55bcef6
6ab2bc7
 
 
 
 
 
 
 
 
f91a0ec
c090ed3
 
 
 
 
 
 
 
 
 
6ab2bc7
 
 
c090ed3
 
6ab2bc7
 
 
 
 
 
 
 
 
 
 
 
 
 
360bdcd
6ab2bc7
 
8788c39
 
 
6ab2bc7
 
 
 
c928ab8
6ab2bc7
c928ab8
 
55bcef6
c928ab8
 
6ab2bc7
8788c39
afd2381
3a2c4d3
55bcef6
 
 
 
 
 
 
afd2381
 
 
 
6ab2bc7
 
 
 
afd2381
 
6ab2bc7
afd2381
6ab2bc7
 
8788c39
afd2381
 
8788c39
afd2381
8788c39
6ab2bc7
 
 
 
 
 
 
7d9e2f5
6ab2bc7
 
55bcef6
6ab2bc7
 
 
 
55bcef6
6ab2bc7
 
7d9e2f5
8788c39
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import gradio as gr
import spaces
import math
from inflections_funcs import start_model, make_beams, get_beam_tokens, calculate_score_vectors

# Global model initialization
print("Loading model...")
model, processor = start_model()
print("Model loaded.")

def generate_beam_html(index, all_beams_data, dark_mode):
    """
    Helper to construct the main and detailed HTML for a specific beam index.
    """
    beam_tokens = all_beams_data["beam_tokens"][index]
    beam_scores = all_beams_data["score_vectors"][index]

    # Use log-prob (beam search) or average prob (sampling)
    raw_score = all_beams_data["sequences_scores"][index]
    if all_beams_data.get("is_sampling", False):
        beam_overall_score = raw_score
    else:
        beam_overall_score = math.exp(raw_score)

    # Construct Main HTML output
    main_container_style = "border: 2px solid lightblue; padding: 15px; border-radius: 8px; font-family: sans-serif;"
    title_style = "font-weight: bold; margin-bottom: 10px; font-size: 1.1em;"

    main_html = f'<div style="{main_container_style}">'
    main_html += f'<div style="{title_style}">Response {index + 1} | Score: {beam_overall_score:.4f}</div>'
    main_html += '<div style="font-family: monospace; white-space: pre-wrap; line-height: 1.5; overflow-wrap: break-word; word-wrap: break-word;">'

    normal_color = "white" if dark_mode else "black"

    for token, score in zip(beam_tokens, beam_scores):
        # Comprehensive replacement of SentencePiece space (U+2581),
        # literal underscores (U+005F), and non-breaking spaces (U+00A0).
        display_token = token.replace('▁', ' ').replace('_', ' ').replace(' ', ' ')
        if score < 0.6:
            color = "red"
            bg_color = "#441111" if dark_mode else "#ffe6e6"
            style = f"color: {color}; background-color: {bg_color};"
        else:
            color = normal_color
            style = f"color: {color};"
        main_html += f'<span style="{style}" title="Score: {score:.4f}">{display_token}</span>'
    main_html += '</div></div>'

    # Construct Detailed HTML output
    detail_table_style = "border-collapse: collapse; width: 100%; max-width: 500px; font-family: monospace;"
    th_style = "border: 1px solid #ccc; padding: 8px; text-align: left; background-color: #f2f2f2;" if not dark_mode else "border: 1px solid #444; padding: 8px; text-align: left; background-color: #333;"
    td_style = "border: 1px solid #ccc; padding: 8px; text-align: left;" if not dark_mode else "border: 1px solid #444; padding: 8px; text-align: left;"

    detail_html = f'<table style="{detail_table_style}"><thead><tr><th style="{th_style}">Token</th><th style="{th_style}">Score</th></tr></thead><tbody>'
    for token, score in zip(beam_tokens, beam_scores):
        display_token = token.replace(' ', ' ').replace('_', ' ').replace(' ', ' ')
        if score < 0.6:
            color = "red"
            bg_color = "#441111" if dark_mode else "#ffe6e6"
            token_cell_style = f"{td_style} color: {color}; background-color: {bg_color};"
        else:
            color = normal_color
            token_cell_style = f"{td_style} color: {color};"
        detail_html += f'<tr><td style="{token_cell_style}">{display_token}</td><td style="{td_style}">{score:.4f}</td></tr>'
    detail_html += '</tbody></table>'

    return main_html, detail_html

@spaces.GPU
def predict(prompt, dark_mode, temperature):
    """
    Generates responses for 3 beams and returns the first beam's visualization and visibility for controls.
    """
    # Generate beams
    generated_dicts, transcription = make_beams(model, processor, prompt, temperature=temperature)

    # Get tokens and scores for all beams
    beam_tokens = get_beam_tokens(generated_dicts, processor)
    score_vectors = calculate_score_vectors(model, generated_dicts)

    if not beam_tokens or not score_vectors:
        return "<span style='color: grey'>No tokens generated.</span>", "", gr.update(visible=False), gr.update(visible=False), None, 0

    # Initialize state with all beam data
    # Convert tensors to lists to avoid ZeroGPU serialization errors

    # Safely handle sequence scores
    if hasattr(generated_dicts, 'sequences_scores') and generated_dicts.sequences_scores is not None:
        seq_scores = generated_dicts.sequences_scores.tolist() if hasattr(generated_dicts.sequences_scores, 'tolist') else generated_dicts.sequences_scores
        is_sampling = False
    else:
        # For sampling, approximate overall score as the average probability of tokens in the beam
        seq_scores = [sum(scores) / len(scores) if scores else 0.0 for scores in score_vectors]
        is_sampling = True

    all_beams_data = {
        "beam_tokens": beam_tokens,
        "score_vectors": score_vectors,
        "sequences_scores": seq_scores,
        "is_sampling": is_sampling
    }

    # Generate HTML for the first beam (index 0)
    main_html, detail_html = generate_beam_html(0, all_beams_data, dark_mode)

    return main_html, detail_html, gr.update(visible=True), gr.update(visible=True), all_beams_data, 0

def switch_beam(current_index, all_beams_data, dark_mode):
    """
    Increments the beam index and returns the updated HTML.
    """
    if all_beams_data is None:
        return None, None, 0

    new_index = (current_index + 1) % 3
    main_html, detail_html = generate_beam_html(new_index, all_beams_data, dark_mode)
    return main_html, detail_html, new_index

# Gradio Interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Image("Stochastic_parrot.JPG")
        with gr.Column(scale=3):
            gr.Markdown("# InflectionLM: Output and Token Visualization")
            gr.Markdown('''Input a prompt. The model will:
                            - Generate multiple independent responses using Top-P and Top-K sampling to ensure diversity.
                            - Display the first response, highlighting any word or parts of words (tokens) with a probability score < 0.6 in red.
                            - These unconfident words/tokens can be considered inflection points, or places where the output could change easily.
                            - Provide a toggle button to view detailed word/token probabilities for the response.
                            - Switch between the generated responses to see different possible outputs and their associated token probabilities.
                        ''')

    with gr.Column():
        prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=1, show_label=True)
        with gr.Row():
            dark_mode_toggle = gr.Checkbox(label="Dark Mode", value=True)
            temp_radio = gr.Radio(
                label="Temperature",
                choices=[0.0, 0.5, 1.0, 1.5],
                value=1.0
            )
        submit_btn = gr.Button("Generate")

        output_html = gr.HTML(label="Highlighted Beam")

        with gr.Row():
            toggle_btn = gr.Button("Show Token Details", visible=False)
            next_beam_btn = gr.Button("Next Beam", visible=False)

        detail_html = gr.HTML(label="Detailed Token Probabilities", visible=False)

    # State components
    visibility_state = gr.State(value=False)
    all_beams_state = gr.State(value=None)
    current_beam_state = gr.State(value=0)

    def toggle_view(visible):
        return not visible, gr.update(visible=not visible)

    toggle_btn.click(fn=toggle_view, inputs=visibility_state, outputs=[visibility_state, detail_html])

    # Switch beam logic
    next_beam_btn.click(
        fn=switch_beam,
        inputs=[current_beam_state, all_beams_state, dark_mode_toggle],
        outputs=[output_html, detail_html, current_beam_state]
    )

    # Trigger generation on both button click and Enter key in textbox
    submit_btn.click(
        fn=predict,
        inputs=[prompt_input, dark_mode_toggle, temp_radio],
        outputs=[output_html, detail_html, toggle_btn, next_beam_btn, all_beams_state, current_beam_state]
    )
    prompt_input.submit(
        fn=predict,
        inputs=[prompt_input, dark_mode_toggle, temp_radio],
        outputs=[output_html, detail_html, toggle_btn, next_beam_btn, all_beams_state, current_beam_state]
    )

if __name__ == "__main__":
    demo.launch()