XMichaelX commited on
Commit
523bd62
·
verified ·
1 Parent(s): 3d958b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -18
app.py CHANGED
@@ -1,32 +1,43 @@
1
  import gradio as gr
 
2
 
3
  title = "GPT2"
4
- description = "Gradio Demo for OpenAI GPT2. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
5
-
6
  article = "<p style='text-align: center'><a href='https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf' target='_blank'>Language Models are Unsupervised Multitask Learners</a></p>"
7
 
8
  examples = [
9
  ['Paris is the capital of', "gpt2-medium"]
10
  ]
11
 
12
- # Load all models at startup
13
- io1 = gr.load("huggingface/distilgpt2")
14
- io2 = gr.load("huggingface/gpt2-large")
15
- io3 = gr.load("huggingface/gpt2-medium")
16
- io4 = gr.load("huggingface/gpt2-xl")
 
 
 
 
17
 
18
- def inference(text, model):
19
- if model == "gpt2-large":
20
- outtext = io2(text)
21
- elif model == "gpt2-medium":
22
- outtext = io3(text)
23
- elif model == "gpt2-xl":
24
- outtext = io4(text)
25
- else:
26
- outtext = io1(text)
27
- return outtext
 
 
 
 
 
 
 
 
28
 
29
- # Create the interface
30
  iface = gr.Interface(
31
  inference,
32
  [
 
1
  import gradio as gr
2
+ from transformers import pipeline, GPT2LMHeadModel, GPT2Tokenizer
3
 
4
  title = "GPT2"
5
+ description = "Gradio Demo for OpenAI GPT2. To use it, simply add your text, or click one of the examples to load them."
 
6
  article = "<p style='text-align: center'><a href='https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf' target='_blank'>Language Models are Unsupervised Multitask Learners</a></p>"
7
 
8
  examples = [
9
  ['Paris is the capital of', "gpt2-medium"]
10
  ]
11
 
12
+ # Initialize models dictionary to cache loaded models
13
+ models = {}
14
+
15
+ def load_model(model_name):
16
+ if model_name not in models:
17
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
18
+ model = GPT2LMHeadModel.from_pretrained(model_name)
19
+ models[model_name] = pipeline("text-generation", model=model, tokenizer=tokenizer)
20
+ return models[model_name]
21
 
22
+ def inference(text, model_name):
23
+ # Map the model names to their Hugging Face identifiers
24
+ model_map = {
25
+ "distilgpt2": "distilgpt2",
26
+ "gpt2-medium": "gpt2-medium",
27
+ "gpt2-large": "gpt2-large",
28
+ "gpt2-xl": "gpt2-xl"
29
+ }
30
+
31
+ # Get the correct model identifier
32
+ hf_model_name = model_map.get(model_name, "distilgpt2")
33
+
34
+ # Load the model (will be cached after first load)
35
+ generator = load_model(hf_model_name)
36
+
37
+ # Generate text
38
+ generated = generator(text, max_length=50, num_return_sequences=1)
39
+ return generated[0]['generated_text']
40
 
 
41
  iface = gr.Interface(
42
  inference,
43
  [