Spaces:
Sleeping
Sleeping
File size: 4,363 Bytes
ac72c21 43130a6 ac72c21 617bd81 ac72c21 617bd81 ac72c21 6c99f7c ac72c21 9a72c69 ac72c21 6c99f7c ac72c21 617bd81 ac72c21 6c99f7c ac72c21 6c99f7c ac72c21 6c99f7c 617bd81 ac72c21 6c99f7c ac72c21 617bd81 ac72c21 6c99f7c 617bd81 ac72c21 617bd81 ac72c21 617bd81 9a72c69 ac72c21 9a72c69 175fea5 9a72c69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@spaces.GPU
def fake_gpu():
pass
import numpy as np
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# Available models
AVAILABLE_MODELS = {
"distilgpt2": "distilgpt2",
"bloomz-560m": "bigscience/bloomz-560m",
"gpt2-medium": "gpt2-medium",
"opt-350m": "facebook/opt-350m",
"pythia-160m": "EleutherAI/pythia-160m"
}
# Initialize model and tokenizer globally
current_model = None
current_tokenizer = None
current_model_name = None
def load_model(model_name):
global current_model, current_tokenizer, current_model_name
if current_model_name != model_name:
current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
current_model_name = model_name
def get_next_token_predictions(text, model_name, top_k=10):
global current_model, current_tokenizer
# Load model if needed
if current_model_name != model_name:
load_model(model_name)
# Get predictions
inputs = current_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = current_model(**inputs)
logits = outputs.logits[0, -1, :]
probs = torch.nn.functional.softmax(logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
return top_k_tokens, top_k_probs.tolist()
def predict_next_token(text, model_name, custom_token=""):
# Add custom token if provided
if custom_token:
text += custom_token
# Get predictions
tokens, probs = get_next_token_predictions(text, model_name)
# Format predictions
predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(tokens, probs)])
return text, gr.update(choices=[f"'{t}'" for t in tokens]), predictions
# Create the interface
with gr.Blocks() as demo:
gr.Markdown("# Interactive Text Generation with Transformer Models")
gr.Markdown("""
This application allows you to interactively generate text using various transformer models.
You can either select from the predicted next tokens or write your own tokens to continue the text generation.
Select a model, start typing or choose from the predicted tokens, and see how the model continues your text!
""")
with gr.Row():
text_input = gr.Textbox(
lines=5,
label="Text",
placeholder="Type your text here...",
value="The quick brown fox"
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="distilgpt2",
label="Select Model"
)
with gr.Row():
custom_input = gr.Textbox(
label="Custom token (optional)",
placeholder="Type a custom token..."
)
with gr.Row():
token_dropdown = gr.Dropdown(
label="Predicted tokens",
choices=[]
)
with gr.Row():
predictions_output = gr.Textbox(
lines=10,
label="Token probabilities"
)
# Set up event handlers
text_input.change(
predict_next_token,
inputs=[text_input, model_dropdown, custom_input],
outputs=[text_input, token_dropdown, predictions_output]
)
model_dropdown.change(
predict_next_token,
inputs=[text_input, model_dropdown, custom_input],
outputs=[text_input, token_dropdown, predictions_output]
)
custom_input.change(
predict_next_token,
inputs=[text_input, model_dropdown, custom_input],
outputs=[text_input, token_dropdown, predictions_output]
)
token_dropdown.change(
predict_next_token,
inputs=[text_input, model_dropdown, gr.Textbox(value="")],
outputs=[text_input, token_dropdown, predictions_output]
)
demo.queue().launch() |