File size: 15,447 Bytes
3a73974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import gradio as gr
import torch
from transformers import pipeline as hf_pipeline
from huggingface_hub import HfApi

# ═══════════════════════════════════════════════════════════════
# Task configuration β€” default models and input/output types
# ═══════════════════════════════════════════════════════════════
TASK_CONFIG = {
    "text-generation": {
        "default": "openai-community/gpt2",
        "description": "Generate text from a prompt",
        "placeholder": "Once upon a time in a land far away...",
    },
    "text-classification": {
        "default": "distilbert/distilbert-base-uncased-finetuned-sst-2-english",
        "description": "Classify text into categories (sentiment, topic, etc.)",
        "placeholder": "I absolutely loved this movie! The acting was superb.",
    },
    "summarization": {
        "default": "facebook/bart-large-cnn",
        "description": "Summarize long text into a concise version",
        "placeholder": "Paste a long article or paragraph here to get a summary...",
    },
    "translation_en_to_fr": {
        "default": "Helsinki-NLP/opus-mt-en-fr",
        "description": "Translate English text to French",
        "placeholder": "Hello, how are you today?",
    },
    "question-answering": {
        "default": "deepset/roberta-base-squad2",
        "description": "Answer a question given a context passage",
        "placeholder": "What is the capital of France?",
    },
    "zero-shot-classification": {
        "default": "facebook/bart-large-mnli",
        "description": "Classify text into custom categories (no training needed)",
        "placeholder": "The new iPhone has an amazing camera and great battery life.",
    },
    "token-classification": {
        "default": "dslim/bert-base-NER",
        "description": "Named Entity Recognition β€” find people, places, organizations",
        "placeholder": "Elon Musk founded SpaceX in Hawthorne, California.",
    },
}

# ═══════════════════════════════════════════════════════════════
# Pipeline cache β€” avoids reloading the same model
# ═══════════════════════════════════════════════════════════════
_cache = {}


def get_pipeline(task, model_id):
    """Load or retrieve a cached pipeline."""
    key = f"{task}::{model_id}"
    if key not in _cache:
        # Clear old cache to save memory (keep only 1 model loaded)
        _cache.clear()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

        device = 0 if torch.cuda.is_available() else -1
        kwargs = {"task": task, "model": model_id, "device": device}

        # Use bfloat16 on GPU for memory efficiency
        if torch.cuda.is_available():
            kwargs["torch_dtype"] = torch.bfloat16

        _cache[key] = hf_pipeline(**kwargs)
    return _cache[key]


# ═══════════════════════════════════════════════════════════════
# Fetch top models for a task from Hub
# ═══════════════════════════════════════════════════════════════
_model_cache = {}


def get_models_for_task(task):
    """Fetch popular models for a given task from the Hub."""
    if task in _model_cache:
        return _model_cache[task]
    try:
        api = HfApi()
        models = list(api.list_models(pipeline_tag=task, sort="downloads", limit=15))
        model_ids = [m.id for m in models]
        # Ensure default model is always first
        default = TASK_CONFIG.get(task, {}).get("default", "")
        if default in model_ids:
            model_ids.remove(default)
        model_ids.insert(0, default)
        _model_cache[task] = model_ids
        return model_ids
    except Exception:
        default = TASK_CONFIG.get(task, {}).get("default", "gpt2")
        return [default]


# ═══════════════════════════════════════════════════════════════
# Main inference function
# ═══════════════════════════════════════════════════════════════
def run_inference(task, model_id, text_input, context, labels, max_tokens, temperature):
    """Run inference on the selected task and model."""
    if not text_input or not text_input.strip():
        return "⚠️ Please enter some input text.", ""

    if not model_id or not model_id.strip():
        return "⚠️ Please select or enter a model ID.", ""

    try:
        status = f"⏳ Loading **{model_id}** for `{task}`..."
        pipe = get_pipeline(task, model_id.strip())
        status = f"βœ… Running **{model_id}** β€” `{task}`"

        # ── Text Generation ──
        if task == "text-generation":
            result = pipe(
                text_input,
                max_new_tokens=int(max_tokens),
                temperature=float(temperature),
                do_sample=True,
            )
            output = result[0]["generated_text"]

        # ── Text Classification ──
        elif task == "text-classification":
            result = pipe(text_input)
            lines = []
            for r in result:
                bar = "β–ˆ" * int(r["score"] * 30)
                lines.append(f"**{r['label']}**: {r['score']:.4f}  {bar}")
            output = "\n".join(lines)

        # ── Summarization ──
        elif task == "summarization":
            result = pipe(text_input, max_length=int(max_tokens), min_length=30)
            output = result[0]["summary_text"]

        # ── Translation ──
        elif "translation" in task:
            result = pipe(text_input, max_length=int(max_tokens))
            output = result[0]["translation_text"]

        # ── Question Answering ──
        elif task == "question-answering":
            if not context or not context.strip():
                return "⚠️ Question Answering requires a **Context** passage. Please fill it in below.", ""
            result = pipe(question=text_input, context=context)
            output = f"**Answer:** {result['answer']}\n**Confidence:** {result['score']:.4f}"

        # ── Zero-Shot Classification ──
        elif task == "zero-shot-classification":
            if not labels or not labels.strip():
                return "⚠️ Zero-Shot Classification requires **Candidate Labels**. Enter them comma-separated below.", ""
            candidate_labels = [l.strip() for l in labels.split(",") if l.strip()]
            if len(candidate_labels) < 2:
                return "⚠️ Please enter at least 2 comma-separated labels.", ""
            result = pipe(text_input, candidate_labels=candidate_labels)
            lines = []
            for label, score in zip(result["labels"], result["scores"]):
                bar = "β–ˆ" * int(score * 30)
                lines.append(f"**{label}**: {score:.4f}  {bar}")
            output = "\n".join(lines)

        # ── Token Classification (NER) ──
        elif task == "token-classification":
            result = pipe(text_input)
            if not result:
                output = "No entities found."
            else:
                lines = []
                for ent in result:
                    lines.append(f"🏷️ **{ent['word']}** β†’ `{ent['entity']}` (score: {ent['score']:.3f})")
                output = "\n".join(lines)

        else:
            result = pipe(text_input)
            output = str(result)

        return output, status

    except Exception as e:
        error_msg = str(e)
        if "gated" in error_msg.lower() or "access" in error_msg.lower():
            return (
                f"πŸ”’ **Gated Model**: `{model_id}` requires access approval.\n\n"
                f"1. Go to [https://huggingface.co/{model_id}](https://huggingface.co/{model_id})\n"
                f"2. Accept the license agreement\n"
                f"3. Add your HF_TOKEN to this Space's secrets\n\n"
                f"Error: {error_msg}",
                "❌ Access denied",
            )
        return f"❌ **Error**: {error_msg}", "❌ Failed"


# ═══════════════════════════════════════════════════════════════
# UI Event Handlers
# ═══════════════════════════════════════════════════════════════
def on_task_change(task):
    """When task changes: update model list, placeholder, and show/hide extra fields."""
    models = get_models_for_task(task)
    config = TASK_CONFIG.get(task, {})
    placeholder = config.get("placeholder", "Enter text here...")
    description = config.get("description", "")

    show_context = task == "question-answering"
    show_labels = task == "zero-shot-classification"

    return (
        gr.Dropdown(choices=models, value=models[0] if models else ""),
        gr.Textbox(placeholder=placeholder),
        gr.Textbox(visible=show_context),
        gr.Textbox(visible=show_labels),
        f"πŸ“‹ *{description}*",
    )


# ═══════════════════════════════════════════════════════════════
# Build the Gradio UI
# ═══════════════════════════════════════════════════════════════
css = """
.gradio-container { max-width: 960px !important; margin: 0 auto !important; }
.output-box { min-height: 120px; }
"""

with gr.Blocks(css=css, title="πŸš€ HF Model Runner", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # πŸš€ HF Model Runner
        **Run any Hugging Face model** β€” pick a task, choose a model, get results instantly.
        """
    )

    # ── Task & Model Selection ──
    with gr.Row():
        task_dropdown = gr.Dropdown(
            choices=list(TASK_CONFIG.keys()),
            value="text-generation",
            label="🎯 Task",
            scale=2,
        )
        model_dropdown = gr.Dropdown(
            choices=get_models_for_task("text-generation"),
            value=TASK_CONFIG["text-generation"]["default"],
            label="πŸ€– Model (type any model ID from the Hub)",
            allow_custom_value=True,
            scale=3,
        )

    task_description = gr.Markdown(f"πŸ“‹ *{TASK_CONFIG['text-generation']['description']}*")

    # ── Input ──
    text_input = gr.Textbox(
        label="πŸ“ Input",
        placeholder=TASK_CONFIG["text-generation"]["placeholder"],
        lines=4,
    )

    # ── Conditional inputs ──
    context_input = gr.Textbox(
        label="πŸ“– Context (required for Question Answering)",
        placeholder="Paste the context passage here. The model will find the answer within this text.",
        lines=3,
        visible=False,
    )
    labels_input = gr.Textbox(
        label="🏷️ Candidate Labels (required for Zero-Shot, comma-separated)",
        placeholder="technology, sports, politics, science, entertainment",
        visible=False,
    )

    # ── Advanced Options ──
    with gr.Accordion("βš™οΈ Advanced Options", open=False):
        with gr.Row():
            max_tokens = gr.Slider(
                minimum=10,
                maximum=512,
                value=100,
                step=10,
                label="Max New Tokens",
            )
            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=0.7,
                step=0.1,
                label="Temperature (creativity)",
            )

    # ── Buttons ──
    with gr.Row():
        run_btn = gr.Button("▢️ Run Model", variant="primary", scale=3)
        clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)

    # ── Output ──
    output_box = gr.Textbox(
        label="πŸ“€ Output",
        interactive=False,
        lines=8,
        show_copy_button=True,
        elem_classes=["output-box"],
    )
    status_md = gr.Markdown("")

    # ── GPU Info ──
    gpu_info = "πŸ–₯️ GPU: " + (
        f"**{torch.cuda.get_device_name(0)}** ({torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB)"
        if torch.cuda.is_available()
        else "**CPU only** β€” models will run slower. Duplicate this Space with GPU for faster inference."
    )
    gr.Markdown(gpu_info)

    # ── Wire events ──
    task_dropdown.change(
        fn=on_task_change,
        inputs=[task_dropdown],
        outputs=[model_dropdown, text_input, context_input, labels_input, task_description],
    )

    gr.on(
        triggers=[run_btn.click, text_input.submit],
        fn=run_inference,
        inputs=[
            task_dropdown,
            model_dropdown,
            text_input,
            context_input,
            labels_input,
            max_tokens,
            temperature,
        ],
        outputs=[output_box, status_md],
    )

    clear_btn.click(
        fn=lambda: ("", "", ""),
        outputs=[text_input, output_box, status_md],
    )

    # ── Examples ──
    gr.Markdown("### πŸ’‘ Quick Examples")
    gr.Examples(
        examples=[
            ["text-generation", "openai-community/gpt2", "The future of artificial intelligence is", "", "", 100, 0.7],
            ["text-classification", "distilbert/distilbert-base-uncased-finetuned-sst-2-english", "I had a wonderful experience at this restaurant!", "", "", 100, 0.7],
            ["summarization", "facebook/bart-large-cnn", "Artificial intelligence (AI) is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans. AI research has been defined as the field of study of intelligent agents, which refers to any system that perceives its environment and takes actions that maximize its chance of achieving its goals.", "", "", 150, 0.7],
            ["token-classification", "dslim/bert-base-NER", "Elon Musk founded SpaceX in Hawthorne, California in 2002.", "", "", 100, 0.7],
            ["zero-shot-classification", "facebook/bart-large-mnli", "The stock market saw a significant rally today.", "", "finance, sports, politics, technology", 100, 0.7],
        ],
        inputs=[task_dropdown, model_dropdown, text_input, context_input, labels_input, max_tokens, temperature],
        outputs=[output_box, status_md],
        fn=run_inference,
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.queue(max_size=20).launch()