File size: 3,239 Bytes
ac72c21
 
 
617bd81
 
 
ac72c21
617bd81
 
 
 
 
 
 
 
ac72c21
 
 
 
 
 
 
6c99f7c
 
ac72c21
 
 
 
 
6c99f7c
ac72c21
 
617bd81
ac72c21
 
 
6c99f7c
ac72c21
 
6c99f7c
ac72c21
6c99f7c
 
617bd81
ac72c21
 
6c99f7c
ac72c21
617bd81
ac72c21
 
6c99f7c
 
617bd81
ac72c21
 
617bd81
ac72c21
 
617bd81
ac72c21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175fea5
ac72c21
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import numpy as np
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Available models
AVAILABLE_MODELS = {
    "distilgpt2": "distilgpt2",
    "bloomz-560m": "bigscience/bloomz-560m",
    "gpt2-medium": "gpt2-medium",
    "opt-350m": "facebook/opt-350m",
    "pythia-160m": "EleutherAI/pythia-160m"
}

# Access token for Hugging Face
HF_TOKEN = os.getenv('HF_TOKEN')

# Initialize model and tokenizer globally
current_model = None
current_tokenizer = None
current_model_name = None

def load_model(model_name):
    global current_model, current_tokenizer, current_model_name
    if current_model_name != model_name:
        current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name], use_auth_token=HF_TOKEN)
        current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name], use_auth_token=HF_TOKEN)
        current_model_name = model_name

def get_next_token_predictions(text, model_name, top_k=10):
    global current_model, current_tokenizer
    
    # Load model if needed
    if current_model_name != model_name:
        load_model(model_name)
    
    # Get predictions
    inputs = current_tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = current_model(**inputs)
        logits = outputs.logits[0, -1, :]
        probs = torch.nn.functional.softmax(logits, dim=-1)
        
    top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
    top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
    
    return top_k_tokens, top_k_probs.tolist()

def predict_next_token(text, model_name, custom_token=""):
    # Add custom token if provided
    if custom_token:
        text += custom_token
    
    # Get predictions
    tokens, probs = get_next_token_predictions(text, model_name)
    
    # Format predictions
    predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(tokens, probs)])
    
    return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions

# Page content
title = "Interactive Text Generation with Transformer Models"
description = """
This application allows you to interactively generate text using various transformer models. 
You can either select from the predicted next tokens or write your own tokens to continue the text generation.

Select a model, start typing or choose from the predicted tokens, and see how the model continues your text!
"""

# Example inputs
examples = [
    ["The quick brown fox", "distilgpt2"],
    ["In a galaxy far", "gpt2-medium"],
    ["Once upon a time", "opt-350m"],
]

# Create the interface
app = gr.Interface(
    fn=predict_next_token,
    inputs=[
        gr.Textbox(lines=5, label="Text"),
        gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), value="distilgpt2", label="Model"),
        gr.Textbox(label="Custom token (optional)")
    ],
    outputs=[
        gr.Textbox(lines=5, label="Generated text"),
        gr.Dropdown(label="Predicted tokens"),
        gr.Textbox(lines=10, label="Token probabilities")
    ],
    theme="huggingface",
    title=title,
    description=description,
    examples=examples,
    allow_flagging="manual"
)

# Launch the app
app.launch()