Transformers / app.py
adityaardak's picture
Create app.py
e54c024 verified
import gradio as gr
import pandas as pd
from transformers import pipeline
# ------------------------------------------------------------
# Lazy-load pipelines (loads only when you use that tab)
# ------------------------------------------------------------
PIPES = {}
def get_pipe(task: str, model: str = None):
key = (task, model)
if key not in PIPES:
if model:
PIPES[key] = pipeline(task, model=model)
else:
PIPES[key] = pipeline(task)
return PIPES[key]
# ------------------------------------------------------------
# Helpers
# ------------------------------------------------------------
def meter(label: str, score: float):
# A cute "meter" bar using text (works everywhere)
score = float(score)
blocks = int(round(score * 20))
bar = "β–ˆ" * blocks + "β–‘" * (20 - blocks)
return f"{label}\n{bar} {score:.2f}"
# ------------------------------------------------------------
# Tasks
# ------------------------------------------------------------
def run_sentiment(text, model_choice):
model_map = {
"Fast (default)": None,
"DistilBERT (SST-2)": "distilbert-base-uncased-finetuned-sst-2-english",
}
pipe = get_pipe("sentiment-analysis", model_map[model_choice])
r = pipe(text)[0]
label = r["label"]
score = r["score"]
emoji = "😊" if "POS" in label.upper() else "😞"
return (
f"{emoji} Prediction: {label}",
meter("Confidence", score),
pd.DataFrame([{"label": label, "confidence": score}]),
)
def run_qa(context, question):
pipe = get_pipe("question-answering", None)
r = pipe(question=question, context=context)
answer = r["answer"]
score = float(r["score"])
return (
f"βœ… Answer: {answer}",
meter("Confidence", score),
pd.DataFrame([{"answer": answer, "confidence": score}]),
)
def run_summary(text, length_mode):
pipe = get_pipe("summarization", None)
if length_mode == "Short":
max_len, min_len = 60, 20
elif length_mode == "Medium":
max_len, min_len = 90, 30
else:
max_len, min_len = 130, 40
r = pipe(text, max_length=max_len, min_length=min_len, do_sample=False)[0]
return r["summary_text"]
def run_translate(text, direction):
# Keep it simple: only two directions (more can be added)
if direction == "English β†’ French":
pipe = get_pipe("translation_en_to_fr", None)
else:
pipe = get_pipe("translation_fr_to_en", "Helsinki-NLP/opus-mt-fr-en")
r = pipe(text)[0]
# key differs by pipeline type; handle safely
return r.get("translation_text", str(r))
def run_generate(prompt, style, max_new_tokens, temperature):
# GPT-2 is lightweight and common; great for demos
pipe = get_pipe("text-generation", "gpt2")
# Add a tiny "story style" prefix (kid-friendly)
if style == "Story πŸ“–":
prompt2 = f"Once upon a time, {prompt.strip()}"
elif style == "Robot πŸ€–":
prompt2 = f"[Robot voice] {prompt.strip()}"
else:
prompt2 = prompt.strip()
r = pipe(
prompt2,
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
num_return_sequences=1,
)[0]["generated_text"]
return r
def run_fill_mask(text):
# Must contain [MASK]
pipe = get_pipe("fill-mask", "bert-base-uncased")
if "[MASK]" not in text:
return "⚠️ Please include [MASK] in the text.", pd.DataFrame()
results = pipe(text)
rows = []
for r in results[:10]:
rows.append({"prediction": r["sequence"], "score": float(r["score"])})
df = pd.DataFrame(rows)
return "βœ… Top predictions shown below", df
def run_zero_shot(text, labels):
pipe = get_pipe("zero-shot-classification", None)
label_list = [x.strip() for x in labels.split(",") if x.strip()]
if not label_list:
return "⚠️ Please type labels separated by commas.", pd.DataFrame()
r = pipe(text, candidate_labels=label_list)
df = pd.DataFrame({"label": r["labels"], "score": r["scores"]})
return "βœ… Sorted scores (bigger = more likely)", df
def run_ner(text):
pipe = get_pipe("ner", None)
ents = pipe(text, grouped_entities=True)
if not ents:
return "No entities found.", pd.DataFrame()
rows = []
for e in ents:
rows.append({
"text": e.get("word", ""),
"type": e.get("entity_group", e.get("entity", "")),
"score": float(e.get("score", 0.0)),
})
df = pd.DataFrame(rows).sort_values("score", ascending=False)
return "βœ… Entities found", df
# ------------------------------------------------------------
# UI
# ------------------------------------------------------------
THEME = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="pink",
neutral_hue="slate",
)
with gr.Blocks(theme=THEME, title="πŸ€— Transformers Playground (Kid Friendly)", css="""
#title {text-align:center}
.bigcard {border-radius: 18px; padding: 18px; background: white}
""") as demo:
gr.Markdown("""
<div id="title">
# πŸ€— Transformers Superpowers Playground
### Same library, many amazing language powers ✨
</div>
**How to use this app (students):**
1. Pick a tab (Sentiment, Q&A, Summary, Translate, etc.)
2. Change the text ✍️
3. Click the big button πŸš€
4. Observe what the Transformer can do 🧠
""")
with gr.Row():
gr.Markdown("""
<div class="bigcard">
## What can Transformers do?
- 😊 Detect feelings (Sentiment)
- ❓ Answer questions (Q&A)
- πŸ“ Summarize long text
- 🌍 Translate languages
- ✍️ Continue stories (Generation)
- 🧩 Fill missing words ([MASK])
- 🏷️ Classify topics (Zero-shot)
- πŸ‘€ Find names/places (NER)
</div>
""")
with gr.Tabs():
# ------------------ Sentiment ------------------
with gr.Tab("😊 Sentiment"):
gr.Markdown("### Detect if text feels **positive** or **negative**.")
with gr.Row():
sent_text = gr.Textbox(
label="Type a sentence",
value="I love this game! It is so fun and exciting!",
lines=3
)
with gr.Column():
sent_model = gr.Dropdown(
["Fast (default)", "DistilBERT (SST-2)"],
value="Fast (default)",
label="Model choice"
)
sent_btn = gr.Button("πŸš€ Analyze Sentiment", variant="primary")
sent_out1 = gr.Textbox(label="Result", lines=1)
sent_out2 = gr.Textbox(label="Confidence Meter", lines=2)
sent_table = gr.Dataframe(label="Details", interactive=False)
gr.Examples(
examples=[
["This movie was amazing! I want to watch it again!"],
["This is the worst day ever. I feel upset."],
["It was okay, not great, not bad."],
],
inputs=sent_text,
label="Try examples"
)
sent_btn.click(run_sentiment, [sent_text, sent_model], [sent_out1, sent_out2, sent_table])
# ------------------ Q&A ------------------
with gr.Tab("❓ Question Answering"):
gr.Markdown("### Ask a question using a paragraph as the β€œbook”.")
qa_context = gr.Textbox(
label="Context (the paragraph)",
value="Paris is the capital of France. It is famous for the Eiffel Tower and beautiful museums.",
lines=5
)
qa_question = gr.Textbox(label="Question", value="What is the capital of France?")
qa_btn = gr.Button("πŸ”Ž Find Answer", variant="primary")
qa_out1 = gr.Textbox(label="Answer", lines=1)
qa_out2 = gr.Textbox(label="Confidence Meter", lines=2)
qa_table = gr.Dataframe(label="Details", interactive=False)
qa_btn.click(run_qa, [qa_context, qa_question], [qa_out1, qa_out2, qa_table])
# ------------------ Summarization ------------------
with gr.Tab("πŸ“ Summarization"):
gr.Markdown("### Make long text short (like a mini version).")
sum_text = gr.Textbox(
label="Long text",
value=("Artificial intelligence is a field of computer science. "
"It tries to make machines smart. AI can help with images, language, and robots. "
"Some AI systems learn from data and improve over time."),
lines=6
)
sum_mode = gr.Radio(["Short", "Medium", "Long"], value="Short", label="Summary size")
sum_btn = gr.Button("✨ Summarize", variant="primary")
sum_out = gr.Textbox(label="Summary", lines=4)
sum_btn.click(run_summary, [sum_text, sum_mode], sum_out)
# ------------------ Translation ------------------
with gr.Tab("🌍 Translation"):
gr.Markdown("### Translate between languages.")
tr_text = gr.Textbox(label="Text", value="I love learning AI.", lines=3)
tr_dir = gr.Radio(["English β†’ French", "French β†’ English"], value="English β†’ French", label="Direction")
tr_btn = gr.Button("🌟 Translate", variant="primary")
tr_out = gr.Textbox(label="Translation", lines=3)
tr_btn.click(run_translate, [tr_text, tr_dir], tr_out)
# ------------------ Text Generation ------------------
with gr.Tab("✍️ Text Generation"):
gr.Markdown("### Let the model continue your writing.")
gen_prompt = gr.Textbox(
label="Start a sentence / story",
value="a brave kid builds a friendly robot that helps at school",
lines=3
)
with gr.Row():
gen_style = gr.Radio(["Story πŸ“–", "Normal ✨", "Robot πŸ€–"], value="Story πŸ“–", label="Style")
gen_tokens = gr.Slider(20, 150, value=60, step=5, label="How long?")
gen_temp = gr.Slider(0.2, 1.5, value=0.9, step=0.1, label="Creativity (temperature)")
gen_btn = gr.Button("πŸš€ Generate", variant="primary")
gen_out = gr.Textbox(label="Generated text", lines=10)
gen_btn.click(run_generate, [gen_prompt, gen_style, gen_tokens, gen_temp], gen_out)
# ------------------ Fill Mask ------------------
with gr.Tab("🧩 Fill Missing Word"):
gr.Markdown("### Put **[MASK]** and the model guesses the missing word.")
fm_text = gr.Textbox(
label="Text with [MASK]",
value="I love to play [MASK] with my friends.",
lines=3
)
fm_btn = gr.Button("🧠 Predict Missing Word", variant="primary")
fm_msg = gr.Textbox(label="Message", lines=1)
fm_table = gr.Dataframe(label="Top predictions", interactive=False)
fm_btn.click(run_fill_mask, fm_text, [fm_msg, fm_table])
# ------------------ Zero-shot classification ------------------
with gr.Tab("🏷️ Classify Topics"):
gr.Markdown("### Classify text using labels you invent (no training needed).")
zs_text = gr.Textbox(
label="Text",
value="I love playing football after school and practicing with my team.",
lines=4
)
zs_labels = gr.Textbox(
label="Labels (comma separated)",
value="sports, school, food, music, games"
)
zs_btn = gr.Button("🎯 Classify", variant="primary")
zs_msg = gr.Textbox(label="Message", lines=1)
zs_table = gr.Dataframe(label="Scores", interactive=False)
zs_btn.click(run_zero_shot, [zs_text, zs_labels], [zs_msg, zs_table])
# ------------------ NER ------------------
with gr.Tab("πŸ‘€ Find Names & Places"):
gr.Markdown("### Find **people, places, and organizations** in text.")
ner_text = gr.Textbox(
label="Text",
value="Elon Musk founded SpaceX in the United States and talked about Mars.",
lines=4
)
ner_btn = gr.Button("πŸ” Detect Entities", variant="primary")
ner_msg = gr.Textbox(label="Message", lines=1)
ner_table = gr.Dataframe(label="Entities", interactive=False)
ner_btn.click(run_ner, ner_text, [ner_msg, ner_table])
gr.Markdown("""
---
## ⭐ Teacher / Demo Tips
- Start with **Sentiment** (instant β€œwow”).
- Then **Q&A** (shows understanding).
- Then **Translate** (feels magical).
- Then **Generation** (kids LOVE it).
- For a challenge: ask students to write examples that β€œtrick” the model.
""")
demo.launch()