Jet-12138 commited on
Commit
0749e03
·
verified ·
1 Parent(s): e9007ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -53
app.py CHANGED
@@ -1,12 +1,11 @@
1
- import torch
2
- import torch.nn.functional as F
3
  from transformers import BertTokenizer
4
  import gradio as gr
5
- import json
6
 
7
- from model import CommentMTLModel
8
 
9
- # Set device, including MPS
10
  if torch.backends.mps.is_available():
11
  device = torch.device("mps")
12
  elif torch.cuda.is_available():
@@ -14,67 +13,92 @@ elif torch.cuda.is_available():
14
  else:
15
  device = torch.device("cpu")
16
 
17
- # Load tokenizer
18
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
19
 
20
- # Load config values manually
21
- with open("config.json", "r") as f:
22
- config_data = json.load(f)
23
 
24
- # Create model
25
  model = CommentMTLModel(
26
  model_name="bert-base-uncased",
27
- num_sentiment_labels=config_data["num_sentiment_labels"],
28
- num_toxicity_labels=config_data["num_toxicity_labels"],
29
- dropout_prob=config_data.get("dropout_prob", 0.1)
30
  )
31
  model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
32
- model.to(device)
33
- model.eval()
34
 
35
- # Define labels
36
  sentiment_labels = ["Negative", "Neutral", "Positive"]
37
- toxicity_labels = ["Toxic", "Severe Toxic", "Obscene", "Threat", "Insult", "Identity Hate"]
38
-
39
- # Define the prediction function
40
- def analyse_comment(comment):
41
- inputs = tokenizer(comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
42
- inputs = {k: v.to(device) for k, v in inputs.items() if k in ['input_ids', 'attention_mask', 'token_type_ids']}
43
-
44
- with torch.no_grad():
45
- outputs = model(**inputs)
46
-
47
- sentiment_logits = outputs["sentiment_logits"]
48
- toxicity_logits = outputs["toxicity_logits"]
49
-
50
- # Process sentiment (multi-class classification)
51
- sentiment_probs = F.softmax(sentiment_logits, dim=1).squeeze(0) # shape: (3,)
52
- sentiment_predictions = {}
53
-
54
- for idx, label in enumerate(sentiment_labels):
55
- prob = sentiment_probs[idx].item()
56
- sentiment_predictions[label] = round(prob, 4)
57
-
58
- # Process toxicity (multi-label classification)
59
- toxicity_probs = torch.sigmoid(toxicity_logits).squeeze(0) # shape: (6,)
60
- toxicity_predictions = {}
61
-
62
- for idx, label in enumerate(toxicity_labels):
63
- prob = toxicity_probs[idx].item()
64
- toxicity_predictions[label] = round(prob, 4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  return {
67
- "Sentiment Probabilities": sentiment_predictions,
68
- "Toxicity Probabilities": toxicity_predictions
 
69
  }
70
 
71
- # Create Gradio interface
72
  iface = gr.Interface(
73
- fn=analyse_comment,
74
- inputs=gr.Textbox(lines=3, placeholder="Please enter a comment for analysis..."),
75
- outputs=gr.JSON(label="Prediction Results"),
76
- title="Comment Sentiment and Toxicity Classifier",
77
- description="This tool classifies the sentiment and the most probable type of toxicity in a given comment. It utilises a custom multi-task learning BERT model. Developed for academic demonstration purposes in Australia."
 
 
 
 
78
  )
79
 
80
- iface.launch()
 
 
1
+ import torch, json, math, torch.nn.functional as F
 
2
  from transformers import BertTokenizer
3
  import gradio as gr
4
+ from typing import List, Dict
5
 
6
+ from model import CommentMTLModel # your class
7
 
8
+ # ------------ Device -----------------------------------------------------------------
9
  if torch.backends.mps.is_available():
10
  device = torch.device("mps")
11
  elif torch.cuda.is_available():
 
13
  else:
14
  device = torch.device("cpu")
15
 
16
+ # ------------ Model / tokenizer ------------------------------------------------------
17
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
18
 
19
+ with open("config.json") as f:
20
+ cfg = json.load(f)
 
21
 
 
22
  model = CommentMTLModel(
23
  model_name="bert-base-uncased",
24
+ num_sentiment_labels=cfg["num_sentiment_labels"],
25
+ num_toxicity_labels=cfg["num_toxicity_labels"],
26
+ dropout_prob=cfg.get("dropout_prob", 0.1)
27
  )
28
  model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
29
+ model.to(device).eval()
 
30
 
 
31
  sentiment_labels = ["Negative", "Neutral", "Positive"]
32
+ toxicity_labels = ["Toxic", "Severe Toxic", "Obscene", "Threat", "Insult", "Identity Hate"]
33
+
34
+ # ------------ Core inference function ------------------------------------------------
35
+ @torch.inference_mode()
36
+ def analyse_batch(comments: List[str]) -> Dict:
37
+ """
38
+ comments: list of ≤100 raw comment strings
39
+ returns: aggregated statistics dict
40
+ """
41
+
42
+ # ---- encode all comments (batched) ----------
43
+ enc = tokenizer(
44
+ comments,
45
+ return_tensors="pt",
46
+ padding=True,
47
+ truncation=True,
48
+ max_length=512
49
+ )
50
+ enc = {k: v.to(device) for k, v in enc.items()}
51
+
52
+ # ---- forward pass (split to mini-batches in case 100 is too big) ----
53
+ batch_size = 32
54
+ n = enc["input_ids"].shape[0]
55
+
56
+ # counters
57
+ sent_counts = {lab: 0 for lab in sentiment_labels}
58
+ tox_counts = {lab: 0 for lab in toxicity_labels}
59
+ comments_with_any_tox = 0
60
+
61
+ for i in range(0, n, batch_size):
62
+ sl = slice(i, i + batch_size)
63
+ out = model(
64
+ input_ids = enc["input_ids"][sl],
65
+ attention_mask = enc["attention_mask"][sl],
66
+ token_type_ids = enc.get("token_type_ids", None)[sl] if "token_type_ids" in enc else None
67
+ )
68
+
69
+ # ----- sentiment (softmax, pick max) ----------------------------
70
+ sent_logits = out["sentiment_logits"] # (b, 3)
71
+ sent_pred = sent_logits.softmax(dim=1).argmax(dim=1) # (b,)
72
+ for idx in sent_pred.tolist():
73
+ sent_counts[sentiment_labels[idx]] += 1
74
+
75
+ # ----- toxicity (sigmoid, multi-label) --------------------------
76
+ tox_probs = out["toxicity_logits"].sigmoid() # (b, 6)
77
+ toxic_mask = tox_probs > 0.30 # boolean mask
78
+ comments_with_any_tox += toxic_mask.any(dim=1).sum().item()
79
+
80
+ # add per-label counts
81
+ for lab_idx, lab in enumerate(toxicity_labels):
82
+ tox_counts[lab] += toxic_mask[:, lab_idx].sum().item()
83
 
84
  return {
85
+ "sentiment_counts": sent_counts,
86
+ "toxicity_counts": tox_counts,
87
+ "comments_with_any_toxicity": int(comments_with_any_tox)
88
  }
89
 
90
+ # ------------ Gradio interface -------------------------------------------------------
91
  iface = gr.Interface(
92
+ fn=analyse_batch,
93
+ inputs=gr.JSON(label="List of comments (max 100)"),
94
+ outputs=gr.JSON(label="Aggregated statistics"),
95
+ title="YouTube Comment Sentiment & Toxicity Batch API",
96
+ description=(
97
+ "Send up to 100 raw comment strings and receive counts of Positive/Neutral/Negative "
98
+ "comments plus counts of toxicity labels where probability > 0.30."
99
+ ),
100
+ allow_flagging="never"
101
  )
102
 
103
+ if __name__ == "__main__":
104
+ iface.launch()