Shivangsinha commited on
Commit
bd3823c
·
verified ·
1 Parent(s): b199529

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +83 -0
README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ import torch
5
+ from transformers import BertTokenizer, BertModel
6
+ import torch.nn as nn
7
+ import os
8
+
9
+ import gc
10
+ torch.cuda.empty_cache()
11
+ gc.collect()
12
+ # Set the CUDA device
13
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3"
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ print(f"Using device: {device}")
17
+
18
+ # Initialize BERT tokenizer
19
+ tokenizer = BertTokenizer.from_pretrained('bert-base-german-cased')
20
+
21
+ # Define the Multi-Task Model
22
+ class MultiTaskModel(nn.Module):
23
+ def __init__(self):
24
+ super(MultiTaskModel, self).__init__()
25
+ self.bert = BertModel.from_pretrained('bert-base-german-cased')
26
+ self.dropout = nn.Dropout(0.3)
27
+ self.fc_fake_news = nn.Linear(self.bert.config.hidden_size, 1)
28
+ self.fc_hate_speech = nn.Linear(self.bert.config.hidden_size, 1)
29
+ self.fc_toxicity = nn.Linear(self.bert.config.hidden_size, 1)
30
+
31
+ def forward(self, input_ids, attention_mask):
32
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
33
+ pooled_output = outputs[1] # Get the pooled output
34
+ pooled_output = self.dropout(pooled_output)
35
+ fake_news_output = self.fc_fake_news(pooled_output)
36
+ hate_speech_output = self.fc_hate_speech(pooled_output)
37
+ toxicity_output = self.fc_toxicity(pooled_output)
38
+ return fake_news_output, hate_speech_output, toxicity_output
39
+
40
+ # Function to load the model
41
+ def load_model():
42
+ model = MultiTaskModel().to(device) # Initialize the model
43
+ model.load_state_dict(torch.load('../../media/data/multiTaskTWONB1/multi_task_model.pt')) # Alternatively Load the saved state from hugging face as well
44
+ model.eval() # Set the model to evaluation mode
45
+ return model
46
+
47
+ # Function to make predictions
48
+ def predict(text, model):
49
+ # Tokenize and encode the input text
50
+ encoding = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
51
+
52
+ # Move input tensors to the same device as the model
53
+ input_ids = encoding['input_ids'].to(device)
54
+ attention_mask = encoding['attention_mask'].to(device)
55
+
56
+ # Make predictions
57
+ with torch.no_grad():
58
+ outputs_fake_news, outputs_hate_speech, outputs_toxicity = model(input_ids, attention_mask)
59
+
60
+ # Apply sigmoid to get probabilities and round to get binary predictions
61
+ preds_fake_news = torch.sigmoid(outputs_fake_news).squeeze().round().cpu().numpy()
62
+ preds_hate_speech = torch.sigmoid(outputs_hate_speech).squeeze().round().cpu().numpy()
63
+ preds_toxicity = torch.sigmoid(outputs_toxicity).squeeze().round().cpu().numpy()
64
+
65
+ return preds_fake_news, preds_hate_speech, preds_toxicity
66
+
67
+ # Load the model
68
+ model = load_model()
69
+
70
+ # Example text input for prediction
71
+ text_input = "Mir fallen nur Steuervorteile durch Gender Pay gap ein."
72
+
73
+ # Make predictions
74
+ predictions = predict(text_input, model)
75
+
76
+ # Print the predictions
77
+ print(f"Fake News Prediction: {predictions[0]}")
78
+ print(f"Hate Speech Prediction: {predictions[1]}")
79
+ print(f"Toxicity Prediction: {predictions[2]}")
80
+
81
+
82
+ torch.cuda.empty_cache()
83
+ gc.collect()