azza1625 commited on
Commit
9635ecd
·
verified ·
1 Parent(s): cf1b3ae

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +137 -4
README.md CHANGED
@@ -2,10 +2,143 @@
2
  license: mit
3
  pipeline_tag: text-classification
4
  tags:
5
- - model_hub_mixin
6
- - pytorch_model_hub_mixin
 
 
 
 
 
7
  ---
8
 
9
  This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
10
- - Library: [More Information Needed]
11
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: mit
3
  pipeline_tag: text-classification
4
  tags:
5
+ - argument-detection
6
+ - stance-detection
7
+ - multi-task-learning
8
+ language:
9
+ - en
10
+ base_model:
11
+ - answerdotai/ModernBERT-large
12
  ---
13
 
14
  This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
15
+
16
+ ---
17
+ ## Model Description
18
+ This is a multi-task learning (MTL) model built on top of `answerdotai/ModernBERT-large`. The model is designed to perform two distinct text classification tasks using a shared feature representation, enhanced by a Mixture-of-Experts (MoE) layer.
19
+
20
+ The model can be used for:
21
+ 1. **Argumentativeness Classification:** Classifying a text as either "Argumentative" or "Non-argumentative."
22
+ 2. **Stance Classification:** Classifying the relationship between two claims as "Same-side" or "Opposing-side."
23
+
24
+ ## How to use
25
+ You can use this model for inference by loading it with the `transformers` library. The following code demonstrates how to make a prediction:
26
+
27
+ ```python
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ from transformers import AutoTokenizer, AutoModel
32
+ from huggingface_hub import PyTorchModelHubMixin
33
+
34
+ class MoELayer(nn.Module):
35
+ def __init__(self, input_dim, num_experts, top_k=2):
36
+ super(MoELayer, self).__init__()
37
+ self.num_experts = num_experts
38
+ self.top_k = top_k
39
+
40
+ # Define experts as independent feed-forward layers
41
+ self.experts = nn.ModuleList([nn.Sequential(
42
+ nn.Linear(input_dim, input_dim * 2),
43
+ nn.ReLU(),
44
+ nn.Linear(input_dim * 2, input_dim)
45
+ ) for _ in range(num_experts)])
46
+
47
+ self.gating_network = nn.Linear(input_dim, num_experts)
48
+
49
+ def forward(self, x):
50
+ gate_logits = self.gating_network(x)
51
+ gate_probs = F.softmax(gate_logits, dim=-1)
52
+
53
+ # Get top-k experts for each input
54
+ topk_vals, topk_indices = torch.topk(gate_probs, self.top_k, dim=-1)
55
+
56
+ # Compute contributions from top-k experts
57
+ output = torch.zeros_like(x)
58
+ for i in range(self.top_k):
59
+ expert_idx = topk_indices[:, i]
60
+ expert_weight = topk_vals[:, i].unsqueeze(-1)
61
+
62
+ expert_outputs = torch.stack([self.experts[j](x[b]) for b, j in enumerate(expert_idx)], dim=0)
63
+
64
+ output += expert_weight * expert_outputs
65
+
66
+ return output
67
+
68
+ class SentenceClassificationMoeMTLModel(
69
+ nn.Module,
70
+ PyTorchModelHubMixin,
71
+ ):
72
+ def __init__(self) -> None:
73
+ super(SentenceClassificationMoeMTLModel, self).__init__()
74
+ self.base_model = AutoModel.from_pretrained("answerdotai/ModernBERT-large")
75
+
76
+ self.moe_layer = MoELayer(input_dim=self.base_model.config.hidden_size, num_experts=8, top_k=2)
77
+
78
+ self.task_1_classifier = nn.Sequential(
79
+ nn.Linear(in_features=self.base_model.config.hidden_size, out_features=768, bias=False),
80
+ nn.GELU(),
81
+ nn.LayerNorm(768, eps=1e-05, elementwise_affine=True),
82
+ nn.Linear(768, 2)
83
+ )
84
+
85
+ self.task_2_classifier = nn.Sequential(
86
+ nn.Linear(in_features=self.base_model.config.hidden_size, out_features=768, bias=False),
87
+ nn.GELU(),
88
+ nn.LayerNorm(768, eps=1e-05, elementwise_affine=True),
89
+ nn.Linear(768, 2),
90
+ )
91
+
92
+ def forward(self, task, input_ids, attention_mask):
93
+ x = self.base_model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
94
+ cls_r = x[:, 0]
95
+
96
+ x = self.moe_layer(x[:, 0])
97
+
98
+ if task == "arg":
99
+ x = self.task_1_classifier(x)
100
+ elif task == "stance":
101
+ x = self.task_2_classifier(x)
102
+
103
+ return x, cls_r
104
+
105
+ model_name = "ag-charalampous/argument-same-side-stance-classification"
106
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
107
+
108
+ model = SentenceClassificationMoeMTLModel.from_pretrained(model_name)
109
+ model.eval()
110
+
111
+ device = "cpu"
112
+
113
+ def classify_sequence(seq, task, label_map):
114
+ enc = tokenizer(
115
+ *(seq if task == 'stance' else (seq,)),
116
+ return_tensors="pt",
117
+ truncation=True,
118
+ max_length=1024
119
+ ).to(device)
120
+
121
+ with torch.no_grad():
122
+ logits, _ = model(task=task, **enc)
123
+ probs = torch.softmax(logits, dim=-1).squeeze()
124
+ pred_idx = probs.argmax().item()
125
+ confidence = probs[pred_idx].item()
126
+
127
+ return label_map[pred_idx], confidence
128
+
129
+ # Example input for task 1
130
+ text = "A fetus or embryo is not a person; therefore, abortion should not be considered murder."
131
+
132
+ label_map = {0: "Non-argumentative", 1: "Argumentative"}
133
+ label, confidence = classify_sequence(text, 'arg', label_map)
134
+
135
+ print(f"Prediction: {label} (Confidence: {confidence:.2f})")
136
+
137
+ # Example input for task 2
138
+ claim_1 = "A fetus or embryo is not a person; therefore, abortion should not be considered murder."
139
+ claim_2 = "Since death is the intention, such procedures should be considered murder."
140
+
141
+ label_map = {0: "Same-side", 1: "Opposing-side"}
142
+ label, confidence = classify_sequence([claim_1, claim_2], 'stance', label_map)
143
+
144
+ print(f"Prediction: {label} (Confidence: {confidence:.2f})")