text-classifier / app.py
xavier-fuentes's picture
Upload folder using huggingface_hub
3ef75e6 verified
import json
import time
from typing import List, Tuple
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "Qwen/Qwen3-0.6B"
MAX_TEXT_CHARS = 4000
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
)
PRESET_LABELS = {
"Sentiment": "positive, negative, neutral",
"Topic": "tech, business, sports, science, politics",
"Intent": "question, statement, request, complaint",
"Tone": "formal, casual, urgent, friendly",
}
def normalize_labels(raw_labels: str) -> List[str]:
labels = [label.strip() for label in raw_labels.split(",") if label.strip()]
unique_labels = []
seen = set()
for label in labels:
key = label.lower()
if key not in seen:
seen.add(key)
unique_labels.append(label)
return unique_labels
def truncate_text(text: str) -> Tuple[str, bool]:
text = (text or "").strip()
if len(text) <= MAX_TEXT_CHARS:
return text, False
return text[:MAX_TEXT_CHARS], True
def build_bar_chart(labels: List[str], scores: List[float]) -> str:
if not labels:
return "<p>No labels to display.</p>"
rows = []
max_score = max(scores) if scores else 1.0
for label, score in zip(labels, scores):
pct = score * 100
width = (score / max_score) * 100 if max_score > 0 else 0
rows.append(
f"""
<div style='margin: 8px 0;'>
<div style='display:flex; justify-content:space-between; font-size:14px;'>
<span><b>{label}</b></span>
<span>{pct:.2f}%</span>
</div>
<div style='background:#e5e7eb; border-radius:8px; overflow:hidden;'>
<div style='width:{width:.2f}%; background:#2563eb; color:white; padding:4px 8px; font-size:12px;'>
{pct:.2f}%
</div>
</div>
</div>
"""
)
return "<div>" + "\n".join(rows) + "</div>"
def apply_preset(preset_name: str) -> str:
return PRESET_LABELS.get(preset_name, "")
def build_classification_prompt(text: str, labels: List[str], multi_label: bool) -> str:
labels_str = ", ".join(f'"{l}"' for l in labels)
mode_instruction = (
"Multiple labels can apply simultaneously. For each label, assign a confidence score between 0 and 1."
if multi_label
else "Choose the single best label. Assign confidence scores that sum to 1."
)
return (
f"Classify the following text into these categories: {labels_str}\n\n"
f"{mode_instruction}\n\n"
f"Text: \"{text}\"\n\n"
f"Respond with ONLY a JSON object mapping each label to its confidence score. "
f"Example: {{{', '.join(f'\"{l}\": 0.5' for l in labels[:2])}}}\n"
f"JSON:"
)
def parse_scores(output: str, labels: List[str]) -> dict:
"""Extract label scores from model output, with fallback parsing."""
# Try to find JSON in the output
output = output.strip()
# Find the first { and last }
start = output.find("{")
end = output.rfind("}")
if start != -1 and end != -1 and end > start:
json_str = output[start : end + 1]
try:
parsed = json.loads(json_str)
scores = {}
for label in labels:
# Try exact match, then case-insensitive
if label in parsed:
scores[label] = float(parsed[label])
else:
lower_map = {k.lower(): v for k, v in parsed.items()}
scores[label] = float(lower_map.get(label.lower(), 0.0))
return scores
except (json.JSONDecodeError, ValueError):
pass
# Fallback: equal scores
return {label: 1.0 / len(labels) for label in labels}
@spaces.GPU
def run_classification(text: str, candidate_labels: str, multi_label: bool):
clean_text, was_truncated = truncate_text(text)
labels = normalize_labels(candidate_labels)
if not clean_text:
raise gr.Error("Please enter text to classify.")
if len(labels) < 2:
raise gr.Error("Please provide at least 2 labels, separated by commas.")
prompt = build_classification_prompt(clean_text, labels, multi_label)
messages = [
{"role": "system", "content": "You are a precise text classifier. Respond only with valid JSON."},
{"role": "user", "content": prompt},
]
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
start = time.perf_counter()
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.1,
do_sample=True,
top_p=0.9,
)
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
elapsed = time.perf_counter() - start
scores = parse_scores(generated, labels)
# Normalize scores
total = sum(scores.values())
if total > 0:
scores = {k: v / total for k, v in scores.items()}
else:
scores = {k: 1.0 / len(labels) for k in labels}
sorted_pairs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
sorted_labels = [x[0] for x in sorted_pairs]
sorted_scores = [x[1] for x in sorted_pairs]
chart_html = build_bar_chart(sorted_labels, sorted_scores)
top_label = sorted_labels[0]
top_score = sorted_scores[0] * 100
truncation_note = (
f" Input was truncated to {MAX_TEXT_CHARS} characters for stable inference."
if was_truncated
else ""
)
summary = (
f"Top prediction: {top_label} ({top_score:.2f}%). "
f"Model: Qwen3-0.6B. "
f"Mode: {'multi-label' if multi_label else 'single-label'}. "
f"Inference time: {elapsed:.3f}s.{truncation_note}"
)
return chart_html, summary
with gr.Blocks(theme=gr.themes.Soft(), title="Zero-Shot Text Classifier") as demo:
gr.Markdown("# Zero-Shot Text Classifier")
gr.Markdown(
"Classify any text into custom labels using **Qwen3-0.6B** with zero-shot instruction prompting. "
"No fine-tuning needed: define your own categories and classify instantly."
)
with gr.Row():
preset = gr.Dropdown(
choices=list(PRESET_LABELS.keys()),
value="Sentiment",
label="Preset Label Sets",
info="Pick a preset to auto-fill candidate labels.",
)
apply_btn = gr.Button("Apply Preset")
text_input = gr.Textbox(
label="Text to Classify",
placeholder="Paste a sentence, paragraph, or document...",
lines=8,
)
labels_input = gr.Textbox(
label="Candidate Labels (comma-separated)",
value=PRESET_LABELS["Sentiment"],
placeholder="positive, negative, neutral",
lines=2,
)
multi_label_input = gr.Checkbox(
label="Multi-label mode",
value=False,
info="If enabled, multiple labels can be true at the same time.",
)
classify_btn = gr.Button("Classify", variant="primary")
chart_output = gr.HTML(label="Label Scores")
summary_output = gr.Textbox(label="Summary", interactive=False)
apply_btn.click(apply_preset, inputs=preset, outputs=labels_input)
classify_btn.click(
run_classification,
inputs=[text_input, labels_input, multi_label_input],
outputs=[chart_output, summary_output],
)
gr.Markdown(
"Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ "
"[AI Enablement Academy](https://enablement.academy) | "
"[Buy me a coffee ☕](https://ko-fi.com/xavierfuentes)"
)
if __name__ == "__main__":
demo.queue().launch()