File size: 5,003 Bytes
ac72c21
d7f8dad
43130a6
 
 
 
 
 
 
 
 
 
 
 
93a3f9a
d7f8dad
 
 
 
93cbacc
d7f8dad
 
e5cdcee
 
 
 
 
43130a6
ac72c21
617bd81
 
 
c9abde4
214932f
949913d
 
617bd81
 
310d018
ac72c21
 
 
310d018
6c99f7c
 
310d018
ac72c21
 
310d018
9a72c69
ac72c21
6c99f7c
310d018
949913d
310d018
d7f8dad
ac72c21
310d018
ac72c21
617bd81
310d018
ac72c21
 
6c99f7c
310d018
 
6c99f7c
ac72c21
6c99f7c
 
310d018
ac72c21
 
6c99f7c
310d018
617bd81
93cbacc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310d018
93cbacc
6c99f7c
 
310d018
 
 
93cbacc
 
 
 
ac72c21
310d018
 
 
5f1ad57
 
310d018
 
 
9a72c69
310d018
 
 
 
 
9a72c69
 
 
 
 
 
 
93a3f9a
9a72c69
93a3f9a
 
310d018
93a3f9a
 
9a72c69
93a3f9a
310d018
 
 
 
 
 
 
 
 
93a3f9a
 
310d018
9a72c69
 
310d018
9a72c69
 
310d018
 
9a72c69
93cbacc
310d018
 
93a3f9a
9a72c69
310d018
93cbacc
9a72c69
175fea5
310d018
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
# Handle Spaces GPU
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper

@spaces.GPU
def fake_gpu():
    pass
    
import numpy as np
import torch
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
from huggingface_hub import login

# Authenticate
HF_TOKEN = os.getenv('HF_TOKEN')
login(token=HF_TOKEN)

# Available models
AVAILABLE_MODELS = {
    "bloomz-560m": "bigscience/bloomz-560m",
    "opt-350m": "facebook/opt-350m",
    "pythia-160m": "EleutherAI/pythia-160m",
    "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
    "deepseek-small": "deepseek-ai/DeepSeek-V2-Lite",
    "llama": "meta-llama/Llama-3.2-1B"
}

# Initialize model and tokenizer
current_model = None
current_tokenizer = None
current_model_name = None
device = "cuda" if torch.cuda.is_available() else "cpu"

def load_model(model_name):
    """Load the selected model and tokenizer."""
    global current_model, current_tokenizer, current_model_name
    if current_model_name != model_name:
        current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]).to(device)
        current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
        current_model_name = model_name

# Load the default model at startup
load_model("mistral-7b")

@spaces.GPU()
def get_next_token_predictions(text, model_name, top_k=10):
    """Generate the next token predictions with their probabilities."""
    global current_model, current_tokenizer
    
    # Load the model if it has changed
    if current_model_name != model_name:
        load_model(model_name)
    
    inputs = current_tokenizer(text, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = current_model(**inputs)
        logits = outputs.logits[0, -1, :]
        probs = torch.nn.functional.softmax(logits, dim=-1)
    
    top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
    top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
    
    return top_k_tokens, top_k_probs.cpu().tolist()

def plot_probabilities(tokens, probs):
    """Generate a horizontal bar chart for token probabilities."""
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.barh(tokens[::-1], probs[::-1], color="skyblue")
    ax.set_xlabel("Probability")
    ax.set_title("Next Token Predictions")
    plt.tight_layout()

    # Save plot as an image and return the file path
    plot_path = "token_probabilities.png"
    plt.savefig(plot_path)
    plt.close(fig)
    
    return plot_path

def predict_next_token(model_name, text, top_k, custom_token=""):
    """Get predictions and update the UI with text and a chart."""
    if custom_token:
        text += custom_token

    tokens, probs = get_next_token_predictions(text, model_name, top_k)

    # Generate bar chart
    plot_path = plot_probabilities(tokens, probs)

    return gr.update(choices=[f"'{t}'" for t in tokens]), plot_path

def append_selected_token(text, selected_token):
    """Append selected token from dropdown to the text input."""
    if selected_token:
        clean_token = selected_token.strip("'")
        text += f" {clean_token}"
    return text

# Create the UI
with gr.Blocks() as demo:
    gr.Markdown("# 🔥 Interactive Text Prediction with Transformers")
    gr.Markdown(
        "This application lets you interactively generate text using multiple transformer models. "
        "Choose a model, type your text, and explore token predictions."
    )
    
    with gr.Row():
        model_dropdown = gr.Dropdown(
            choices=list(AVAILABLE_MODELS.keys()),
            value="distilgpt2",
            label="Select Model"
        )
    
    with gr.Row():
        text_input = gr.Textbox(
            lines=5,
            label="Input Text",
            placeholder="Type your text here...",
            value="The quick brown fox"
        )
    
    with gr.Row():
        top_k_slider = gr.Slider(
            minimum=1,
            maximum=20,
            value=10,
            step=1,
            label="Top-k Predictions"
        )

    with gr.Row():
        predict_button = gr.Button("Predict")
    
    with gr.Row():
        token_dropdown = gr.Dropdown(
            label="Predicted Tokens",
            choices=[]
        )
        append_button = gr.Button("Append Token")

    with gr.Row():
        chart_output = gr.Image(label="Token Probability Chart")

    # Button click events
    predict_button.click(
        predict_next_token,
        inputs=[model_dropdown, text_input, top_k_slider],
        outputs=[token_dropdown, chart_output]
    )

    append_button.click(
        append_selected_token,
        inputs=[text_input, token_dropdown],
        outputs=text_input
    )

demo.queue().launch()