Spaces:
Runtime error
Runtime error
File size: 4,291 Bytes
f48fd56 018d244 5df6d8e 0749e03 5df6d8e 0749e03 5df6d8e f48fd56 018d244 5df6d8e 0749e03 2416c4a 5df6d8e 0749e03 1d71490 0749e03 1d71490 018d244 0749e03 018d244 0749e03 f48fd56 0749e03 f48fd56 0749e03 f48fd56 0749e03 64ce917 018d244 0749e03 018d244 0749e03 018d244 0749e03 f48fd56 0749e03 f48fd56 0749e03 018d244 0749e03 f48fd56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import torch
import json
import torch.nn.functional as F
from transformers import BertTokenizer
import gradio as gr
from typing import List, Dict
from model import CommentMTLModel # your class
# ------------ Device optimisation -----------------------------------------------------------------
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# ------------ Model / tokenizer ------------------------------------------------------
TOKENIZER_DIR = "/app/bert-base-uncased" # 新增
tokenizer = BertTokenizer.from_pretrained(
TOKENIZER_DIR,
local_files_only=True # 强制离线
)
with open("config.json") as f:
cfg = json.load(f)
model = CommentMTLModel(
model_name="bert-base-uncased",
num_sentiment_labels=cfg["num_sentiment_labels"],
num_toxicity_labels=cfg["num_toxicity_labels"],
dropout_prob=cfg.get("dropout_prob", 0.1)
)
model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
model.to(device).eval()
sentiment_labels = ["Negative", "Neutral", "Positive"]
toxicity_labels = ["Toxic", "Severe Toxic", "Obscene", "Threat", "Insult", "Identity Hate"]
# ------------ Core inference function ------------------------------------------------
@torch.inference_mode()
def analyse_batch(comments_text: str) -> Dict:
"""
comments_text: multiline string, each line is a comment (≤100 lines)
returns: aggregated statistics dict
"""
# Split input into list of comments, remove blank lines
comments: List[str] = [line for line in comments_text.splitlines() if line.strip()]
# Ensure we have at most 100 comments
comments = comments[:100]
# ---- encode all comments (batched) ----------
enc = tokenizer(
comments,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
enc = {k: v.to(device) for k, v in enc.items()}
# ---- forward pass (split to mini-batches in case 100 is too big) ----
batch_size = 32
n = enc["input_ids"].shape[0]
# counters
sent_counts = {lab: 0 for lab in sentiment_labels}
tox_counts = {lab: 0 for lab in toxicity_labels}
comments_with_any_tox = 0
for i in range(0, n, batch_size):
sl = slice(i, i + batch_size)
out = model(
input_ids = enc["input_ids"][sl],
attention_mask = enc["attention_mask"][sl],
token_type_ids = enc.get("token_type_ids", None)[sl] if "token_type_ids" in enc else None
)
# ----- sentiment (softmax, pick max) ----------------------------
sent_logits = out["sentiment_logits"] # (b, 3)
sent_pred = sent_logits.softmax(dim=1).argmax(dim=1) # (b,)
for idx in sent_pred.tolist():
sent_counts[sentiment_labels[idx]] += 1
# ----- toxicity (sigmoid, multi-label) --------------------------
tox_probs = out["toxicity_logits"].sigmoid() # (b, 6)
toxic_mask = tox_probs > 0.30 # boolean mask
comments_with_any_tox += toxic_mask.any(dim=1).sum().item()
# add per-label counts
for lab_idx, lab in enumerate(toxicity_labels):
tox_counts[lab] += toxic_mask[:, lab_idx].sum().item()
return {
"sentiment_counts": sent_counts,
"toxicity_counts": tox_counts,
"comments_with_any_toxicity": int(comments_with_any_tox)
}
# ------------ Gradio interface -------------------------------------------------------
iface = gr.Interface(
fn=analyse_batch,
inputs=gr.Textbox(
label="YouTube comments (max 100, one per line)",
placeholder="Paste up to 100 comments, each on its own line.",
lines=20,
max_lines=100
),
outputs=gr.JSON(label="Aggregated statistics"),
title="YouTube Comment Sentiment & Toxicity Batch API",
description=(
"Paste up to 100 raw comment strings, each on a new line, "
"then click Analyze to receive counts of Positive/Neutral/Negative comments "
"plus counts of toxicity labels where probability > 0.30."
),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()
|