File size: 5,249 Bytes
f9cb218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
Gradio App for HF Space Deployment
Comment Classification Skill powered by Qwen2.5-1.5B fine-tuned model.
"""

import os
import json
import time
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ============================================================
# Load Model
# ============================================================
MODEL_ID = os.environ.get("MODEL_ID", "jovincia/qwen25-comment-classifier")

print(f"Loading model from {MODEL_ID}...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float32,  # CPU-safe for HF Spaces free tier
        trust_remote_code=True,
    )
    model.eval()
    id2label = model.config.id2label
    print(f"Model loaded. Labels: {id2label}")
except Exception as e:
    print(f"ERROR: Failed to load model from '{MODEL_ID}': {e}")
    print("Make sure the model has been trained (02_finetune.py) or the HF repo exists.")
    raise SystemExit(1)

# Color mapping for labels
LABEL_COLORS = {
    "positive": "#4CAF50",
    "negative": "#F44336",
    "neutral": "#9E9E9E",
    "ambiguous": "#FF9800",
}

def classify_comment(text: str) -> dict:
    """Classify a comment and return label probabilities."""
    if not text or not text.strip():
        return {label: 0.0 for label in id2label.values()}

    start_time = time.perf_counter()

    inputs = tokenizer(
        text,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1)[0]

    latency_ms = (time.perf_counter() - start_time) * 1000

    # Build result dict
    result = {}
    for idx, prob in enumerate(probs.tolist()):
        label = id2label[str(idx)] if str(idx) in id2label else id2label.get(idx, f"class_{idx}")
        result[label] = round(prob, 4)

    # Print latency for monitoring
    predicted = max(result, key=result.get)
    print(f"[{latency_ms:.1f}ms] '{text[:50]}...' -> {predicted} ({result[predicted]:.3f})")

    return result


def batch_classify(texts: str) -> str:
    """Classify multiple comments (one per line)."""
    if not texts or not texts.strip():
        return "Please enter at least one comment."

    lines = [line.strip() for line in texts.strip().split("\n") if line.strip()]
    results = []

    for line in lines:
        probs = classify_comment(line)
        predicted = max(probs, key=probs.get)
        confidence = probs[predicted]
        results.append(f"[{predicted.upper()}] ({confidence:.1%}) {line}")

    return "\n".join(results)


# ============================================================
# Gradio Interface
# ============================================================
with gr.Blocks(
    title="Comment Classification Skill",
    theme=gr.themes.Soft(),
) as demo:
    gr.Markdown(
        """
        # Comment Classification Skill
        **Fine-tuned Qwen2.5-1.5B** for 4-class comment sentiment classification.

        Classes: **positive** | **negative** | **neutral** | **ambiguous**
        """
    )

    with gr.Tab("Single Comment"):
        with gr.Row():
            with gr.Column():
                input_text = gr.Textbox(
                    label="Enter a comment",
                    placeholder="Type your comment here...",
                    lines=3,
                )
                classify_btn = gr.Button("Classify", variant="primary")
            with gr.Column():
                output_label = gr.Label(label="Classification Result", num_top_classes=4)

        classify_btn.click(
            fn=classify_comment,
            inputs=input_text,
            outputs=output_label,
        )

        gr.Examples(
            examples=[
                ["This product is amazing! Best purchase I've ever made."],
                ["Terrible quality. Broke after one day of use."],
                ["It arrived on time. Standard packaging."],
                ["I'm not sure if this is working correctly or not..."],
                ["The customer service was incredibly helpful and kind!"],
                ["What a waste of money. Never buying from here again."],
            ],
            inputs=input_text,
        )

    with gr.Tab("Batch Classification"):
        gr.Markdown("Enter one comment per line for batch processing.")
        batch_input = gr.Textbox(
            label="Comments (one per line)",
            placeholder="Comment 1\nComment 2\nComment 3",
            lines=8,
        )
        batch_btn = gr.Button("Classify All", variant="primary")
        batch_output = gr.Textbox(label="Results", lines=10, interactive=False)

        batch_btn.click(
            fn=batch_classify,
            inputs=batch_input,
            outputs=batch_output,
        )

    gr.Markdown(
        """
        ---
        **Model:** Qwen2.5-1.5B fine-tuned with LoRA on GoEmotions dataset (58k+ comments)
        **Task:** 4-class comment sentiment classification
        """
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)