hf-model-runner / app.py
dharshanzeb's picture
Add main app.py β€” multi-task model runner
3a73974 verified
import gradio as gr
import torch
from transformers import pipeline as hf_pipeline
from huggingface_hub import HfApi
# ═══════════════════════════════════════════════════════════════
# Task configuration β€” default models and input/output types
# ═══════════════════════════════════════════════════════════════
TASK_CONFIG = {
"text-generation": {
"default": "openai-community/gpt2",
"description": "Generate text from a prompt",
"placeholder": "Once upon a time in a land far away...",
},
"text-classification": {
"default": "distilbert/distilbert-base-uncased-finetuned-sst-2-english",
"description": "Classify text into categories (sentiment, topic, etc.)",
"placeholder": "I absolutely loved this movie! The acting was superb.",
},
"summarization": {
"default": "facebook/bart-large-cnn",
"description": "Summarize long text into a concise version",
"placeholder": "Paste a long article or paragraph here to get a summary...",
},
"translation_en_to_fr": {
"default": "Helsinki-NLP/opus-mt-en-fr",
"description": "Translate English text to French",
"placeholder": "Hello, how are you today?",
},
"question-answering": {
"default": "deepset/roberta-base-squad2",
"description": "Answer a question given a context passage",
"placeholder": "What is the capital of France?",
},
"zero-shot-classification": {
"default": "facebook/bart-large-mnli",
"description": "Classify text into custom categories (no training needed)",
"placeholder": "The new iPhone has an amazing camera and great battery life.",
},
"token-classification": {
"default": "dslim/bert-base-NER",
"description": "Named Entity Recognition β€” find people, places, organizations",
"placeholder": "Elon Musk founded SpaceX in Hawthorne, California.",
},
}
# ═══════════════════════════════════════════════════════════════
# Pipeline cache β€” avoids reloading the same model
# ═══════════════════════════════════════════════════════════════
_cache = {}
def get_pipeline(task, model_id):
"""Load or retrieve a cached pipeline."""
key = f"{task}::{model_id}"
if key not in _cache:
# Clear old cache to save memory (keep only 1 model loaded)
_cache.clear()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
device = 0 if torch.cuda.is_available() else -1
kwargs = {"task": task, "model": model_id, "device": device}
# Use bfloat16 on GPU for memory efficiency
if torch.cuda.is_available():
kwargs["torch_dtype"] = torch.bfloat16
_cache[key] = hf_pipeline(**kwargs)
return _cache[key]
# ═══════════════════════════════════════════════════════════════
# Fetch top models for a task from Hub
# ═══════════════════════════════════════════════════════════════
_model_cache = {}
def get_models_for_task(task):
"""Fetch popular models for a given task from the Hub."""
if task in _model_cache:
return _model_cache[task]
try:
api = HfApi()
models = list(api.list_models(pipeline_tag=task, sort="downloads", limit=15))
model_ids = [m.id for m in models]
# Ensure default model is always first
default = TASK_CONFIG.get(task, {}).get("default", "")
if default in model_ids:
model_ids.remove(default)
model_ids.insert(0, default)
_model_cache[task] = model_ids
return model_ids
except Exception:
default = TASK_CONFIG.get(task, {}).get("default", "gpt2")
return [default]
# ═══════════════════════════════════════════════════════════════
# Main inference function
# ═══════════════════════════════════════════════════════════════
def run_inference(task, model_id, text_input, context, labels, max_tokens, temperature):
"""Run inference on the selected task and model."""
if not text_input or not text_input.strip():
return "⚠️ Please enter some input text.", ""
if not model_id or not model_id.strip():
return "⚠️ Please select or enter a model ID.", ""
try:
status = f"⏳ Loading **{model_id}** for `{task}`..."
pipe = get_pipeline(task, model_id.strip())
status = f"βœ… Running **{model_id}** β€” `{task}`"
# ── Text Generation ──
if task == "text-generation":
result = pipe(
text_input,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
do_sample=True,
)
output = result[0]["generated_text"]
# ── Text Classification ──
elif task == "text-classification":
result = pipe(text_input)
lines = []
for r in result:
bar = "β–ˆ" * int(r["score"] * 30)
lines.append(f"**{r['label']}**: {r['score']:.4f} {bar}")
output = "\n".join(lines)
# ── Summarization ──
elif task == "summarization":
result = pipe(text_input, max_length=int(max_tokens), min_length=30)
output = result[0]["summary_text"]
# ── Translation ──
elif "translation" in task:
result = pipe(text_input, max_length=int(max_tokens))
output = result[0]["translation_text"]
# ── Question Answering ──
elif task == "question-answering":
if not context or not context.strip():
return "⚠️ Question Answering requires a **Context** passage. Please fill it in below.", ""
result = pipe(question=text_input, context=context)
output = f"**Answer:** {result['answer']}\n**Confidence:** {result['score']:.4f}"
# ── Zero-Shot Classification ──
elif task == "zero-shot-classification":
if not labels or not labels.strip():
return "⚠️ Zero-Shot Classification requires **Candidate Labels**. Enter them comma-separated below.", ""
candidate_labels = [l.strip() for l in labels.split(",") if l.strip()]
if len(candidate_labels) < 2:
return "⚠️ Please enter at least 2 comma-separated labels.", ""
result = pipe(text_input, candidate_labels=candidate_labels)
lines = []
for label, score in zip(result["labels"], result["scores"]):
bar = "β–ˆ" * int(score * 30)
lines.append(f"**{label}**: {score:.4f} {bar}")
output = "\n".join(lines)
# ── Token Classification (NER) ──
elif task == "token-classification":
result = pipe(text_input)
if not result:
output = "No entities found."
else:
lines = []
for ent in result:
lines.append(f"🏷️ **{ent['word']}** β†’ `{ent['entity']}` (score: {ent['score']:.3f})")
output = "\n".join(lines)
else:
result = pipe(text_input)
output = str(result)
return output, status
except Exception as e:
error_msg = str(e)
if "gated" in error_msg.lower() or "access" in error_msg.lower():
return (
f"πŸ”’ **Gated Model**: `{model_id}` requires access approval.\n\n"
f"1. Go to [https://huggingface.co/{model_id}](https://huggingface.co/{model_id})\n"
f"2. Accept the license agreement\n"
f"3. Add your HF_TOKEN to this Space's secrets\n\n"
f"Error: {error_msg}",
"❌ Access denied",
)
return f"❌ **Error**: {error_msg}", "❌ Failed"
# ═══════════════════════════════════════════════════════════════
# UI Event Handlers
# ═══════════════════════════════════════════════════════════════
def on_task_change(task):
"""When task changes: update model list, placeholder, and show/hide extra fields."""
models = get_models_for_task(task)
config = TASK_CONFIG.get(task, {})
placeholder = config.get("placeholder", "Enter text here...")
description = config.get("description", "")
show_context = task == "question-answering"
show_labels = task == "zero-shot-classification"
return (
gr.Dropdown(choices=models, value=models[0] if models else ""),
gr.Textbox(placeholder=placeholder),
gr.Textbox(visible=show_context),
gr.Textbox(visible=show_labels),
f"πŸ“‹ *{description}*",
)
# ═══════════════════════════════════════════════════════════════
# Build the Gradio UI
# ═══════════════════════════════════════════════════════════════
css = """
.gradio-container { max-width: 960px !important; margin: 0 auto !important; }
.output-box { min-height: 120px; }
"""
with gr.Blocks(css=css, title="πŸš€ HF Model Runner", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸš€ HF Model Runner
**Run any Hugging Face model** β€” pick a task, choose a model, get results instantly.
"""
)
# ── Task & Model Selection ──
with gr.Row():
task_dropdown = gr.Dropdown(
choices=list(TASK_CONFIG.keys()),
value="text-generation",
label="🎯 Task",
scale=2,
)
model_dropdown = gr.Dropdown(
choices=get_models_for_task("text-generation"),
value=TASK_CONFIG["text-generation"]["default"],
label="πŸ€– Model (type any model ID from the Hub)",
allow_custom_value=True,
scale=3,
)
task_description = gr.Markdown(f"πŸ“‹ *{TASK_CONFIG['text-generation']['description']}*")
# ── Input ──
text_input = gr.Textbox(
label="πŸ“ Input",
placeholder=TASK_CONFIG["text-generation"]["placeholder"],
lines=4,
)
# ── Conditional inputs ──
context_input = gr.Textbox(
label="πŸ“– Context (required for Question Answering)",
placeholder="Paste the context passage here. The model will find the answer within this text.",
lines=3,
visible=False,
)
labels_input = gr.Textbox(
label="🏷️ Candidate Labels (required for Zero-Shot, comma-separated)",
placeholder="technology, sports, politics, science, entertainment",
visible=False,
)
# ── Advanced Options ──
with gr.Accordion("βš™οΈ Advanced Options", open=False):
with gr.Row():
max_tokens = gr.Slider(
minimum=10,
maximum=512,
value=100,
step=10,
label="Max New Tokens",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature (creativity)",
)
# ── Buttons ──
with gr.Row():
run_btn = gr.Button("▢️ Run Model", variant="primary", scale=3)
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
# ── Output ──
output_box = gr.Textbox(
label="πŸ“€ Output",
interactive=False,
lines=8,
show_copy_button=True,
elem_classes=["output-box"],
)
status_md = gr.Markdown("")
# ── GPU Info ──
gpu_info = "πŸ–₯️ GPU: " + (
f"**{torch.cuda.get_device_name(0)}** ({torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB)"
if torch.cuda.is_available()
else "**CPU only** β€” models will run slower. Duplicate this Space with GPU for faster inference."
)
gr.Markdown(gpu_info)
# ── Wire events ──
task_dropdown.change(
fn=on_task_change,
inputs=[task_dropdown],
outputs=[model_dropdown, text_input, context_input, labels_input, task_description],
)
gr.on(
triggers=[run_btn.click, text_input.submit],
fn=run_inference,
inputs=[
task_dropdown,
model_dropdown,
text_input,
context_input,
labels_input,
max_tokens,
temperature,
],
outputs=[output_box, status_md],
)
clear_btn.click(
fn=lambda: ("", "", ""),
outputs=[text_input, output_box, status_md],
)
# ── Examples ──
gr.Markdown("### πŸ’‘ Quick Examples")
gr.Examples(
examples=[
["text-generation", "openai-community/gpt2", "The future of artificial intelligence is", "", "", 100, 0.7],
["text-classification", "distilbert/distilbert-base-uncased-finetuned-sst-2-english", "I had a wonderful experience at this restaurant!", "", "", 100, 0.7],
["summarization", "facebook/bart-large-cnn", "Artificial intelligence (AI) is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans. AI research has been defined as the field of study of intelligent agents, which refers to any system that perceives its environment and takes actions that maximize its chance of achieving its goals.", "", "", 150, 0.7],
["token-classification", "dslim/bert-base-NER", "Elon Musk founded SpaceX in Hawthorne, California in 2002.", "", "", 100, 0.7],
["zero-shot-classification", "facebook/bart-large-mnli", "The stock market saw a significant rally today.", "", "finance, sports, politics, technology", 100, 0.7],
],
inputs=[task_dropdown, model_dropdown, text_input, context_input, labels_input, max_tokens, temperature],
outputs=[output_box, status_md],
fn=run_inference,
cache_examples=False,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()