Mehak-Mazhar commited on
Commit
91b66ad
Β·
verified Β·
1 Parent(s): 967130e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -53
app.py CHANGED
@@ -1,60 +1,50 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
-
4
- # Load text generation models
5
- models = {
6
- "🟦 GPT-2 Tiny": pipeline("text-generation", model="sshleifer/tiny-gpt2"),
7
- "🟩 GPT-Neo 125M": pipeline("text-generation", model="EleutherAI/gpt-neo-125M"),
8
- "🟨 T5 Small (CommonGen)": pipeline("text2text-generation", model="mrm8488/t5-small-finetuned-common_gen"),
9
- }
10
-
11
- # Generation function
12
- def generate_text(prompt, model_choice):
13
- if not prompt.strip():
14
- return "❗ Please enter a prompt."
15
-
16
- generator = models[model_choice]
17
- output = generator(prompt, max_length=100, num_return_sequences=1)[0]
18
-
19
- return output.get("generated_text", output.get("generated_text", output.get("text", "⚠️ Generation failed.")))
20
-
21
- # Gradio App UI
22
- with gr.Blocks(css="""
23
- body { background-color: #fffde7; }
24
- #main-title {
25
- text-align: center;
26
- font-size: 34px;
27
- font-weight: bold;
28
- color: #5d4037;
29
- margin-top: 20px;
30
- }
31
- .output-box {
32
- border: 2px solid #ffcc80;
33
- padding: 15px;
34
- border-radius: 10px;
35
- background-color: #fff8e1;
36
- }
37
- #footer {
38
- text-align: center;
39
- font-size: 12px;
40
- color: #aaa;
41
- margin-top: 30px;
42
- }
43
- """) as demo:
44
-
45
- gr.Markdown("<div id='main-title'>πŸ“ LLM for Content Generation</div>")
46
 
47
  with gr.Row():
48
- with gr.Column(scale=1):
49
- prompt_input = gr.Textbox(label="🟨 Prompt", placeholder="Enter your topic, e.g., 'Benefits of AI in education'", lines=5)
50
- model_selector = gr.Radio(choices=list(models.keys()), value="🟦 GPT-2 Tiny", label="Select a Model")
51
- generate_button = gr.Button("πŸš€ Generate Content")
52
 
53
- with gr.Column(scale=1):
54
- generated_output = gr.Textbox(label="πŸ“ Generated Text", lines=15, elem_classes=["output-box"])
55
 
56
- generate_button.click(fn=generate_text, inputs=[prompt_input, model_selector], outputs=generated_output)
57
-
58
- gr.Markdown("<div id='footer'>✨ Powered by Hugging Face Transformers</div>")
59
 
60
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import (
3
+ GPT2LMHeadModel, GPT2Tokenizer,
4
+ AutoModelForCausalLM, AutoTokenizer
5
+ )
6
+
7
+ # Load Models
8
+ gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
9
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
10
+
11
+ distilgpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2")
12
+ distilgpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
13
+
14
+ bloom_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
15
+ bloom_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
16
+
17
+ # Inference Function
18
+ def generate_text(prompt, model_name):
19
+ if model_name == "🧠 GPT2":
20
+ inputs = gpt2_tokenizer.encode(prompt, return_tensors="pt")
21
+ output = gpt2_model.generate(inputs, max_length=100)
22
+ return gpt2_tokenizer.decode(output[0], skip_special_tokens=True)
23
+
24
+ elif model_name == "⚑ DistilGPT2":
25
+ inputs = distilgpt2_tokenizer.encode(prompt, return_tensors="pt")
26
+ output = distilgpt2_model.generate(inputs, max_length=100)
27
+ return distilgpt2_tokenizer.decode(output[0], skip_special_tokens=True)
28
+
29
+ elif model_name == "🌸 Bloom-560M":
30
+ inputs = bloom_tokenizer(prompt, return_tensors="pt")
31
+ output = bloom_model.generate(inputs["input_ids"], max_length=100)
32
+ return bloom_tokenizer.decode(output[0], skip_special_tokens=True)
33
+
34
+ # Gradio UI
35
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
36
+ gr.Markdown("<h1 style='text-align: center; color: brown;'>LLM for Content Generation</h1>")
37
+ gr.Markdown("<div style='text-align: center;'>Generate high-quality text using three powerful LLMs</div>")
 
 
 
 
 
 
 
 
38
 
39
  with gr.Row():
40
+ with gr.Column():
41
+ prompt = gr.Textbox(label="Enter a topic or prompt")
42
+ model_choice = gr.Radio(["🧠 GPT2", "⚑ DistilGPT2", "🌸 Bloom-560M"], label="Choose a Model")
43
+ submit = gr.Button("Generate")
44
 
45
+ with gr.Column():
46
+ output = gr.Textbox(label="Generated Text", lines=10)
47
 
48
+ submit.click(fn=generate_text, inputs=[prompt, model_choice], outputs=output)
 
 
49
 
50
  demo.launch()