llms-demo / app.py
yabramuvdi's picture
Update app.py
e5cdcee verified
raw
history blame
5 kB
import os
# Handle Spaces GPU
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@spaces.GPU
def fake_gpu():
pass
import numpy as np
import torch
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
from huggingface_hub import login
# Authenticate
HF_TOKEN = os.getenv('HF_TOKEN')
login(token=HF_TOKEN)
# Available models
AVAILABLE_MODELS = {
"bloomz-560m": "bigscience/bloomz-560m",
"opt-350m": "facebook/opt-350m",
"pythia-160m": "EleutherAI/pythia-160m",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
"deepseek-small": "deepseek-ai/DeepSeek-V2-Lite",
"llama": "meta-llama/Llama-3.2-1B"
}
# Initialize model and tokenizer
current_model = None
current_tokenizer = None
current_model_name = None
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(model_name):
"""Load the selected model and tokenizer."""
global current_model, current_tokenizer, current_model_name
if current_model_name != model_name:
current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]).to(device)
current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
current_model_name = model_name
# Load the default model at startup
load_model("mistral-7b")
@spaces.GPU()
def get_next_token_predictions(text, model_name, top_k=10):
"""Generate the next token predictions with their probabilities."""
global current_model, current_tokenizer
# Load the model if it has changed
if current_model_name != model_name:
load_model(model_name)
inputs = current_tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = current_model(**inputs)
logits = outputs.logits[0, -1, :]
probs = torch.nn.functional.softmax(logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
return top_k_tokens, top_k_probs.cpu().tolist()
def plot_probabilities(tokens, probs):
"""Generate a horizontal bar chart for token probabilities."""
fig, ax = plt.subplots(figsize=(8, 5))
ax.barh(tokens[::-1], probs[::-1], color="skyblue")
ax.set_xlabel("Probability")
ax.set_title("Next Token Predictions")
plt.tight_layout()
# Save plot as an image and return the file path
plot_path = "token_probabilities.png"
plt.savefig(plot_path)
plt.close(fig)
return plot_path
def predict_next_token(model_name, text, top_k, custom_token=""):
"""Get predictions and update the UI with text and a chart."""
if custom_token:
text += custom_token
tokens, probs = get_next_token_predictions(text, model_name, top_k)
# Generate bar chart
plot_path = plot_probabilities(tokens, probs)
return gr.update(choices=[f"'{t}'" for t in tokens]), plot_path
def append_selected_token(text, selected_token):
"""Append selected token from dropdown to the text input."""
if selected_token:
clean_token = selected_token.strip("'")
text += f" {clean_token}"
return text
# Create the UI
with gr.Blocks() as demo:
gr.Markdown("# 🔥 Interactive Text Prediction with Transformers")
gr.Markdown(
"This application lets you interactively generate text using multiple transformer models. "
"Choose a model, type your text, and explore token predictions."
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="distilgpt2",
label="Select Model"
)
with gr.Row():
text_input = gr.Textbox(
lines=5,
label="Input Text",
placeholder="Type your text here...",
value="The quick brown fox"
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=1,
maximum=20,
value=10,
step=1,
label="Top-k Predictions"
)
with gr.Row():
predict_button = gr.Button("Predict")
with gr.Row():
token_dropdown = gr.Dropdown(
label="Predicted Tokens",
choices=[]
)
append_button = gr.Button("Append Token")
with gr.Row():
chart_output = gr.Image(label="Token Probability Chart")
# Button click events
predict_button.click(
predict_next_token,
inputs=[model_dropdown, text_input, top_k_slider],
outputs=[token_dropdown, chart_output]
)
append_button.click(
append_selected_token,
inputs=[text_input, token_dropdown],
outputs=text_input
)
demo.queue().launch()