File size: 8,044 Bytes
3ef75e6
6b7e476
 
 
 
 
3ef75e6
 
6b7e476
3ef75e6
6b7e476
 
3ef75e6
 
 
 
 
 
6b7e476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ef75e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b7e476
 
 
 
 
 
 
 
 
 
3ef75e6
 
 
 
 
 
 
 
 
 
 
 
6b7e476
3ef75e6
 
 
 
 
 
 
 
 
 
 
 
6b7e476
 
3ef75e6
 
 
 
 
 
 
 
 
 
6b7e476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ef75e6
6b7e476
3ef75e6
6b7e476
 
 
 
 
 
 
 
3ef75e6
 
6b7e476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import json
import time
from typing import List, Tuple

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "Qwen/Qwen3-0.6B"
MAX_TEXT_CHARS = 4000

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
)

PRESET_LABELS = {
    "Sentiment": "positive, negative, neutral",
    "Topic": "tech, business, sports, science, politics",
    "Intent": "question, statement, request, complaint",
    "Tone": "formal, casual, urgent, friendly",
}


def normalize_labels(raw_labels: str) -> List[str]:
    labels = [label.strip() for label in raw_labels.split(",") if label.strip()]
    unique_labels = []
    seen = set()
    for label in labels:
        key = label.lower()
        if key not in seen:
            seen.add(key)
            unique_labels.append(label)
    return unique_labels


def truncate_text(text: str) -> Tuple[str, bool]:
    text = (text or "").strip()
    if len(text) <= MAX_TEXT_CHARS:
        return text, False
    return text[:MAX_TEXT_CHARS], True


def build_bar_chart(labels: List[str], scores: List[float]) -> str:
    if not labels:
        return "<p>No labels to display.</p>"

    rows = []
    max_score = max(scores) if scores else 1.0
    for label, score in zip(labels, scores):
        pct = score * 100
        width = (score / max_score) * 100 if max_score > 0 else 0
        rows.append(
            f"""
            <div style='margin: 8px 0;'>
              <div style='display:flex; justify-content:space-between; font-size:14px;'>
                <span><b>{label}</b></span>
                <span>{pct:.2f}%</span>
              </div>
              <div style='background:#e5e7eb; border-radius:8px; overflow:hidden;'>
                <div style='width:{width:.2f}%; background:#2563eb; color:white; padding:4px 8px; font-size:12px;'>
                  {pct:.2f}%
                </div>
              </div>
            </div>
            """
        )

    return "<div>" + "\n".join(rows) + "</div>"


def apply_preset(preset_name: str) -> str:
    return PRESET_LABELS.get(preset_name, "")


def build_classification_prompt(text: str, labels: List[str], multi_label: bool) -> str:
    labels_str = ", ".join(f'"{l}"' for l in labels)
    mode_instruction = (
        "Multiple labels can apply simultaneously. For each label, assign a confidence score between 0 and 1."
        if multi_label
        else "Choose the single best label. Assign confidence scores that sum to 1."
    )

    return (
        f"Classify the following text into these categories: {labels_str}\n\n"
        f"{mode_instruction}\n\n"
        f"Text: \"{text}\"\n\n"
        f"Respond with ONLY a JSON object mapping each label to its confidence score. "
        f"Example: {{{', '.join(f'\"{l}\": 0.5' for l in labels[:2])}}}\n"
        f"JSON:"
    )


def parse_scores(output: str, labels: List[str]) -> dict:
    """Extract label scores from model output, with fallback parsing."""
    # Try to find JSON in the output
    output = output.strip()

    # Find the first { and last }
    start = output.find("{")
    end = output.rfind("}")
    if start != -1 and end != -1 and end > start:
        json_str = output[start : end + 1]
        try:
            parsed = json.loads(json_str)
            scores = {}
            for label in labels:
                # Try exact match, then case-insensitive
                if label in parsed:
                    scores[label] = float(parsed[label])
                else:
                    lower_map = {k.lower(): v for k, v in parsed.items()}
                    scores[label] = float(lower_map.get(label.lower(), 0.0))
            return scores
        except (json.JSONDecodeError, ValueError):
            pass

    # Fallback: equal scores
    return {label: 1.0 / len(labels) for label in labels}


@spaces.GPU
def run_classification(text: str, candidate_labels: str, multi_label: bool):
    clean_text, was_truncated = truncate_text(text)
    labels = normalize_labels(candidate_labels)

    if not clean_text:
        raise gr.Error("Please enter text to classify.")
    if len(labels) < 2:
        raise gr.Error("Please provide at least 2 labels, separated by commas.")

    prompt = build_classification_prompt(clean_text, labels, multi_label)

    messages = [
        {"role": "system", "content": "You are a precise text classifier. Respond only with valid JSON."},
        {"role": "user", "content": prompt},
    ]

    input_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True,
        enable_thinking=False,
    )

    start = time.perf_counter()

    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.1,
            do_sample=True,
            top_p=0.9,
        )

    generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    elapsed = time.perf_counter() - start

    scores = parse_scores(generated, labels)

    # Normalize scores
    total = sum(scores.values())
    if total > 0:
        scores = {k: v / total for k, v in scores.items()}
    else:
        scores = {k: 1.0 / len(labels) for k in labels}

    sorted_pairs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    sorted_labels = [x[0] for x in sorted_pairs]
    sorted_scores = [x[1] for x in sorted_pairs]

    chart_html = build_bar_chart(sorted_labels, sorted_scores)

    top_label = sorted_labels[0]
    top_score = sorted_scores[0] * 100
    truncation_note = (
        f" Input was truncated to {MAX_TEXT_CHARS} characters for stable inference."
        if was_truncated
        else ""
    )

    summary = (
        f"Top prediction: {top_label} ({top_score:.2f}%). "
        f"Model: Qwen3-0.6B. "
        f"Mode: {'multi-label' if multi_label else 'single-label'}. "
        f"Inference time: {elapsed:.3f}s.{truncation_note}"
    )

    return chart_html, summary


with gr.Blocks(theme=gr.themes.Soft(), title="Zero-Shot Text Classifier") as demo:
    gr.Markdown("# Zero-Shot Text Classifier")
    gr.Markdown(
        "Classify any text into custom labels using **Qwen3-0.6B** with zero-shot instruction prompting. "
        "No fine-tuning needed: define your own categories and classify instantly."
    )

    with gr.Row():
        preset = gr.Dropdown(
            choices=list(PRESET_LABELS.keys()),
            value="Sentiment",
            label="Preset Label Sets",
            info="Pick a preset to auto-fill candidate labels.",
        )
        apply_btn = gr.Button("Apply Preset")

    text_input = gr.Textbox(
        label="Text to Classify",
        placeholder="Paste a sentence, paragraph, or document...",
        lines=8,
    )

    labels_input = gr.Textbox(
        label="Candidate Labels (comma-separated)",
        value=PRESET_LABELS["Sentiment"],
        placeholder="positive, negative, neutral",
        lines=2,
    )

    multi_label_input = gr.Checkbox(
        label="Multi-label mode",
        value=False,
        info="If enabled, multiple labels can be true at the same time.",
    )

    classify_btn = gr.Button("Classify", variant="primary")

    chart_output = gr.HTML(label="Label Scores")
    summary_output = gr.Textbox(label="Summary", interactive=False)

    apply_btn.click(apply_preset, inputs=preset, outputs=labels_input)
    classify_btn.click(
        run_classification,
        inputs=[text_input, labels_input, multi_label_input],
        outputs=[chart_output, summary_output],
    )

    gr.Markdown(
        "Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ "
        "[AI Enablement Academy](https://enablement.academy) | "
        "[Buy me a coffee ☕](https://ko-fi.com/xavierfuentes)"
    )


if __name__ == "__main__":
    demo.queue().launch()