Spaces:
Sleeping
Sleeping
Ng Huong Duyen commited on
Commit ·
0c036ef
1
Parent(s): ba26ab7
update dropout to 0.1 for all task heads
Browse files- handler.py +4 -4
handler.py
CHANGED
|
@@ -20,10 +20,10 @@ class EndpointHandler:
|
|
| 20 |
self.model.eval()
|
| 21 |
|
| 22 |
self.task_heads = torch.nn.ModuleDict({
|
| 23 |
-
'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.
|
| 24 |
-
'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.
|
| 25 |
-
'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.
|
| 26 |
-
'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.
|
| 27 |
})
|
| 28 |
self.log_vars = torch.nn.ParameterDict({
|
| 29 |
task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads
|
|
|
|
| 20 |
self.model.eval()
|
| 21 |
|
| 22 |
self.task_heads = torch.nn.ModuleDict({
|
| 23 |
+
'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.1),
|
| 24 |
+
'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.1),
|
| 25 |
+
'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.1),
|
| 26 |
+
'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.1)
|
| 27 |
})
|
| 28 |
self.log_vars = torch.nn.ParameterDict({
|
| 29 |
task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads
|