File size: 1,913 Bytes
0ec5fee 523bd62 0ec5fee 523bd62 0ec5fee 3d958b3 0ec5fee 523bd62 0ec5fee 523bd62 0ec5fee 3d958b3 0ec5fee 3d958b3 0ec5fee 3d958b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
from transformers import pipeline, GPT2LMHeadModel, GPT2Tokenizer
title = "GPT2"
description = "Gradio Demo for OpenAI GPT2. To use it, simply add your text, or click one of the examples to load them."
article = "<p style='text-align: center'><a href='https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf' target='_blank'>Language Models are Unsupervised Multitask Learners</a></p>"
examples = [
['Paris is the capital of', "gpt2-medium"]
]
# Initialize models dictionary to cache loaded models
models = {}
def load_model(model_name):
if model_name not in models:
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
models[model_name] = pipeline("text-generation", model=model, tokenizer=tokenizer)
return models[model_name]
def inference(text, model_name):
# Map the model names to their Hugging Face identifiers
model_map = {
"distilgpt2": "distilgpt2",
"gpt2-medium": "gpt2-medium",
"gpt2-large": "gpt2-large",
"gpt2-xl": "gpt2-xl"
}
# Get the correct model identifier
hf_model_name = model_map.get(model_name, "distilgpt2")
# Load the model (will be cached after first load)
generator = load_model(hf_model_name)
# Generate text
generated = generator(text, max_length=50, num_return_sequences=1)
return generated[0]['generated_text']
iface = gr.Interface(
inference,
[
gr.Textbox(label="Input"),
gr.Dropdown(
choices=["distilgpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"],
value="gpt2-medium",
label="Model"
)
],
gr.Textbox(label="Output"),
examples=examples,
article=article,
title=title,
description=description
)
iface.launch(enable_queue=True) |