Upload 2 files
Browse files- moderation_model.pth +3 -0
- test.py +71 -0
moderation_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cb5f1cbf7c576b2d1ea0ee801d90178b2d392ba37256573bc8454c38ae521854
|
| 3 |
+
size 204952
|
test.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
|
| 9 |
+
from transformers import AutoTokenizer, AutoModel
|
| 10 |
+
|
| 11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
|
| 13 |
+
tokenizer_embeddings = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
|
| 14 |
+
model_embeddings = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2').to(device)
|
| 15 |
+
|
| 16 |
+
class ModerationModel(nn.Module):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
input_size = 384
|
| 19 |
+
hidden_size = 128
|
| 20 |
+
output_size = 11
|
| 21 |
+
super(ModerationModel, self).__init__()
|
| 22 |
+
self.fc1 = nn.Linear(input_size, hidden_size)
|
| 23 |
+
self.fc2 = nn.Linear(hidden_size, output_size)
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
x = F.relu(self.fc1(x))
|
| 27 |
+
x = self.fc2(x)
|
| 28 |
+
return x
|
| 29 |
+
|
| 30 |
+
def mean_pooling(model_output, attention_mask):
|
| 31 |
+
token_embeddings = model_output[0]
|
| 32 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 33 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 34 |
+
|
| 35 |
+
def getEmbeddings(sentences):
|
| 36 |
+
encoded_input = tokenizer_embeddings(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
model_output = model_embeddings(**encoded_input)
|
| 39 |
+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
| 40 |
+
return sentence_embeddings.cpu()
|
| 41 |
+
|
| 42 |
+
def getEmb(text):
|
| 43 |
+
sentences = [text]
|
| 44 |
+
sentence_embeddings = getEmbeddings(sentences)
|
| 45 |
+
return sentence_embeddings.tolist()[0]
|
| 46 |
+
|
| 47 |
+
def predict(model, embeddings):
|
| 48 |
+
model.eval()
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
embeddings_tensor = torch.tensor(embeddings, dtype=torch.float)
|
| 51 |
+
outputs = model(embeddings_tensor.unsqueeze(0))
|
| 52 |
+
predicted_scores = torch.sigmoid(outputs)
|
| 53 |
+
predicted_scores = predicted_scores.squeeze(0).tolist()
|
| 54 |
+
category_names = ["harassment", "harassment-threatening", "hate", "hate-threatening", "self-harm", "self-harm-instructions", "self-harm-intent", "sexual", "sexual-minors", "violence", "violence-graphic"]
|
| 55 |
+
|
| 56 |
+
result = {category: score for category, score in zip(category_names, predicted_scores)}
|
| 57 |
+
detected = {category: score > 0.5 for category, score in zip(category_names, predicted_scores)}
|
| 58 |
+
detect_value = any(detected.values())
|
| 59 |
+
|
| 60 |
+
return {"category_scores": result, 'detect': detected, 'detected': detect_value}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
print('Load model')
|
| 64 |
+
moderation = ModerationModel()
|
| 65 |
+
moderation.load_state_dict(torch.load('moderation_model.pth'))
|
| 66 |
+
|
| 67 |
+
text = "I want to kill them."
|
| 68 |
+
|
| 69 |
+
embeddings_for_prediction = getEmb(text)
|
| 70 |
+
prediction = predict(moderation, embeddings_for_prediction)
|
| 71 |
+
print(json.dumps(prediction,indent=4))
|