Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -335,8 +335,8 @@ class MultiHeadClassification(nn.Module):
|
|
| 335 |
dropout (float): Dropout rate
|
| 336 |
l2_reg (float): L2 regularization rate
|
| 337 |
"""
|
| 338 |
-
backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone
|
| 339 |
instance = cls(backbone, head_config, dropout, l2_reg)
|
| 340 |
-
instance.load(os.path.join(model_path, 'pretrained/model.pth'))
|
| 341 |
instance.head_config = {k: v. instance.heads}
|
| 342 |
return instance
|
|
|
|
| 335 |
dropout (float): Dropout rate
|
| 336 |
l2_reg (float): L2 regularization rate
|
| 337 |
"""
|
| 338 |
+
backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone'))
|
| 339 |
instance = cls(backbone, head_config, dropout, l2_reg)
|
| 340 |
+
instance.load(os.path.join(model_path, 'pretrained/multi-head-sequence-classification-model-model.pth'))
|
| 341 |
instance.head_config = {k: v. instance.heads}
|
| 342 |
return instance
|