Spaces:
Sleeping
Sleeping
| import json | |
| import time | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_ID = "Qwen/Qwen3-0.6B" | |
| MAX_TEXT_CHARS = 4000 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| PRESET_LABELS = { | |
| "Sentiment": "positive, negative, neutral", | |
| "Topic": "tech, business, sports, science, politics", | |
| "Intent": "question, statement, request, complaint", | |
| "Tone": "formal, casual, urgent, friendly", | |
| } | |
| def normalize_labels(raw_labels: str) -> List[str]: | |
| labels = [label.strip() for label in raw_labels.split(",") if label.strip()] | |
| unique_labels = [] | |
| seen = set() | |
| for label in labels: | |
| key = label.lower() | |
| if key not in seen: | |
| seen.add(key) | |
| unique_labels.append(label) | |
| return unique_labels | |
| def truncate_text(text: str) -> Tuple[str, bool]: | |
| text = (text or "").strip() | |
| if len(text) <= MAX_TEXT_CHARS: | |
| return text, False | |
| return text[:MAX_TEXT_CHARS], True | |
| def build_bar_chart(labels: List[str], scores: List[float]) -> str: | |
| if not labels: | |
| return "<p>No labels to display.</p>" | |
| rows = [] | |
| max_score = max(scores) if scores else 1.0 | |
| for label, score in zip(labels, scores): | |
| pct = score * 100 | |
| width = (score / max_score) * 100 if max_score > 0 else 0 | |
| rows.append( | |
| f""" | |
| <div style='margin: 8px 0;'> | |
| <div style='display:flex; justify-content:space-between; font-size:14px;'> | |
| <span><b>{label}</b></span> | |
| <span>{pct:.2f}%</span> | |
| </div> | |
| <div style='background:#e5e7eb; border-radius:8px; overflow:hidden;'> | |
| <div style='width:{width:.2f}%; background:#2563eb; color:white; padding:4px 8px; font-size:12px;'> | |
| {pct:.2f}% | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| return "<div>" + "\n".join(rows) + "</div>" | |
| def apply_preset(preset_name: str) -> str: | |
| return PRESET_LABELS.get(preset_name, "") | |
| def build_classification_prompt(text: str, labels: List[str], multi_label: bool) -> str: | |
| labels_str = ", ".join(f'"{l}"' for l in labels) | |
| mode_instruction = ( | |
| "Multiple labels can apply simultaneously. For each label, assign a confidence score between 0 and 1." | |
| if multi_label | |
| else "Choose the single best label. Assign confidence scores that sum to 1." | |
| ) | |
| return ( | |
| f"Classify the following text into these categories: {labels_str}\n\n" | |
| f"{mode_instruction}\n\n" | |
| f"Text: \"{text}\"\n\n" | |
| f"Respond with ONLY a JSON object mapping each label to its confidence score. " | |
| f"Example: {{{', '.join(f'\"{l}\": 0.5' for l in labels[:2])}}}\n" | |
| f"JSON:" | |
| ) | |
| def parse_scores(output: str, labels: List[str]) -> dict: | |
| """Extract label scores from model output, with fallback parsing.""" | |
| # Try to find JSON in the output | |
| output = output.strip() | |
| # Find the first { and last } | |
| start = output.find("{") | |
| end = output.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| json_str = output[start : end + 1] | |
| try: | |
| parsed = json.loads(json_str) | |
| scores = {} | |
| for label in labels: | |
| # Try exact match, then case-insensitive | |
| if label in parsed: | |
| scores[label] = float(parsed[label]) | |
| else: | |
| lower_map = {k.lower(): v for k, v in parsed.items()} | |
| scores[label] = float(lower_map.get(label.lower(), 0.0)) | |
| return scores | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| # Fallback: equal scores | |
| return {label: 1.0 / len(labels) for label in labels} | |
| def run_classification(text: str, candidate_labels: str, multi_label: bool): | |
| clean_text, was_truncated = truncate_text(text) | |
| labels = normalize_labels(candidate_labels) | |
| if not clean_text: | |
| raise gr.Error("Please enter text to classify.") | |
| if len(labels) < 2: | |
| raise gr.Error("Please provide at least 2 labels, separated by commas.") | |
| prompt = build_classification_prompt(clean_text, labels, multi_label) | |
| messages = [ | |
| {"role": "system", "content": "You are a precise text classifier. Respond only with valid JSON."}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| input_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| start = time.perf_counter() | |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.1, | |
| do_sample=True, | |
| top_p=0.9, | |
| ) | |
| generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
| elapsed = time.perf_counter() - start | |
| scores = parse_scores(generated, labels) | |
| # Normalize scores | |
| total = sum(scores.values()) | |
| if total > 0: | |
| scores = {k: v / total for k, v in scores.items()} | |
| else: | |
| scores = {k: 1.0 / len(labels) for k in labels} | |
| sorted_pairs = sorted(scores.items(), key=lambda x: x[1], reverse=True) | |
| sorted_labels = [x[0] for x in sorted_pairs] | |
| sorted_scores = [x[1] for x in sorted_pairs] | |
| chart_html = build_bar_chart(sorted_labels, sorted_scores) | |
| top_label = sorted_labels[0] | |
| top_score = sorted_scores[0] * 100 | |
| truncation_note = ( | |
| f" Input was truncated to {MAX_TEXT_CHARS} characters for stable inference." | |
| if was_truncated | |
| else "" | |
| ) | |
| summary = ( | |
| f"Top prediction: {top_label} ({top_score:.2f}%). " | |
| f"Model: Qwen3-0.6B. " | |
| f"Mode: {'multi-label' if multi_label else 'single-label'}. " | |
| f"Inference time: {elapsed:.3f}s.{truncation_note}" | |
| ) | |
| return chart_html, summary | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Zero-Shot Text Classifier") as demo: | |
| gr.Markdown("# Zero-Shot Text Classifier") | |
| gr.Markdown( | |
| "Classify any text into custom labels using **Qwen3-0.6B** with zero-shot instruction prompting. " | |
| "No fine-tuning needed: define your own categories and classify instantly." | |
| ) | |
| with gr.Row(): | |
| preset = gr.Dropdown( | |
| choices=list(PRESET_LABELS.keys()), | |
| value="Sentiment", | |
| label="Preset Label Sets", | |
| info="Pick a preset to auto-fill candidate labels.", | |
| ) | |
| apply_btn = gr.Button("Apply Preset") | |
| text_input = gr.Textbox( | |
| label="Text to Classify", | |
| placeholder="Paste a sentence, paragraph, or document...", | |
| lines=8, | |
| ) | |
| labels_input = gr.Textbox( | |
| label="Candidate Labels (comma-separated)", | |
| value=PRESET_LABELS["Sentiment"], | |
| placeholder="positive, negative, neutral", | |
| lines=2, | |
| ) | |
| multi_label_input = gr.Checkbox( | |
| label="Multi-label mode", | |
| value=False, | |
| info="If enabled, multiple labels can be true at the same time.", | |
| ) | |
| classify_btn = gr.Button("Classify", variant="primary") | |
| chart_output = gr.HTML(label="Label Scores") | |
| summary_output = gr.Textbox(label="Summary", interactive=False) | |
| apply_btn.click(apply_preset, inputs=preset, outputs=labels_input) | |
| classify_btn.click( | |
| run_classification, | |
| inputs=[text_input, labels_input, multi_label_input], | |
| outputs=[chart_output, summary_output], | |
| ) | |
| gr.Markdown( | |
| "Built by [Xavier Fuentes](https://huggingface.co/xavier-fuentes) @ " | |
| "[AI Enablement Academy](https://enablement.academy) | " | |
| "[Buy me a coffee ☕](https://ko-fi.com/xavierfuentes)" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |