Update README.md
Browse files
README.md
CHANGED
|
@@ -61,11 +61,36 @@ classification of news categories politics, society and conflicts.
|
|
| 61 |
Example of how to use the model:
|
| 62 |
|
| 63 |
```python
|
|
|
|
|
|
|
| 64 |
import torch
|
| 65 |
from transformers import AutoTokenizer
|
| 66 |
from huggingface_hub import hf_hub_download
|
| 67 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
|
| 70 |
'politics', 'science', 'society', 'sports', 'travel']
|
| 71 |
|
|
|
|
| 61 |
Example of how to use the model:
|
| 62 |
|
| 63 |
```python
|
| 64 |
+
import torch.nn as nn
|
| 65 |
+
from transformers import BertModel
|
| 66 |
import torch
|
| 67 |
from transformers import AutoTokenizer
|
| 68 |
from huggingface_hub import hf_hub_download
|
| 69 |
|
| 70 |
|
| 71 |
+
class BiLSTMClassifier(nn.Module):
|
| 72 |
+
def __init__(self, hidden_dim, output_dim, n_layers, dropout):
|
| 73 |
+
super(BiLSTMClassifier, self).__init__()
|
| 74 |
+
self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
|
| 75 |
+
self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim, num_layers=n_layers,
|
| 76 |
+
bidirectional=True, dropout=dropout, batch_first=True)
|
| 77 |
+
self.fc = nn.Linear(hidden_dim * 2, output_dim)
|
| 78 |
+
self.dropout = nn.Dropout(dropout)
|
| 79 |
+
|
| 80 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
embedded = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
|
| 83 |
+
lstm_out, _ = self.lstm(embedded)
|
| 84 |
+
pooled = torch.mean(lstm_out, dim=1)
|
| 85 |
+
logits = self.fc(self.dropout(pooled))
|
| 86 |
+
|
| 87 |
+
if labels is not None:
|
| 88 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 89 |
+
loss = loss_fn(logits, labels)
|
| 90 |
+
return {"loss": loss, "logits": logits} # Возвращаем словарь
|
| 91 |
+
return logits # Возвращаем логиты, если метки не переданы
|
| 92 |
+
|
| 93 |
+
|
| 94 |
categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
|
| 95 |
'politics', 'science', 'society', 'sports', 'travel']
|
| 96 |
|