CommentResponse / app.py
Jet-12138's picture
Update app.py
2416c4a verified
import torch
import json
import torch.nn.functional as F
from transformers import BertTokenizer
import gradio as gr
from typing import List, Dict
from model import CommentMTLModel # your class
# ------------ Device optimisation -----------------------------------------------------------------
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# ------------ Model / tokenizer ------------------------------------------------------
TOKENIZER_DIR = "/app/bert-base-uncased" # 新增
tokenizer = BertTokenizer.from_pretrained(
TOKENIZER_DIR,
local_files_only=True # 强制离线
)
with open("config.json") as f:
cfg = json.load(f)
model = CommentMTLModel(
model_name="bert-base-uncased",
num_sentiment_labels=cfg["num_sentiment_labels"],
num_toxicity_labels=cfg["num_toxicity_labels"],
dropout_prob=cfg.get("dropout_prob", 0.1)
)
model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
model.to(device).eval()
sentiment_labels = ["Negative", "Neutral", "Positive"]
toxicity_labels = ["Toxic", "Severe Toxic", "Obscene", "Threat", "Insult", "Identity Hate"]
# ------------ Core inference function ------------------------------------------------
@torch.inference_mode()
def analyse_batch(comments_text: str) -> Dict:
"""
comments_text: multiline string, each line is a comment (≤100 lines)
returns: aggregated statistics dict
"""
# Split input into list of comments, remove blank lines
comments: List[str] = [line for line in comments_text.splitlines() if line.strip()]
# Ensure we have at most 100 comments
comments = comments[:100]
# ---- encode all comments (batched) ----------
enc = tokenizer(
comments,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
enc = {k: v.to(device) for k, v in enc.items()}
# ---- forward pass (split to mini-batches in case 100 is too big) ----
batch_size = 32
n = enc["input_ids"].shape[0]
# counters
sent_counts = {lab: 0 for lab in sentiment_labels}
tox_counts = {lab: 0 for lab in toxicity_labels}
comments_with_any_tox = 0
for i in range(0, n, batch_size):
sl = slice(i, i + batch_size)
out = model(
input_ids = enc["input_ids"][sl],
attention_mask = enc["attention_mask"][sl],
token_type_ids = enc.get("token_type_ids", None)[sl] if "token_type_ids" in enc else None
)
# ----- sentiment (softmax, pick max) ----------------------------
sent_logits = out["sentiment_logits"] # (b, 3)
sent_pred = sent_logits.softmax(dim=1).argmax(dim=1) # (b,)
for idx in sent_pred.tolist():
sent_counts[sentiment_labels[idx]] += 1
# ----- toxicity (sigmoid, multi-label) --------------------------
tox_probs = out["toxicity_logits"].sigmoid() # (b, 6)
toxic_mask = tox_probs > 0.30 # boolean mask
comments_with_any_tox += toxic_mask.any(dim=1).sum().item()
# add per-label counts
for lab_idx, lab in enumerate(toxicity_labels):
tox_counts[lab] += toxic_mask[:, lab_idx].sum().item()
return {
"sentiment_counts": sent_counts,
"toxicity_counts": tox_counts,
"comments_with_any_toxicity": int(comments_with_any_tox)
}
# ------------ Gradio interface -------------------------------------------------------
iface = gr.Interface(
fn=analyse_batch,
inputs=gr.Textbox(
label="YouTube comments (max 100, one per line)",
placeholder="Paste up to 100 comments, each on its own line.",
lines=20,
max_lines=100
),
outputs=gr.JSON(label="Aggregated statistics"),
title="YouTube Comment Sentiment & Toxicity Batch API",
description=(
"Paste up to 100 raw comment strings, each on a new line, "
"then click Analyze to receive counts of Positive/Neutral/Negative comments "
"plus counts of toxicity labels where probability > 0.30."
),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()