philhum commited on
Commit
ee4594d
·
1 Parent(s): a5eaf46

Fix forward() to return all multitask head outputs

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -88,7 +88,15 @@ class SpecialistDistilBERT(nn.Module):
88
  pooled = outputs.last_hidden_state[:, 0, :]
89
  pooled = self.dropout(pooled)
90
  shared = self.shared(pooled)
91
- return {"doc_logits": self.doc_classifier(shared)}
 
 
 
 
 
 
 
 
92
 
93
 
94
  @asynccontextmanager
 
88
  pooled = outputs.last_hidden_state[:, 0, :]
89
  pooled = self.dropout(pooled)
90
  shared = self.shared(pooled)
91
+ return {
92
+ "doc_logits": self.doc_classifier(shared),
93
+ "expense_logits": self.expense_classifier(shared),
94
+ "personal_logits": self.personal_classifier(shared),
95
+ "business_logits": self.business_classifier(shared),
96
+ "auto_renew_logits": self.auto_renew_classifier(shared),
97
+ "renewal_logits": self.renewal_period_classifier(shared),
98
+ "cancellation_logits": self.cancellation_classifier(shared),
99
+ }
100
 
101
 
102
  @asynccontextmanager