Spaces:
Sleeping
Sleeping
| import os | |
| if os.environ.get("SPACES_ZERO_GPU") is not None: | |
| import spaces | |
| else: | |
| class spaces: | |
| def GPU(func): | |
| def wrapper(*args, **kwargs): | |
| return func(*args, **kwargs) | |
| return wrapper | |
| def fake_gpu(): | |
| pass | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Available models | |
| AVAILABLE_MODELS = { | |
| "distilgpt2": "distilgpt2", | |
| "bloomz-560m": "bigscience/bloomz-560m", | |
| "gpt2-medium": "gpt2-medium", | |
| "opt-350m": "facebook/opt-350m", | |
| "pythia-160m": "EleutherAI/pythia-160m" | |
| } | |
| # Initialize model and tokenizer globally | |
| current_model = None | |
| current_tokenizer = None | |
| current_model_name = None | |
| def load_model(model_name): | |
| global current_model, current_tokenizer, current_model_name | |
| if current_model_name != model_name: | |
| current_model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]) | |
| current_tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name]) | |
| current_model_name = model_name | |
| def get_next_token_predictions(text, model_name, top_k=10): | |
| global current_model, current_tokenizer | |
| # Load model if needed | |
| if current_model_name != model_name: | |
| load_model(model_name) | |
| # Get predictions | |
| inputs = current_tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = current_model(**inputs) | |
| logits = outputs.logits[0, -1, :] | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| top_k_probs, top_k_indices = torch.topk(probs, k=top_k) | |
| top_k_tokens = [current_tokenizer.decode([idx.item()]) for idx in top_k_indices] | |
| return top_k_tokens, top_k_probs.tolist() | |
| def predict_next_token(text, model_name, custom_token=""): | |
| # Add custom token if provided | |
| if custom_token: | |
| text += custom_token | |
| # Get predictions | |
| tokens, probs = get_next_token_predictions(text, model_name) | |
| # Format predictions | |
| predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(tokens, probs)]) | |
| return text, gr.update(choices=[f"'{t}'" for t in tokens]), predictions | |
| # Create the interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Interactive Text Generation with Transformer Models") | |
| gr.Markdown(""" | |
| This application allows you to interactively generate text using various transformer models. | |
| You can either select from the predicted next tokens or write your own tokens to continue the text generation. | |
| Select a model, start typing or choose from the predicted tokens, and see how the model continues your text! | |
| """) | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| lines=5, | |
| label="Text", | |
| placeholder="Type your text here...", | |
| value="The quick brown fox" | |
| ) | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value="distilgpt2", | |
| label="Select Model" | |
| ) | |
| with gr.Row(): | |
| custom_input = gr.Textbox( | |
| label="Custom token (optional)", | |
| placeholder="Type a custom token..." | |
| ) | |
| with gr.Row(): | |
| token_dropdown = gr.Dropdown( | |
| label="Predicted tokens", | |
| choices=[] | |
| ) | |
| with gr.Row(): | |
| predictions_output = gr.Textbox( | |
| lines=10, | |
| label="Token probabilities" | |
| ) | |
| # Set up event handlers | |
| text_input.change( | |
| predict_next_token, | |
| inputs=[text_input, model_dropdown, custom_input], | |
| outputs=[text_input, token_dropdown, predictions_output] | |
| ) | |
| model_dropdown.change( | |
| predict_next_token, | |
| inputs=[text_input, model_dropdown, custom_input], | |
| outputs=[text_input, token_dropdown, predictions_output] | |
| ) | |
| custom_input.change( | |
| predict_next_token, | |
| inputs=[text_input, model_dropdown, custom_input], | |
| outputs=[text_input, token_dropdown, predictions_output] | |
| ) | |
| token_dropdown.change( | |
| predict_next_token, | |
| inputs=[text_input, model_dropdown, gr.Textbox(value="")], | |
| outputs=[text_input, token_dropdown, predictions_output] | |
| ) | |
| demo.queue().launch() |