| import torch | |
| import torch.nn as nn | |
| import lightning as L | |
| import torchmetrics as tm | |
| from tokenizers import Tokenizer | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| COARSE_LABELS = [ | |
| "ABBR (0): Abbreviation", | |
| "ENTY (1): Entity", | |
| "DESC (2): Description and abstract concept", | |
| "HUM (3): Human being", | |
| "LOC (4): Location", | |
| "NUM (5): Numeric value", | |
| ] | |
| FINE_LABELS = [ | |
| "ABBR (0): Abbreviation", | |
| "ABBR (1): Expression abbreviated", | |
| "ENTY (2): Animal", | |
| "ENTY (3): Organ of body", | |
| "ENTY (4): Color", | |
| "ENTY (5): Invention, book and other creative piece", | |
| "ENTY (6): Currency name", | |
| "ENTY (7): Disease and medicine", | |
| "ENTY (8): Event", | |
| "ENTY (9): Food", | |
| "ENTY (10): Musical instrument", | |
| "ENTY (11): Language", | |
| "ENTY (12): Letter like a-z", | |
| "ENTY (13): Other entity", | |
| "ENTY (14): Plant", | |
| "ENTY (15): Product", | |
| "ENTY (16): Religion", | |
| "ENTY (17): Sport", | |
| "ENTY (18): Element and substance", | |
| "ENTY (19): Symbols and sign", | |
| "ENTY (20): Techniques and method", | |
| "ENTY (21): Equivalent term", | |
| "ENTY (22): Vehicle", | |
| "ENTY (23): Word with a special property", | |
| "DESC (24): Definition of something", | |
| "DESC (25): Description of something", | |
| "DESC (26): Manner of an action", | |
| "DESC (27): Reason", | |
| "HUM (28): Group or organization of persons", | |
| "HUM (29): Individual", | |
| "HUM (30): Title of a person", | |
| "HUM (31): Description of a person", | |
| "LOC (32): City", | |
| "LOC (33): Country", | |
| "LOC (34): Mountain", | |
| "LOC (35): Other location", | |
| "LOC (36): State", | |
| "NUM (37): Postcode or other code", | |
| "NUM (38): Number of something", | |
| "NUM (39): Date", | |
| "NUM (40): Distance, linear measure", | |
| "NUM (41): Price", | |
| "NUM (42): Order, rank", | |
| "NUM (43): Other number", | |
| "NUM (44): Lasting time of something", | |
| "NUM (45): Percent, fraction", | |
| "NUM (46): Speed", | |
| "NUM (47): Temperature", | |
| "NUM (48): Size, area and volume", | |
| "NUM (49): Weight", | |
| ] | |
| class Classifier: | |
| def __init__(self, tokenizer_ckpt_path, model_ckpt_path): | |
| self.tokenizer = Tokenizer.from_file(tokenizer_ckpt_path) | |
| self.model = LSTMWithAttentionClassifier.load_from_checkpoint( | |
| model_ckpt_path, | |
| map_location="cpu", | |
| ) | |
| def predict(self, text): | |
| encoding = self.tokenizer.encode(text) | |
| ids = torch.tensor([encoding.ids]) | |
| logits, _ = self.model(ids) | |
| probs = torch.softmax(logits, dim=1).squeeze().tolist() | |
| return { | |
| category: prob | |
| for category, prob in zip( | |
| FINE_LABELS if self.model.fine else COARSE_LABELS, probs | |
| ) | |
| } | |
| class Attention(nn.Module): | |
| def __init__(self, hidden_dim): | |
| super().__init__() | |
| self.WQuery = nn.Linear(hidden_dim, hidden_dim) | |
| self.WKey = nn.Linear(hidden_dim, hidden_dim) | |
| self.WValue = nn.Linear(hidden_dim, 1) | |
| def forward(self, x): | |
| query = torch.tanh(self.WQuery(x)) | |
| key = torch.tanh(self.WKey(x)) | |
| attention_weights = torch.softmax(self.WValue(query + key), dim=1) | |
| return (attention_weights * x).sum(dim=1), attention_weights | |
| class LSTMWithAttentionClassifier(L.LightningModule): | |
| def __init__( | |
| self, | |
| vocab_size, | |
| embedding_dim, | |
| hidden_dim, | |
| num_classes, | |
| lr=1e-3, | |
| weight_decay=1e-2, | |
| num_layers=1, | |
| bidirectional=False, | |
| dropout=0.0, | |
| padding_idx=3, | |
| fine=False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.lr = lr | |
| self.weight_decay = weight_decay | |
| self.fine = fine | |
| self.embedding = nn.Embedding( | |
| vocab_size, | |
| embedding_dim, | |
| padding_idx=padding_idx, | |
| ) | |
| self.lstm = nn.LSTM( | |
| embedding_dim, | |
| hidden_dim, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| bidirectional=bidirectional, | |
| dropout=dropout, | |
| ) | |
| self.attention = Attention( | |
| hidden_dim * (1 + bidirectional), | |
| ) | |
| self.fc = nn.Linear( | |
| hidden_dim * (1 + bidirectional), | |
| num_classes, | |
| ) | |
| self.criteria = nn.CrossEntropyLoss() | |
| self.accuracy = tm.Accuracy( | |
| task="multiclass", | |
| num_classes=num_classes, | |
| ) | |
| def forward(self, input_ids): | |
| x = self.embedding(input_ids) | |
| x, _ = self.lstm(x) | |
| x, attention_weights = self.attention(x) | |
| x = self.fc(x) | |
| return x, attention_weights | |
| def training_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| coarse = batch["coarse"] | |
| fine = batch["fine"] | |
| logits, _ = self(input_ids) | |
| loss = self.criteria(logits, fine if self.fine else coarse) | |
| self.log("train_loss", loss) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| coarse = batch["coarse"] | |
| fine = batch["fine"] | |
| logits, _ = self(input_ids) | |
| loss = self.criteria(logits, fine if self.fine else coarse) | |
| self.log("val_loss", loss) | |
| pred = logits.argmax(dim=1) | |
| self.accuracy(pred, fine if self.fine else coarse) | |
| self.log("val_acc", self.accuracy, prog_bar=True) | |
| def configure_optimizers(self): | |
| return torch.optim.AdamW( | |
| self.parameters(), | |
| lr=self.lr, | |
| weight_decay=self.weight_decay, | |
| ) | |
| tokenizer_ckpt_path = hf_hub_download( | |
| repo_id="SatwikKambham/trec-classifier", | |
| filename="tokenizer.json", | |
| ) | |
| model_ckpt_path = hf_hub_download( | |
| repo_id="SatwikKambham/trec-classifier", | |
| filename="lstm_attention.ckpt", | |
| ) | |
| classifier = Classifier(tokenizer_ckpt_path, model_ckpt_path) | |
| interface = gr.Interface( | |
| fn=classifier.predict, | |
| inputs=gr.components.Textbox( | |
| label="Question", | |
| placeholder="Enter a question here...", | |
| ), | |
| outputs=gr.components.Label( | |
| label="Predicted class", | |
| num_top_classes=3, | |
| ), | |
| examples=[ | |
| [ | |
| "What does LOL mean?", | |
| ], | |
| [ | |
| "What is the meaning of life?", | |
| ], | |
| [ | |
| "How long does it take for light from the sun to reach the earth?", | |
| ], | |
| [ | |
| "When is friendship day?", | |
| ], | |
| ], | |
| ) | |
| interface.launch() | |