| import torch |
| import torch.nn as nn |
| import lightning as L |
| import torchmetrics as tm |
| from tokenizers import Tokenizer |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
|
|
| COARSE_LABELS = [ |
| "ABBR (0): Abbreviation", |
| "ENTY (1): Entity", |
| "DESC (2): Description and abstract concept", |
| "HUM (3): Human being", |
| "LOC (4): Location", |
| "NUM (5): Numeric value", |
| ] |
|
|
| FINE_LABELS = [ |
| "ABBR (0): Abbreviation", |
| "ABBR (1): Expression abbreviated", |
| "ENTY (2): Animal", |
| "ENTY (3): Organ of body", |
| "ENTY (4): Color", |
| "ENTY (5): Invention, book and other creative piece", |
| "ENTY (6): Currency name", |
| "ENTY (7): Disease and medicine", |
| "ENTY (8): Event", |
| "ENTY (9): Food", |
| "ENTY (10): Musical instrument", |
| "ENTY (11): Language", |
| "ENTY (12): Letter like a-z", |
| "ENTY (13): Other entity", |
| "ENTY (14): Plant", |
| "ENTY (15): Product", |
| "ENTY (16): Religion", |
| "ENTY (17): Sport", |
| "ENTY (18): Element and substance", |
| "ENTY (19): Symbols and sign", |
| "ENTY (20): Techniques and method", |
| "ENTY (21): Equivalent term", |
| "ENTY (22): Vehicle", |
| "ENTY (23): Word with a special property", |
| "DESC (24): Definition of something", |
| "DESC (25): Description of something", |
| "DESC (26): Manner of an action", |
| "DESC (27): Reason", |
| "HUM (28): Group or organization of persons", |
| "HUM (29): Individual", |
| "HUM (30): Title of a person", |
| "HUM (31): Description of a person", |
| "LOC (32): City", |
| "LOC (33): Country", |
| "LOC (34): Mountain", |
| "LOC (35): Other location", |
| "LOC (36): State", |
| "NUM (37): Postcode or other code", |
| "NUM (38): Number of something", |
| "NUM (39): Date", |
| "NUM (40): Distance, linear measure", |
| "NUM (41): Price", |
| "NUM (42): Order, rank", |
| "NUM (43): Other number", |
| "NUM (44): Lasting time of something", |
| "NUM (45): Percent, fraction", |
| "NUM (46): Speed", |
| "NUM (47): Temperature", |
| "NUM (48): Size, area and volume", |
| "NUM (49): Weight", |
| ] |
|
|
|
|
| class Classifier: |
| def __init__(self, tokenizer_ckpt_path, model_ckpt_path): |
| self.tokenizer = Tokenizer.from_file(tokenizer_ckpt_path) |
| self.model = LSTMWithAttentionClassifier.load_from_checkpoint( |
| model_ckpt_path, |
| map_location="cpu", |
| ) |
|
|
| def predict(self, text): |
| encoding = self.tokenizer.encode(text) |
| ids = torch.tensor([encoding.ids]) |
| logits, _ = self.model(ids) |
| probs = torch.softmax(logits, dim=1).squeeze().tolist() |
| return { |
| category: prob |
| for category, prob in zip( |
| FINE_LABELS if self.model.fine else COARSE_LABELS, probs |
| ) |
| } |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, hidden_dim): |
| super().__init__() |
| self.WQuery = nn.Linear(hidden_dim, hidden_dim) |
| self.WKey = nn.Linear(hidden_dim, hidden_dim) |
| self.WValue = nn.Linear(hidden_dim, 1) |
|
|
| def forward(self, x): |
| query = torch.tanh(self.WQuery(x)) |
| key = torch.tanh(self.WKey(x)) |
|
|
| attention_weights = torch.softmax(self.WValue(query + key), dim=1) |
|
|
| return (attention_weights * x).sum(dim=1), attention_weights |
|
|
|
|
| class LSTMWithAttentionClassifier(L.LightningModule): |
| def __init__( |
| self, |
| vocab_size, |
| embedding_dim, |
| hidden_dim, |
| num_classes, |
| lr=1e-3, |
| weight_decay=1e-2, |
| num_layers=1, |
| bidirectional=False, |
| dropout=0.0, |
| padding_idx=3, |
| fine=False, |
| **kwargs, |
| ): |
| super().__init__() |
| self.save_hyperparameters() |
| self.lr = lr |
| self.weight_decay = weight_decay |
| self.fine = fine |
|
|
| self.embedding = nn.Embedding( |
| vocab_size, |
| embedding_dim, |
| padding_idx=padding_idx, |
| ) |
| self.lstm = nn.LSTM( |
| embedding_dim, |
| hidden_dim, |
| num_layers=num_layers, |
| batch_first=True, |
| bidirectional=bidirectional, |
| dropout=dropout, |
| ) |
| self.attention = Attention( |
| hidden_dim * (1 + bidirectional), |
| ) |
| self.fc = nn.Linear( |
| hidden_dim * (1 + bidirectional), |
| num_classes, |
| ) |
|
|
| self.criteria = nn.CrossEntropyLoss() |
| self.accuracy = tm.Accuracy( |
| task="multiclass", |
| num_classes=num_classes, |
| ) |
|
|
| def forward(self, input_ids): |
| x = self.embedding(input_ids) |
| x, _ = self.lstm(x) |
| x, attention_weights = self.attention(x) |
| x = self.fc(x) |
| return x, attention_weights |
|
|
| def training_step(self, batch, batch_idx): |
| input_ids = batch["input_ids"] |
| coarse = batch["coarse"] |
| fine = batch["fine"] |
| logits, _ = self(input_ids) |
| loss = self.criteria(logits, fine if self.fine else coarse) |
| self.log("train_loss", loss) |
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| input_ids = batch["input_ids"] |
| coarse = batch["coarse"] |
| fine = batch["fine"] |
| logits, _ = self(input_ids) |
| loss = self.criteria(logits, fine if self.fine else coarse) |
| self.log("val_loss", loss) |
| pred = logits.argmax(dim=1) |
| self.accuracy(pred, fine if self.fine else coarse) |
| self.log("val_acc", self.accuracy, prog_bar=True) |
|
|
| def configure_optimizers(self): |
| return torch.optim.AdamW( |
| self.parameters(), |
| lr=self.lr, |
| weight_decay=self.weight_decay, |
| ) |
|
|
|
|
| tokenizer_ckpt_path = hf_hub_download( |
| repo_id="SatwikKambham/trec-classifier", |
| filename="tokenizer.json", |
| ) |
| model_ckpt_path = hf_hub_download( |
| repo_id="SatwikKambham/trec-classifier", |
| filename="lstm_attention.ckpt", |
| ) |
| classifier = Classifier(tokenizer_ckpt_path, model_ckpt_path) |
| interface = gr.Interface( |
| fn=classifier.predict, |
| inputs=gr.components.Textbox( |
| label="Question", |
| placeholder="Enter a question here...", |
| ), |
| outputs=gr.components.Label( |
| label="Predicted class", |
| num_top_classes=3, |
| ), |
| examples=[ |
| [ |
| "What does LOL mean?", |
| ], |
| [ |
| "What is the meaning of life?", |
| ], |
| [ |
| "How long does it take for light from the sun to reach the earth?", |
| ], |
| [ |
| "When is friendship day?", |
| ], |
| ], |
| ) |
| interface.launch() |
|
|