Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| AVAILABLE_MODELS = { | |
| "distilgpt2": "distilgpt2", | |
| "bloomz-560m": "bigscience/bloomz-560m", | |
| "gpt2-medium": "gpt2-medium", | |
| "opt-350m": "facebook/opt-350m", | |
| "pythia-160m": "EleutherAI/pythia-160m" | |
| } | |
| generator = None | |
| def load_model(model_name): | |
| global generator | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained(AVAILABLE_MODELS[model_name]) | |
| tokenizer = AutoTokenizer.from_pretrained(AVAILABLE_MODELS[model_name]) | |
| generator = (model, tokenizer) | |
| return f"Successfully loaded {model_name}" | |
| except Exception as e: | |
| return f"Error loading model: {str(e)}" | |
| def get_predictions(text, model_name): | |
| global generator | |
| if not generator: | |
| load_model(model_name) | |
| model, tokenizer = generator | |
| inputs = tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits[0, -1, :] | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| top_k_probs, top_k_indices = torch.topk(probs, k=10) | |
| top_k_tokens = [tokenizer.decode([idx.item()]) for idx in top_k_indices] | |
| predictions = "\n".join([f"'{token}' : {prob:.4f}" for token, prob in zip(top_k_tokens, top_k_probs)]) | |
| return top_k_tokens, predictions | |
| def generate(model_name, text, token_choice="", custom_token=""): | |
| if token_choice: | |
| text += token_choice.strip("'") | |
| if custom_token: | |
| text += custom_token | |
| tokens, predictions = get_predictions(text, model_name) | |
| return text, gr.Dropdown(choices=[f"'{t}'" for t in tokens]), predictions | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Interactive Text Generation") | |
| model_name = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value="distilgpt2", | |
| label="Select Model" | |
| ) | |
| text = gr.Textbox( | |
| lines=5, | |
| label="Text", | |
| placeholder="Type or select tokens to generate text..." | |
| ) | |
| with gr.Row(): | |
| token_choice = gr.Dropdown( | |
| choices=[], | |
| label="Select predicted token" | |
| ) | |
| custom_token = gr.Textbox( | |
| label="Or type custom token" | |
| ) | |
| predictions = gr.Textbox( | |
| label="Predictions", | |
| lines=10 | |
| ) | |
| for component in [model_name, token_choice, custom_token]: | |
| component.change( | |
| generate, | |
| inputs=[model_name, text, token_choice, custom_token], | |
| outputs=[text, token_choice, predictions] | |
| ) | |
| demo.queue().launch() |