Spaces:
Sleeping
Sleeping
File size: 5,003 Bytes
ac72c21 d7f8dad 43130a6 93a3f9a d7f8dad 93cbacc d7f8dad e5cdcee 43130a6 ac72c21 617bd81 c9abde4 214932f 949913d 617bd81 310d018 ac72c21 310d018 6c99f7c 310d018 ac72c21 310d018 9a72c69 ac72c21 6c99f7c 310d018 949913d 310d018 d7f8dad ac72c21 310d018 ac72c21 617bd81 310d018 ac72c21 6c99f7c 310d018 6c99f7c ac72c21 6c99f7c 310d018 ac72c21 6c99f7c 310d018 617bd81 93cbacc 310d018 93cbacc 6c99f7c 310d018 93cbacc ac72c21 310d018 5f1ad57 310d018 9a72c69 310d018 9a72c69 93a3f9a 9a72c69 93a3f9a 310d018 93a3f9a 9a72c69 93a3f9a 310d018 93a3f9a 310d018 9a72c69 310d018 9a72c69 310d018 9a72c69 93cbacc 310d018 93a3f9a 9a72c69 310d018 93cbacc 9a72c69 175fea5 310d018 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
# Handle Spaces GPU
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@spaces.GPU
def fake_gpu():
pass
import numpy as np
import torch
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces
from huggingface_hub import login
# Authenticate
HF_TOKEN = os.getenv('HF_TOKEN')
login(token=HF_TOKEN)
# Available models
AVAILABLE_MODELS = {
"bloomz-560m": "bigscience/bloomz-560m",
"opt-350m": "facebook/opt-350m",
"pythia-160m": "EleutherAI/pythia-160m",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
"deepseek-small": "deepseek-ai/DeepSeek-V2-Lite",
"llama": "meta-llama/Llama-3.2-1B"
}
# Initialize model and tokenizer
current_model = None
current_tokenizer = None
current_model_name = None
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(model_name):
"""Load the selected model and tokenizer."""
global current_model, current_tokenizer, current_model_name
if current_model_name != model_name:
current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]).to(device)
current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
current_model_name = model_name
# Load the default model at startup
load_model("mistral-7b")
@spaces.GPU()
def get_next_token_predictions(text, model_name, top_k=10):
"""Generate the next token predictions with their probabilities."""
global current_model, current_tokenizer
# Load the model if it has changed
if current_model_name != model_name:
load_model(model_name)
inputs = current_tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = current_model(**inputs)
logits = outputs.logits[0, -1, :]
probs = torch.nn.functional.softmax(logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices]
return top_k_tokens, top_k_probs.cpu().tolist()
def plot_probabilities(tokens, probs):
"""Generate a horizontal bar chart for token probabilities."""
fig, ax = plt.subplots(figsize=(8, 5))
ax.barh(tokens[::-1], probs[::-1], color="skyblue")
ax.set_xlabel("Probability")
ax.set_title("Next Token Predictions")
plt.tight_layout()
# Save plot as an image and return the file path
plot_path = "token_probabilities.png"
plt.savefig(plot_path)
plt.close(fig)
return plot_path
def predict_next_token(model_name, text, top_k, custom_token=""):
"""Get predictions and update the UI with text and a chart."""
if custom_token:
text += custom_token
tokens, probs = get_next_token_predictions(text, model_name, top_k)
# Generate bar chart
plot_path = plot_probabilities(tokens, probs)
return gr.update(choices=[f"'{t}'" for t in tokens]), plot_path
def append_selected_token(text, selected_token):
"""Append selected token from dropdown to the text input."""
if selected_token:
clean_token = selected_token.strip("'")
text += f" {clean_token}"
return text
# Create the UI
with gr.Blocks() as demo:
gr.Markdown("# 🔥 Interactive Text Prediction with Transformers")
gr.Markdown(
"This application lets you interactively generate text using multiple transformer models. "
"Choose a model, type your text, and explore token predictions."
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="distilgpt2",
label="Select Model"
)
with gr.Row():
text_input = gr.Textbox(
lines=5,
label="Input Text",
placeholder="Type your text here...",
value="The quick brown fox"
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=1,
maximum=20,
value=10,
step=1,
label="Top-k Predictions"
)
with gr.Row():
predict_button = gr.Button("Predict")
with gr.Row():
token_dropdown = gr.Dropdown(
label="Predicted Tokens",
choices=[]
)
append_button = gr.Button("Append Token")
with gr.Row():
chart_output = gr.Image(label="Token Probability Chart")
# Button click events
predict_button.click(
predict_next_token,
inputs=[model_dropdown, text_input, top_k_slider],
outputs=[token_dropdown, chart_output]
)
append_button.click(
append_selected_token,
inputs=[text_input, token_dropdown],
outputs=text_input
)
demo.queue().launch()
|