| | --- |
| | license: cc-by-nc-sa-4.0 |
| | datasets: |
| | - Blablablab/ALOE |
| | --- |
| | |
| | ### Model Description |
| |
|
| | The model classifies whether two *appraisals* aligned or not and is trained on [ALOE](https://huggingface.co/datasets/Blablablab/ALOE) dataset. |
| |
|
| | **Input:** two appraisals (see `forward` function in `SNN` class) |
| |
|
| | **Output:** cosine similarity |
| |
|
| | **Model architecture**: Siamese Network + `all-mpnet-base-v2` |
| |
|
| | **Developed by:** Jiamin Yang |
| |
|
| | ### Model Performance |
| |
|
| | | F1 | Recall | Precision | |
| | | :--: | :----: | :-------: | |
| | | 0.46 | 0.45 | 0.46 | |
| |
|
| | ### Getting Started |
| |
|
| | ```python |
| | import torch |
| | from torch import nn |
| | from transformers import AutoModel, AutoTokenizer |
| | |
| | class SNN(nn.Module): |
| | def __init__(self, model_name): |
| | super(SNN,self).__init__() |
| | self.model = AutoModel.from_pretrained(model_name).to("cuda").train() |
| | self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-4) |
| | |
| | def mean_pooling(self, token_embeddings, attention_mask): |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| | |
| | def forward(self, input_ids_a, attention_a, input_ids_b, attention_b): |
| | #encode sentence and get mean pooled sentence representation |
| | encoding1 = self.model(input_ids_a, attention_mask=attention_a)[0] #all token embeddings |
| | encoding2 = self.model(input_ids_b, attention_mask=attention_b)[0] |
| | |
| | meanPooled1 = self.mean_pooling(encoding1, attention_a) |
| | meanPooled2 = self.mean_pooling(encoding2, attention_b) |
| | |
| | pred = self.cos(meanPooled1, meanPooled2) |
| | return pred |
| | |
| | checkpoint_path = 'your_path_to/empathy-appraisal-alignment.pt' |
| | |
| | tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') |
| | model = SNN('sentence-transformers/all-mpnet-base-v2').to('cuda') |
| | checkpoint = torch.load(checkpoint_path) |
| | state_dict = checkpoint['model_state_dict'] |
| | |
| | # depend on the version of torch |
| | del state_dict['model.embeddings.position_ids'] |
| | |
| | model.load_state_dict(state_dict) |
| | |
| | # use the model |
| | target = ["I'm so sad that my cat died yesterday."] |
| | observer = ["It's ok to feel sad."] |
| | |
| | target_encodings = tokenizer(target, padding=True, truncation=True) |
| | target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda') |
| | target_attention_mask = torch.LongTensor(target_encodings['attention_mask']).to('cuda') |
| | observer_encodings = tokenizer(observer, padding=True, truncation=True) |
| | observer_input_ids = torch.LongTensor(observer_encodings['input_ids']).to('cuda') |
| | observer_attention_mask = torch.LongTensor(observer_encodings['attention_mask']).to('cuda') |
| | |
| | model.eval() |
| | output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask) |
| | print(output) # [0.5755] |
| | ``` |
| |
|