ag-charalampous commited on
Commit
f65d961
·
verified ·
1 Parent(s): 8eff031

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +93 -4
README.md CHANGED
@@ -1,9 +1,98 @@
1
  ---
2
  tags:
3
- - model_hub_mixin
4
- - pytorch_model_hub_mixin
 
 
 
 
 
 
 
5
  ---
6
 
7
  This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  tags:
3
+ - target-identification
4
+ - argumentation
5
+ - contrastive-learning
6
+ license: mit
7
+ language:
8
+ - en
9
+ base_model:
10
+ - answerdotai/ModernBERT-base
11
+ pipeline_tag: text-classification
12
  ---
13
 
14
  This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
15
+
16
+ ---
17
+ ## Model Description
18
+ This is a dual-encoder retrieval model built on top of `answerdotai/ModernBERT-base`. The model is designed to perform target identification by finding the most relevant `theses` along with their associated data for a given `claim`
19
+
20
+ You can modify the `top_k`, `num_args` & `top_level_only` variables to adjust the output of the model.
21
+
22
+ ## How to use
23
+ You can use this model for inference by loading it with the `transformers` library. The following code demonstrates how to make a prediction:
24
+
25
+ ```python
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ from transformers import AutoModel, AutoTokenizer
30
+ from huggingface_hub import hf_hub_download, PyTorchModelHubMixin
31
+
32
+ import pickle
33
+ from sklearn.metrics.pairwise import cosine_similarity
34
+ import numpy as np
35
+
36
+ class DualEncoderThesisModel(nn.Module, PyTorchModelHubMixin):
37
+ def __init__(self) -> None:
38
+ super(DualEncoderThesisModel, self).__init__()
39
+ self.encoder = AutoModel.from_pretrained("answerdotai/ModernBERT-base")
40
+
41
+ def forward(self, input_ids_a, attention_mask_a, input_ids_b, attention_mask_b):
42
+ # Encode arguments
43
+ output_a = self.encoder(input_ids=input_ids_a, attention_mask=attention_mask_a).last_hidden_state
44
+ emb_a = output_a[:, 0]
45
+
46
+ # Encode theses
47
+ output_b = self.encoder(input_ids=input_ids_b, attention_mask=attention_mask_b).last_hidden_state
48
+ emb_b = output_b[:, 0]
49
+
50
+ return emb_a, emb_b
51
+
52
+ model_name = "ag-charalampous/target-identification"
53
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
54
+
55
+ model = DualEncoderThesisModel.from_pretrained(model_name)
56
+ model.eval()
57
+
58
+ device = "cpu"
59
+
60
+ embeddings_path = hf_hub_download(
61
+ repo_id="ag-charalampous/target-identification",
62
+ filename="retrieval_data_random_negatives_10_train_data.pkl"
63
+ )
64
+
65
+ with open(embeddings_path, "rb") as f:
66
+ embeddings_metadata = pickle.load(f)
67
+
68
+ @torch.no_grad()
69
+ def retrieve_theses(claim, top_k=3, num_args=5, top_level_only=True, device="cpu"):
70
+ stored_embeddings = embeddings_metadata["embeddings"]
71
+ metadata = embeddings_metadata["metadata"]
72
+
73
+ enc = tokenizer(claim, return_tensors='pt', truncation=True, padding='max_length', max_length=1024).to(device)
74
+ query_embedding = model.encoder(**enc).last_hidden_state[:, 0].cpu().numpy()
75
+
76
+ sims = cosine_similarity(query_embedding, stored_embeddings)[0]
77
+ top_indices = np.argsort(sims)[::-1][:top_k]
78
+
79
+ results = []
80
+ for idx in top_indices:
81
+ arguments = metadata[idx]['arguments']
82
+ if top_level_only:
83
+ arguments = [arg for arg in arguments if arg['target_type'] == 'thesis']
84
+
85
+ results.append({
86
+ "thesis": metadata[idx]["thesis"],
87
+ "debate_title": metadata[idx]["debate_title"],
88
+ "arguments": arguments[:num_args]
89
+ })
90
+
91
+ return results
92
+
93
+ claim = "A fetus or embryo is not a person; therefore, abortion should not be considered murder."
94
+
95
+ theses = retrieve_theses(claim)
96
+
97
+ for thesis in theses:
98
+ print(f"{thesis['thesis']} | {thesis['debate_title']}")