File size: 4,363 Bytes
ac72c21
43130a6
 
 
 
 
 
 
 
 
 
 
 
 
 
ac72c21
 
617bd81
 
 
ac72c21
617bd81
 
 
 
 
 
 
 
ac72c21
 
 
 
6c99f7c
 
ac72c21
 
9a72c69
 
ac72c21
6c99f7c
ac72c21
 
617bd81
ac72c21
 
 
6c99f7c
ac72c21
 
6c99f7c
ac72c21
6c99f7c
 
617bd81
ac72c21
 
6c99f7c
ac72c21
617bd81
ac72c21
 
6c99f7c
 
617bd81
ac72c21
 
617bd81
ac72c21
 
617bd81
9a72c69
ac72c21
 
9a72c69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175fea5
9a72c69
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper

@spaces.GPU
def fake_gpu():
  pass

import numpy as np
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Available models
AVAILABLE_MODELS = {
    "distilgpt2": "distilgpt2",
    "bloomz-560m": "bigscience/bloomz-560m",
    "gpt2-medium": "gpt2-medium",
    "opt-350m": "facebook/opt-350m",
    "pythia-160m": "EleutherAI/pythia-160m"
}

# Initialize model and tokenizer globally
current_model = None
current_tokenizer = None
current_model_name = None

def load_model(model_name):
    global current_model, current_tokenizer, current_model_name
    if current_model_name != model_name:
        current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
        current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
        current_model_name = model_name

def get_next_token_predictions(text, model_name, top_k=10):
    global current_model, current_tokenizer
    
    # Load model if needed
    if current_model_name != model_name:
        load_model(model_name)
    
    # Get predictions
    inputs = current_tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = current_model(**inputs)
        logits = outputs.logits[0, -1, :]
        probs = torch.nn.functional.softmax(logits, dim=-1)
        
    top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
    top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
    
    return top_k_tokens, top_k_probs.tolist()

def predict_next_token(text, model_name, custom_token=""):
    # Add custom token if provided
    if custom_token:
        text += custom_token
    
    # Get predictions
    tokens, probs = get_next_token_predictions(text, model_name)
    
    # Format predictions
    predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(tokens, probs)])
    
    return text, gr.update(choices=[f"'{t}'" for t in tokens]), predictions

# Create the interface
with gr.Blocks() as demo:
    gr.Markdown("# Interactive Text Generation with Transformer Models")
    
    gr.Markdown("""
    This application allows you to interactively generate text using various transformer models. 
    You can either select from the predicted next tokens or write your own tokens to continue the text generation.

    Select a model, start typing or choose from the predicted tokens, and see how the model continues your text!
    """)
    
    with gr.Row():
        text_input = gr.Textbox(
            lines=5,
            label="Text",
            placeholder="Type your text here...",
            value="The quick brown fox"
        )
        
    with gr.Row():
        model_dropdown = gr.Dropdown(
            choices=list(AVAILABLE_MODELS.keys()),
            value="distilgpt2",
            label="Select Model"
        )
        
    with gr.Row():
        custom_input = gr.Textbox(
            label="Custom token (optional)",
            placeholder="Type a custom token..."
        )
        
    with gr.Row():
        token_dropdown = gr.Dropdown(
            label="Predicted tokens",
            choices=[]
        )
        
    with gr.Row():
        predictions_output = gr.Textbox(
            lines=10,
            label="Token probabilities"
        )
    
    # Set up event handlers
    text_input.change(
        predict_next_token,
        inputs=[text_input, model_dropdown, custom_input],
        outputs=[text_input, token_dropdown, predictions_output]
    )
    
    model_dropdown.change(
        predict_next_token,
        inputs=[text_input, model_dropdown, custom_input],
        outputs=[text_input, token_dropdown, predictions_output]
    )
    
    custom_input.change(
        predict_next_token,
        inputs=[text_input, model_dropdown, custom_input],
        outputs=[text_input, token_dropdown, predictions_output]
    )
    
    token_dropdown.change(
        predict_next_token,
        inputs=[text_input, model_dropdown, gr.Textbox(value="")],
        outputs=[text_input, token_dropdown, predictions_output]
    )

demo.queue().launch()