Jet-12138 commited on
Commit
64ce917
·
verified ·
1 Parent(s): f771085

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -39,7 +39,7 @@ toxicity_labels = ["Toxic", "Severe Toxic", "Obscene", "Threat", "Insult", "Iden
39
  # Define the prediction function
40
  def analyse_comment(comment):
41
  inputs = tokenizer(comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
42
- inputs = {k: v.to(device) for k, v in inputs.items()}
43
 
44
  with torch.no_grad():
45
  outputs = model(**inputs)
@@ -47,19 +47,22 @@ def analyse_comment(comment):
47
  sentiment_logits = outputs["sentiment_logits"]
48
  toxicity_logits = outputs["toxicity_logits"]
49
 
50
- # Process sentiment
51
  sentiment_probs = F.softmax(sentiment_logits, dim=1)
52
  sentiment_idx = torch.argmax(sentiment_probs, dim=1).item()
53
  sentiment_prediction = sentiment_labels[sentiment_idx]
54
 
55
- # Process toxicity
56
- toxicity_probs = F.softmax(toxicity_logits, dim=1)
57
- toxicity_idx = torch.argmax(toxicity_probs, dim=1).item()
58
- toxicity_prediction = toxicity_labels[toxicity_idx]
59
-
 
 
 
60
  return {
61
  "Sentiment": sentiment_prediction,
62
- "Toxicity": toxicity_prediction
63
  }
64
 
65
  # Create Gradio interface
 
39
  # Define the prediction function
40
  def analyse_comment(comment):
41
  inputs = tokenizer(comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
42
+ inputs = {k: v.to(device) for k, v in inputs.items() if k in ['input_ids', 'attention_mask', 'token_type_ids']}
43
 
44
  with torch.no_grad():
45
  outputs = model(**inputs)
 
47
  sentiment_logits = outputs["sentiment_logits"]
48
  toxicity_logits = outputs["toxicity_logits"]
49
 
50
+ # Process sentiment (single label classification)
51
  sentiment_probs = F.softmax(sentiment_logits, dim=1)
52
  sentiment_idx = torch.argmax(sentiment_probs, dim=1).item()
53
  sentiment_prediction = sentiment_labels[sentiment_idx]
54
 
55
+ # Process toxicity (multi-label classification)
56
+ toxicity_probs = torch.sigmoid(toxicity_logits).squeeze(0) # shape: (6,)
57
+ toxicity_predictions = {}
58
+
59
+ for idx, label in enumerate(toxicity_labels):
60
+ prob = toxicity_probs[idx].item()
61
+ toxicity_predictions[label] = round(prob, 2)
62
+
63
  return {
64
  "Sentiment": sentiment_prediction,
65
+ "Toxicity Probabilities": toxicity_predictions
66
  }
67
 
68
  # Create Gradio interface