Spaces:
Runtime error
Runtime error
| import torch | |
| import json | |
| import torch.nn.functional as F | |
| from transformers import BertTokenizer | |
| import gradio as gr | |
| from typing import List, Dict | |
| from model import CommentMTLModel # your class | |
| # ------------ Device optimisation ----------------------------------------------------------------- | |
| if torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| # ------------ Model / tokenizer ------------------------------------------------------ | |
| TOKENIZER_DIR = "/app/bert-base-uncased" # 新增 | |
| tokenizer = BertTokenizer.from_pretrained( | |
| TOKENIZER_DIR, | |
| local_files_only=True # 强制离线 | |
| ) | |
| with open("config.json") as f: | |
| cfg = json.load(f) | |
| model = CommentMTLModel( | |
| model_name="bert-base-uncased", | |
| num_sentiment_labels=cfg["num_sentiment_labels"], | |
| num_toxicity_labels=cfg["num_toxicity_labels"], | |
| dropout_prob=cfg.get("dropout_prob", 0.1) | |
| ) | |
| model.load_state_dict(torch.load("pytorch_model.bin", map_location=device)) | |
| model.to(device).eval() | |
| sentiment_labels = ["Negative", "Neutral", "Positive"] | |
| toxicity_labels = ["Toxic", "Severe Toxic", "Obscene", "Threat", "Insult", "Identity Hate"] | |
| # ------------ Core inference function ------------------------------------------------ | |
| def analyse_batch(comments_text: str) -> Dict: | |
| """ | |
| comments_text: multiline string, each line is a comment (≤100 lines) | |
| returns: aggregated statistics dict | |
| """ | |
| # Split input into list of comments, remove blank lines | |
| comments: List[str] = [line for line in comments_text.splitlines() if line.strip()] | |
| # Ensure we have at most 100 comments | |
| comments = comments[:100] | |
| # ---- encode all comments (batched) ---------- | |
| enc = tokenizer( | |
| comments, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| # ---- forward pass (split to mini-batches in case 100 is too big) ---- | |
| batch_size = 32 | |
| n = enc["input_ids"].shape[0] | |
| # counters | |
| sent_counts = {lab: 0 for lab in sentiment_labels} | |
| tox_counts = {lab: 0 for lab in toxicity_labels} | |
| comments_with_any_tox = 0 | |
| for i in range(0, n, batch_size): | |
| sl = slice(i, i + batch_size) | |
| out = model( | |
| input_ids = enc["input_ids"][sl], | |
| attention_mask = enc["attention_mask"][sl], | |
| token_type_ids = enc.get("token_type_ids", None)[sl] if "token_type_ids" in enc else None | |
| ) | |
| # ----- sentiment (softmax, pick max) ---------------------------- | |
| sent_logits = out["sentiment_logits"] # (b, 3) | |
| sent_pred = sent_logits.softmax(dim=1).argmax(dim=1) # (b,) | |
| for idx in sent_pred.tolist(): | |
| sent_counts[sentiment_labels[idx]] += 1 | |
| # ----- toxicity (sigmoid, multi-label) -------------------------- | |
| tox_probs = out["toxicity_logits"].sigmoid() # (b, 6) | |
| toxic_mask = tox_probs > 0.30 # boolean mask | |
| comments_with_any_tox += toxic_mask.any(dim=1).sum().item() | |
| # add per-label counts | |
| for lab_idx, lab in enumerate(toxicity_labels): | |
| tox_counts[lab] += toxic_mask[:, lab_idx].sum().item() | |
| return { | |
| "sentiment_counts": sent_counts, | |
| "toxicity_counts": tox_counts, | |
| "comments_with_any_toxicity": int(comments_with_any_tox) | |
| } | |
| # ------------ Gradio interface ------------------------------------------------------- | |
| iface = gr.Interface( | |
| fn=analyse_batch, | |
| inputs=gr.Textbox( | |
| label="YouTube comments (max 100, one per line)", | |
| placeholder="Paste up to 100 comments, each on its own line.", | |
| lines=20, | |
| max_lines=100 | |
| ), | |
| outputs=gr.JSON(label="Aggregated statistics"), | |
| title="YouTube Comment Sentiment & Toxicity Batch API", | |
| description=( | |
| "Paste up to 100 raw comment strings, each on a new line, " | |
| "then click Analyze to receive counts of Positive/Neutral/Negative comments " | |
| "plus counts of toxicity labels where probability > 0.30." | |
| ), | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |