File size: 4,291 Bytes
f48fd56
 
 
018d244
5df6d8e
0749e03
5df6d8e
0749e03
5df6d8e
f48fd56
018d244
 
 
 
 
 
5df6d8e
0749e03
2416c4a
 
 
 
 
5df6d8e
0749e03
 
1d71490
 
 
0749e03
 
 
1d71490
018d244
0749e03
018d244
 
0749e03
 
 
 
f48fd56
0749e03
f48fd56
0749e03
 
f48fd56
 
 
 
 
0749e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64ce917
018d244
0749e03
 
 
018d244
 
0749e03
018d244
0749e03
f48fd56
 
 
 
 
 
0749e03
 
 
f48fd56
 
 
0749e03
 
018d244
 
0749e03
f48fd56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import json
import torch.nn.functional as F
from transformers import BertTokenizer
import gradio as gr
from typing import List, Dict

from model import CommentMTLModel   # your class

# ------------ Device optimisation -----------------------------------------------------------------
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# ------------ Model / tokenizer ------------------------------------------------------
TOKENIZER_DIR = "/app/bert-base-uncased"     # 新增
tokenizer = BertTokenizer.from_pretrained(
    TOKENIZER_DIR,
    local_files_only=True                    # 强制离线
)

with open("config.json") as f:
    cfg = json.load(f)

model = CommentMTLModel(
    model_name="bert-base-uncased",
    num_sentiment_labels=cfg["num_sentiment_labels"],
    num_toxicity_labels=cfg["num_toxicity_labels"],
    dropout_prob=cfg.get("dropout_prob", 0.1)
)
model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
model.to(device).eval()

sentiment_labels = ["Negative", "Neutral", "Positive"]
toxicity_labels  = ["Toxic", "Severe Toxic", "Obscene", "Threat", "Insult", "Identity Hate"]

# ------------ Core inference function ------------------------------------------------
@torch.inference_mode()
def analyse_batch(comments_text: str) -> Dict:
    """
    comments_text: multiline string, each line is a comment (≤100 lines)
    returns: aggregated statistics dict
    """
    # Split input into list of comments, remove blank lines
    comments: List[str] = [line for line in comments_text.splitlines() if line.strip()]

    # Ensure we have at most 100 comments
    comments = comments[:100]

    # ---- encode all comments (batched) ----------
    enc = tokenizer(
        comments,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )
    enc = {k: v.to(device) for k, v in enc.items()}

    # ---- forward pass (split to mini-batches in case 100 is too big) ----
    batch_size = 32
    n = enc["input_ids"].shape[0]

    # counters
    sent_counts = {lab: 0 for lab in sentiment_labels}
    tox_counts  = {lab: 0 for lab in toxicity_labels}
    comments_with_any_tox = 0

    for i in range(0, n, batch_size):
        sl = slice(i, i + batch_size)
        out = model(
            input_ids      = enc["input_ids"][sl],
            attention_mask = enc["attention_mask"][sl],
            token_type_ids = enc.get("token_type_ids", None)[sl] if "token_type_ids" in enc else None
        )

        # ----- sentiment (softmax, pick max) ----------------------------
        sent_logits = out["sentiment_logits"]          # (b, 3)
        sent_pred   = sent_logits.softmax(dim=1).argmax(dim=1)  # (b,)
        for idx in sent_pred.tolist():
            sent_counts[sentiment_labels[idx]] += 1

        # ----- toxicity (sigmoid, multi-label) --------------------------
        tox_probs = out["toxicity_logits"].sigmoid()   # (b, 6)
        toxic_mask = tox_probs > 0.30                  # boolean mask
        comments_with_any_tox += toxic_mask.any(dim=1).sum().item()

        # add per-label counts
        for lab_idx, lab in enumerate(toxicity_labels):
            tox_counts[lab] += toxic_mask[:, lab_idx].sum().item()

    return {
        "sentiment_counts": sent_counts,
        "toxicity_counts": tox_counts,
        "comments_with_any_toxicity": int(comments_with_any_tox)
    }

# ------------ Gradio interface -------------------------------------------------------
iface = gr.Interface(
    fn=analyse_batch,
    inputs=gr.Textbox(
        label="YouTube comments (max 100, one per line)",
        placeholder="Paste up to 100 comments, each on its own line.",
        lines=20,
        max_lines=100
    ),
    outputs=gr.JSON(label="Aggregated statistics"),
    title="YouTube Comment Sentiment & Toxicity Batch API",
    description=(
        "Paste up to 100 raw comment strings, each on a new line, "
        "then click Analyze to receive counts of Positive/Neutral/Negative comments "
        "plus counts of toxicity labels where probability > 0.30."
    ),
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()