llms-demo / app.py
yabramuvdi's picture
Update app.py
43130a6 verified
raw
history blame
4.36 kB
import os
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@spaces.GPU
def fake_gpu():
pass
import numpy as np
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# Available models
AVAILABLE_MODELS = {
"distilgpt2": "distilgpt2",
"bloomz-560m": "bigscience/bloomz-560m",
"gpt2-medium": "gpt2-medium",
"opt-350m": "facebook/opt-350m",
"pythia-160m": "EleutherAI/pythia-160m"
}
# Initialize model and tokenizer globally
current_model = None
current_tokenizer = None
current_model_name = None
def load_model(model_name):
global current_model, current_tokenizer, current_model_name
if current_model_name != model_name:
current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name])
current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
current_model_name = model_name
def get_next_token_predictions(text, model_name, top_k=10):
global current_model, current_tokenizer
# Load model if needed
if current_model_name != model_name:
load_model(model_name)
# Get predictions
inputs = current_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = current_model(**inputs)
logits = outputs.logits[0, -1, :]
probs = torch.nn.functional.softmax(logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
return top_k_tokens, top_k_probs.tolist()
def predict_next_token(text, model_name, custom_token=""):
# Add custom token if provided
if custom_token:
text += custom_token
# Get predictions
tokens, probs = get_next_token_predictions(text, model_name)
# Format predictions
predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(tokens, probs)])
return text, gr.update(choices=[f"'{t}'" for t in tokens]), predictions
# Create the interface
with gr.Blocks() as demo:
gr.Markdown("# Interactive Text Generation with Transformer Models")
gr.Markdown("""
This application allows you to interactively generate text using various transformer models.
You can either select from the predicted next tokens or write your own tokens to continue the text generation.
Select a model, start typing or choose from the predicted tokens, and see how the model continues your text!
""")
with gr.Row():
text_input = gr.Textbox(
lines=5,
label="Text",
placeholder="Type your text here...",
value="The quick brown fox"
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="distilgpt2",
label="Select Model"
)
with gr.Row():
custom_input = gr.Textbox(
label="Custom token (optional)",
placeholder="Type a custom token..."
)
with gr.Row():
token_dropdown = gr.Dropdown(
label="Predicted tokens",
choices=[]
)
with gr.Row():
predictions_output = gr.Textbox(
lines=10,
label="Token probabilities"
)
# Set up event handlers
text_input.change(
predict_next_token,
inputs=[text_input, model_dropdown, custom_input],
outputs=[text_input, token_dropdown, predictions_output]
)
model_dropdown.change(
predict_next_token,
inputs=[text_input, model_dropdown, custom_input],
outputs=[text_input, token_dropdown, predictions_output]
)
custom_input.change(
predict_next_token,
inputs=[text_input, model_dropdown, custom_input],
outputs=[text_input, token_dropdown, predictions_output]
)
token_dropdown.change(
predict_next_token,
inputs=[text_input, model_dropdown, gr.Textbox(value="")],
outputs=[text_input, token_dropdown, predictions_output]
)
demo.queue().launch()