Update README.md
Browse filesimport torch
from torch import nn
# Definir la red, lo que en definitiva la hace siamesa es pasar los dos dic
class SBERT(nn.Module):
def __init__(self, base_model, dropout=0.1):
super().__init__()
self.base_model = base_model
self.dropout = nn.Dropout(dropout)
# Recordamos que la salida de la Bert es 768
#self.fc = nn.Linear(768, 3) #Cambio 13/6
self.fc = nn.Linear(768*3, 3)
def forward(self, premise, hypothesis):
out_u = self.base_model(**premise)
out_v = self.base_model(**hypothesis)
pooler_u = out_u.pooler_output
pooler_v = out_v.pooler_output
pooler_u = self.dropout(pooler_u)
pooler_v = self.dropout(pooler_v)
#concatenated = torch.cat([self.fc(pooler_u), self.fc(pooler_v), torch.abs(self.fc(pooler_u) - self.fc(pooler_v))], dim=0)
concatenated = torch.cat([pooler_u, pooler_v, torch.abs(pooler_u -pooler_v)], dim=1)
out=self.fc(concatenated)
return out
|
@@ -8,26 +8,3 @@ pipeline_tag: fill-mask
|
|
| 8 |
tags:
|
| 9 |
- code
|
| 10 |
---
|
| 11 |
-
import torch
|
| 12 |
-
from torch import nn
|
| 13 |
-
# Definir la red, lo que en definitiva la hace siamesa es pasar los dos dic
|
| 14 |
-
class SBERT(nn.Module):
|
| 15 |
-
def __init__(self, base_model, dropout=0.1):
|
| 16 |
-
super().__init__()
|
| 17 |
-
self.base_model = base_model
|
| 18 |
-
self.dropout = nn.Dropout(dropout)
|
| 19 |
-
# Recordamos que la salida de la Bert es 768
|
| 20 |
-
#self.fc = nn.Linear(768, 3) #Cambio 13/6
|
| 21 |
-
self.fc = nn.Linear(768*3, 3)
|
| 22 |
-
|
| 23 |
-
def forward(self, premise, hypothesis):
|
| 24 |
-
out_u = self.base_model(**premise)
|
| 25 |
-
out_v = self.base_model(**hypothesis)
|
| 26 |
-
pooler_u = out_u.pooler_output
|
| 27 |
-
pooler_v = out_v.pooler_output
|
| 28 |
-
pooler_u = self.dropout(pooler_u)
|
| 29 |
-
pooler_v = self.dropout(pooler_v)
|
| 30 |
-
#concatenated = torch.cat([self.fc(pooler_u), self.fc(pooler_v), torch.abs(self.fc(pooler_u) - self.fc(pooler_v))], dim=0)
|
| 31 |
-
concatenated = torch.cat([pooler_u, pooler_v, torch.abs(pooler_u -pooler_v)], dim=1)
|
| 32 |
-
out=self.fc(concatenated)
|
| 33 |
-
return out
|
|
|
|
| 8 |
tags:
|
| 9 |
- code
|
| 10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|