Spaces:
Sleeping
Sleeping
File size: 8,044 Bytes
3ef75e6 6b7e476 3ef75e6 6b7e476 3ef75e6 6b7e476 3ef75e6 6b7e476 3ef75e6 6b7e476 3ef75e6 6b7e476 3ef75e6 6b7e476 3ef75e6 6b7e476 3ef75e6 6b7e476 3ef75e6 6b7e476 3ef75e6 6b7e476 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | import json
import time
from typing import List, Tuple
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "Qwen/Qwen3-0.6B"
MAX_TEXT_CHARS = 4000
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
)
PRESET_LABELS = {
"Sentiment": "positive, negative, neutral",
"Topic": "tech, business, sports, science, politics",
"Intent": "question, statement, request, complaint",
"Tone": "formal, casual, urgent, friendly",
}
def normalize_labels(raw_labels: str) -> List[str]:
labels = [label.strip() for label in raw_labels.split(",") if label.strip()]
unique_labels = []
seen = set()
for label in labels:
key = label.lower()
if key not in seen:
seen.add(key)
unique_labels.append(label)
return unique_labels
def truncate_text(text: str) -> Tuple[str, bool]:
text = (text or "").strip()
if len(text) <= MAX_TEXT_CHARS:
return text, False
return text[:MAX_TEXT_CHARS], True
def build_bar_chart(labels: List[str], scores: List[float]) -> str:
if not labels:
return "<p>No labels to display.</p>"
rows = []
max_score = max(scores) if scores else 1.0
for label, score in zip(labels, scores):
pct = score * 100
width = (score / max_score) * 100 if max_score > 0 else 0
rows.append(
f"""
<div style='margin: 8px 0;'>
<div style='display:flex; justify-content:space-between; font-size:14px;'>
<span><b>{label}</b></span>
<span>{pct:.2f}%</span>
</div>
<div style='background:#e5e7eb; border-radius:8px; overflow:hidden;'>
<div style='width:{width:.2f}%; background:#2563eb; color:white; padding:4px 8px; font-size:12px;'>
{pct:.2f}%
</div>
</div>
</div>
"""
)
return "<div>" + "\n".join(rows) + "</div>"
def apply_preset(preset_name: str) -> str:
return PRESET_LABELS.get(preset_name, "")
def build_classification_prompt(text: str, labels: List[str], multi_label: bool) -> str:
labels_str = ", ".join(f'"{l}"' for l in labels)
mode_instruction = (
"Multiple labels can apply simultaneously. For each label, assign a confidence score between 0 and 1."
if multi_label
else "Choose the single best label. Assign confidence scores that sum to 1."
)
return (
f"Classify the following text into these categories: {labels_str}\n\n"
f"{mode_instruction}\n\n"
f"Text: \"{text}\"\n\n"
f"Respond with ONLY a JSON object mapping each label to its confidence score. "
f"Example: {{{', '.join(f'\"{l}\": 0.5' for l in labels[:2])}}}\n"
f"JSON:"
)
def parse_scores(output: str, labels: List[str]) -> dict:
"""Extract label scores from model output, with fallback parsing."""
# Try to find JSON in the output
output = output.strip()
# Find the first { and last }
start = output.find("{")
end = output.rfind("}")
if start != -1 and end != -1 and end > start:
json_str = output[start : end + 1]
try:
parsed = json.loads(json_str)
scores = {}
for label in labels:
# Try exact match, then case-insensitive
if label in parsed:
scores[label] = float(parsed[label])
else:
lower_map = {k.lower(): v for k, v in parsed.items()}
scores[label] = float(lower_map.get(label.lower(), 0.0))
return scores
except (json.JSONDecodeError, ValueError):
pass
# Fallback: equal scores
return {label: 1.0 / len(labels) for label in labels}
@spaces.GPU
def run_classification(text: str, candidate_labels: str, multi_label: bool):
clean_text, was_truncated = truncate_text(text)
labels = normalize_labels(candidate_labels)
if not clean_text:
raise gr.Error("Please enter text to classify.")
if len(labels) < 2:
raise gr.Error("Please provide at least 2 labels, separated by commas.")
prompt = build_classification_prompt(clean_text, labels, multi_label)
messages = [
{"role": "system", "content": "You are a precise text classifier. Respond only with valid JSON."},
{"role": "user", "content": prompt},
]
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
start = time.perf_counter()
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1,
do_sample=True,
top_p=0.9,
)
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
elapsed = time.perf_counter() - start
scores = parse_scores(generated, labels)
# Normalize scores
total = sum(scores.values())
if total > 0:
scores = {k: v / total for k, v in scores.items()}
else:
scores = {k: 1.0 / len(labels) for k in labels}
sorted_pairs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
sorted_labels = [x[0] for x in sorted_pairs]
sorted_scores = [x[1] for x in sorted_pairs]
chart_html = build_bar_chart(sorted_labels, sorted_scores)
top_label = sorted_labels[0]
top_score = sorted_scores[0] * 100
truncation_note = (
f" Input was truncated to {MAX_TEXT_CHARS} characters for stable inference."
if was_truncated
else ""
)
summary = (
f"Top prediction: {top_label} ({top_score:.2f}%). "
f"Model: Qwen3-0.6B. "
f"Mode: {'multi-label' if multi_label else 'single-label'}. "
f"Inference time: {elapsed:.3f}s.{truncation_note}"
)
return chart_html, summary
with gr.Blocks(theme=gr.themes.Soft(), title="Zero-Shot Text Classifier") as demo:
gr.Markdown("# Zero-Shot Text Classifier")
gr.Markdown(
"Classify any text into custom labels using **Qwen3-0.6B** with zero-shot instruction prompting. "
"No fine-tuning needed: define your own categories and classify instantly."
)
with gr.Row():
preset = gr.Dropdown(
choices=list(PRESET_LABELS.keys()),
value="Sentiment",
label="Preset Label Sets",
info="Pick a preset to auto-fill candidate labels.",
)
apply_btn = gr.Button("Apply Preset")
text_input = gr.Textbox(
label="Text to Classify",
placeholder="Paste a sentence, paragraph, or document...",
lines=8,
)
labels_input = gr.Textbox(
label="Candidate Labels (comma-separated)",
value=PRESET_LABELS["Sentiment"],
placeholder="positive, negative, neutral",
lines=2,
)
multi_label_input = gr.Checkbox(
label="Multi-label mode",
value=False,
info="If enabled, multiple labels can be true at the same time.",
)
classify_btn = gr.Button("Classify", variant="primary")
chart_output = gr.HTML(label="Label Scores")
summary_output = gr.Textbox(label="Summary", interactive=False)
apply_btn.click(apply_preset, inputs=preset, outputs=labels_input)
classify_btn.click(
run_classification,
inputs=[text_input, labels_input, multi_label_input],
outputs=[chart_output, summary_output],
)
gr.Markdown(
"Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ "
"[AI Enablement Academy](https://enablement.academy) | "
"[Buy me a coffee ☕](https://ko-fi.com/xavierfuentes)"
)
if __name__ == "__main__":
demo.queue().launch()
|