wizardoftrap commited on
Commit
a3ad589
·
verified ·
1 Parent(s): af5a4a6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +88 -3
README.md CHANGED
@@ -1,7 +1,92 @@
 
 
 
 
 
 
 
 
 
1
  # Language-Agnostic Text Classifier
2
 
3
- Trained only on **English** data.
4
- Works on **Hindi** at inference time without retraining.
5
 
6
  **Task:** Sentence-level sentiment classification
7
- **Base model:** bert-base-multilingual-cased
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - stanfordnlp/imdb
4
+ language:
5
+ - en
6
+ - hi
7
+ base_model:
8
+ - google-bert/bert-base-multilingual-cased
9
+ ---
10
  # Language-Agnostic Text Classifier
11
 
12
+ Trained only on **English** data <br>
13
+ Works on both **English** and **Hindi** at inference time without retraining *(Other langauges not tested)*
14
 
15
  **Task:** Sentence-level sentiment classification
16
+ **Base model:** bert-base-multilingual-cased <br>
17
+ **For more details:** *[Github Repo](https://github.com/wizardoftrap/language_agnostic_classifier)*
18
+ ## Usage
19
+
20
+ ```python
21
+ import torch
22
+ import torch.nn as nn
23
+ from transformers import AutoTokenizer, AutoModel
24
+
25
+ class LanguageAgnosticClassifier(nn.Module):
26
+ def __init__(self, base_model, num_labels):
27
+ super().__init__()
28
+ self.encoder = AutoModel.from_pretrained(base_model)
29
+ hidden = self.encoder.config.hidden_size
30
+ self.classifier = nn.Linear(hidden, num_labels)
31
+
32
+ def mean_pool(self, hidden, mask):
33
+ mask = mask.unsqueeze(-1).float()
34
+ return (hidden * mask).sum(1) / mask.sum(1)
35
+
36
+ def forward(self, input_ids, attention_mask):
37
+ out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
38
+ pooled = self.mean_pool(out.last_hidden_state, attention_mask)
39
+ return self.classifier(pooled)
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ "wizardoftrap/language_agnostic_classifier"
43
+ )
44
+
45
+ model = LanguageAgnosticClassifier(
46
+ base_model="bert-base-multilingual-cased",
47
+ num_labels=2
48
+ )
49
+
50
+ state_dict = torch.hub.load_state_dict_from_url(
51
+ "https://huggingface.co/wizardoftrap/language_agnostic_classifier/resolve/main/bert-language_agnostic-classifier.bin",
52
+ map_location="cpu"
53
+ )
54
+
55
+ model.load_state_dict(state_dict)
56
+ model.eval()
57
+
58
+ def predict(text):
59
+ enc = tokenizer(
60
+ text,
61
+ return_tensors="pt",
62
+ truncation=True,
63
+ padding="max_length",
64
+ max_length=128
65
+ )
66
+ with torch.no_grad():
67
+ logits = model(enc["input_ids"], enc["attention_mask"])
68
+ return logits.argmax(1).item()
69
+
70
+ predict("This movie was amazing")
71
+ predict("This movie was terrible")
72
+ predict("The film was not bad, but not great either")
73
+ predict("Despite good acting, the story failed to impress me")
74
+
75
+ predict("यह फिल्म बहुत शानदार थी")
76
+ predict("यह फिल्म बहुत खराब थी")
77
+ predict("फिल्म बुरी नहीं थी, लेकिन खास भी नहीं लगी")
78
+ predict("अभिनय अच्छा था, पर कहानी कमजोर रह गई")
79
+
80
+ predict("Story अच्छी थी but execution weak था")
81
+ predict("Acting was good लेकिन movie boring लगी")
82
+ predict("Concept अच्छा था but screenplay खराब था")
83
+
84
+ predict("Yeah, this movie was a masterpiece… said no one ever")
85
+ predict("फिल्म इतनी अच्छी थी कि नींद आ गई")
86
+
87
+ predict("The movie was okay")
88
+ predict("फिल्म ठीक-ठाक थी")
89
+
90
+ ```
91
+
92
+ *- Shiv Prakash Verma*