import json import time from typing import List, Tuple import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_ID = "Qwen/Qwen3-0.6B" MAX_TEXT_CHARS = 4000 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device_map="auto", ) PRESET_LABELS = { "Sentiment": "positive, negative, neutral", "Topic": "tech, business, sports, science, politics", "Intent": "question, statement, request, complaint", "Tone": "formal, casual, urgent, friendly", } def normalize_labels(raw_labels: str) -> List[str]: labels = [label.strip() for label in raw_labels.split(",") if label.strip()] unique_labels = [] seen = set() for label in labels: key = label.lower() if key not in seen: seen.add(key) unique_labels.append(label) return unique_labels def truncate_text(text: str) -> Tuple[str, bool]: text = (text or "").strip() if len(text) <= MAX_TEXT_CHARS: return text, False return text[:MAX_TEXT_CHARS], True def build_bar_chart(labels: List[str], scores: List[float]) -> str: if not labels: return "

No labels to display.

" rows = [] max_score = max(scores) if scores else 1.0 for label, score in zip(labels, scores): pct = score * 100 width = (score / max_score) * 100 if max_score > 0 else 0 rows.append( f"""
{label} {pct:.2f}%
{pct:.2f}%
""" ) return "
" + "\n".join(rows) + "
" def apply_preset(preset_name: str) -> str: return PRESET_LABELS.get(preset_name, "") def build_classification_prompt(text: str, labels: List[str], multi_label: bool) -> str: labels_str = ", ".join(f'"{l}"' for l in labels) mode_instruction = ( "Multiple labels can apply simultaneously. For each label, assign a confidence score between 0 and 1." if multi_label else "Choose the single best label. Assign confidence scores that sum to 1." ) return ( f"Classify the following text into these categories: {labels_str}\n\n" f"{mode_instruction}\n\n" f"Text: \"{text}\"\n\n" f"Respond with ONLY a JSON object mapping each label to its confidence score. " f"Example: {{{', '.join(f'\"{l}\": 0.5' for l in labels[:2])}}}\n" f"JSON:" ) def parse_scores(output: str, labels: List[str]) -> dict: """Extract label scores from model output, with fallback parsing.""" # Try to find JSON in the output output = output.strip() # Find the first { and last } start = output.find("{") end = output.rfind("}") if start != -1 and end != -1 and end > start: json_str = output[start : end + 1] try: parsed = json.loads(json_str) scores = {} for label in labels: # Try exact match, then case-insensitive if label in parsed: scores[label] = float(parsed[label]) else: lower_map = {k.lower(): v for k, v in parsed.items()} scores[label] = float(lower_map.get(label.lower(), 0.0)) return scores except (json.JSONDecodeError, ValueError): pass # Fallback: equal scores return {label: 1.0 / len(labels) for label in labels} @spaces.GPU def run_classification(text: str, candidate_labels: str, multi_label: bool): clean_text, was_truncated = truncate_text(text) labels = normalize_labels(candidate_labels) if not clean_text: raise gr.Error("Please enter text to classify.") if len(labels) < 2: raise gr.Error("Please provide at least 2 labels, separated by commas.") prompt = build_classification_prompt(clean_text, labels, multi_label) messages = [ {"role": "system", "content": "You are a precise text classifier. Respond only with valid JSON."}, {"role": "user", "content": prompt}, ] input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) start = time.perf_counter() inputs = tokenizer(input_text, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, temperature=0.1, do_sample=True, top_p=0.9, ) generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) elapsed = time.perf_counter() - start scores = parse_scores(generated, labels) # Normalize scores total = sum(scores.values()) if total > 0: scores = {k: v / total for k, v in scores.items()} else: scores = {k: 1.0 / len(labels) for k in labels} sorted_pairs = sorted(scores.items(), key=lambda x: x[1], reverse=True) sorted_labels = [x[0] for x in sorted_pairs] sorted_scores = [x[1] for x in sorted_pairs] chart_html = build_bar_chart(sorted_labels, sorted_scores) top_label = sorted_labels[0] top_score = sorted_scores[0] * 100 truncation_note = ( f" Input was truncated to {MAX_TEXT_CHARS} characters for stable inference." if was_truncated else "" ) summary = ( f"Top prediction: {top_label} ({top_score:.2f}%). " f"Model: Qwen3-0.6B. " f"Mode: {'multi-label' if multi_label else 'single-label'}. " f"Inference time: {elapsed:.3f}s.{truncation_note}" ) return chart_html, summary with gr.Blocks(theme=gr.themes.Soft(), title="Zero-Shot Text Classifier") as demo: gr.Markdown("# Zero-Shot Text Classifier") gr.Markdown( "Classify any text into custom labels using **Qwen3-0.6B** with zero-shot instruction prompting. " "No fine-tuning needed: define your own categories and classify instantly." ) with gr.Row(): preset = gr.Dropdown( choices=list(PRESET_LABELS.keys()), value="Sentiment", label="Preset Label Sets", info="Pick a preset to auto-fill candidate labels.", ) apply_btn = gr.Button("Apply Preset") text_input = gr.Textbox( label="Text to Classify", placeholder="Paste a sentence, paragraph, or document...", lines=8, ) labels_input = gr.Textbox( label="Candidate Labels (comma-separated)", value=PRESET_LABELS["Sentiment"], placeholder="positive, negative, neutral", lines=2, ) multi_label_input = gr.Checkbox( label="Multi-label mode", value=False, info="If enabled, multiple labels can be true at the same time.", ) classify_btn = gr.Button("Classify", variant="primary") chart_output = gr.HTML(label="Label Scores") summary_output = gr.Textbox(label="Summary", interactive=False) apply_btn.click(apply_preset, inputs=preset, outputs=labels_input) classify_btn.click( run_classification, inputs=[text_input, labels_input, multi_label_input], outputs=[chart_output, summary_output], ) gr.Markdown( "Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ " "[AI Enablement Academy](https://enablement.academy) | " "[Buy me a coffee ☕](https://ko-fi.com/xavierfuentes)" ) if __name__ == "__main__": demo.queue().launch()