|
|
--- |
|
|
datasets: |
|
|
- stanfordnlp/imdb |
|
|
language: |
|
|
- en |
|
|
- hi |
|
|
base_model: |
|
|
- google-bert/bert-base-multilingual-cased |
|
|
--- |
|
|
# Language-Agnostic Text Classifier |
|
|
|
|
|
Trained only on **English** data <br> |
|
|
Works on both **English** and **Hindi** at inference time without retraining *(Other langauges not tested)* |
|
|
|
|
|
**Task:** Sentence-level sentiment classification |
|
|
**Base model:** bert-base-multilingual-cased <br> |
|
|
**For more details:** *[Github Repo](https://github.com/wizardoftrap/language_agnostic_classifier)* |
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
class LanguageAgnosticClassifier(nn.Module): |
|
|
def __init__(self, base_model, num_labels): |
|
|
super().__init__() |
|
|
self.encoder = AutoModel.from_pretrained(base_model) |
|
|
hidden = self.encoder.config.hidden_size |
|
|
self.classifier = nn.Linear(hidden, num_labels) |
|
|
|
|
|
def mean_pool(self, hidden, mask): |
|
|
mask = mask.unsqueeze(-1).float() |
|
|
return (hidden * mask).sum(1) / mask.sum(1) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
|
pooled = self.mean_pool(out.last_hidden_state, attention_mask) |
|
|
return self.classifier(pooled) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
"wizardoftrap/language_agnostic_classifier" |
|
|
) |
|
|
|
|
|
model = LanguageAgnosticClassifier( |
|
|
base_model="bert-base-multilingual-cased", |
|
|
num_labels=2 |
|
|
) |
|
|
|
|
|
state_dict = torch.hub.load_state_dict_from_url( |
|
|
"https://huggingface.co/wizardoftrap/language_agnostic_classifier/resolve/main/bert-language_agnostic-classifier.bin", |
|
|
map_location="cpu" |
|
|
) |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
|
|
|
def predict(text): |
|
|
enc = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=128 |
|
|
) |
|
|
with torch.no_grad(): |
|
|
logits = model(enc["input_ids"], enc["attention_mask"]) |
|
|
return logits.argmax(1).item() |
|
|
|
|
|
predict("This movie was amazing") |
|
|
predict("This movie was terrible") |
|
|
predict("The film was not bad, but not great either") |
|
|
predict("Despite good acting, the story failed to impress me") |
|
|
|
|
|
predict("यह फिल्म बहुत शानदार थी") |
|
|
predict("यह फिल्म बहुत खराब थी") |
|
|
predict("फिल्म बुरी नहीं थी, लेकिन खास भी नहीं लगी") |
|
|
predict("अभिनय अच्छा था, पर कहानी कमजोर रह गई") |
|
|
|
|
|
predict("Story अच्छी थी but execution weak था") |
|
|
predict("Acting was good लेकिन movie boring लगी") |
|
|
predict("Concept अच्छा था but screenplay खराब था") |
|
|
|
|
|
predict("Yeah, this movie was a masterpiece… said no one ever") |
|
|
predict("फिल्म इतनी अच्छी थी कि नींद आ गई") |
|
|
|
|
|
predict("The movie was okay") |
|
|
predict("फिल्म ठीक-ठाक थी") |
|
|
|
|
|
``` |
|
|
|
|
|
*- Shiv Prakash Verma* |