Ng Huong Duyen commited on
Commit
0c036ef
·
1 Parent(s): ba26ab7

update dropout to 0.1 for all task heads

Browse files
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -20,10 +20,10 @@ class EndpointHandler:
20
  self.model.eval()
21
 
22
  self.task_heads = torch.nn.ModuleDict({
23
- 'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.2),
24
- 'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.2),
25
- 'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.2),
26
- 'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.2)
27
  })
28
  self.log_vars = torch.nn.ParameterDict({
29
  task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads
 
20
  self.model.eval()
21
 
22
  self.task_heads = torch.nn.ModuleDict({
23
+ 'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.1),
24
+ 'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.1),
25
+ 'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.1),
26
+ 'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.1)
27
  })
28
  self.log_vars = torch.nn.ParameterDict({
29
  task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads