| | --- |
| | license: mit |
| | datasets: |
| | - fancyzhx/ag_news |
| | language: |
| | - en |
| | metrics: |
| | - accuracy |
| | base_model: |
| | - google-t5/t5-large |
| | pipeline_tag: text-classification |
| | tags: |
| | - ag |
| | - news |
| | - document |
| | - classification |
| | --- |
| | This model is finetuned using AG news dataset for 2 epochs using 120000 train samples and evaluated on the test set with below metrics. |
| |
|
| | Test Loss: 0.1629 |
| |
|
| | Accuracy: 0.9521 |
| |
|
| | F1 Score: 0.9521 |
| |
|
| | Precision: 0.9522 |
| |
|
| | Recall: 0.9522 |
| |
|
| |
|
| | ```python |
| | # Import necessary libraries |
| | import torch |
| | import torch.nn as nn |
| | from transformers import T5Tokenizer, T5ForConditionalGeneration |
| | |
| | # Set device |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | # Define the model class (same structure as used during training) |
| | class CustomT5Model(nn.Module): |
| | def __init__(self): |
| | super(CustomT5Model, self).__init__() |
| | self.t5 = T5ForConditionalGeneration.from_pretrained("t5-large") |
| | self.classifier = nn.Linear(1024, 4) # 4 classes for AG News |
| | |
| | def forward(self, input_ids, attention_mask=None): |
| | encoder_outputs = self.t5.encoder( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | return_dict=True |
| | ) |
| | hidden_states = encoder_outputs.last_hidden_state # (batch_size, seq_len, hidden_dim) |
| | logits = self.classifier(hidden_states[:, 0, :]) # Use [CLS] token representation |
| | return logits |
| | |
| | # Initialize the model |
| | model = CustomT5Model().to(device) |
| | |
| | # Load the saved model weights from Hugging Face |
| | model_path = "https://huggingface.co/Vijayendra/T5-large-docClassification/resolve/main/best_model.pth" |
| | model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location=device)) |
| | model.eval() |
| | |
| | # Load the tokenizer |
| | tokenizer = T5Tokenizer.from_pretrained("t5-large") |
| | |
| | # Inference function |
| | def infer(model, tokenizer, text): |
| | model.eval() |
| | with torch.no_grad(): |
| | # Preprocess the input text |
| | inputs = tokenizer( |
| | [f"classify: {text}"], |
| | max_length=99, |
| | truncation=True, |
| | padding="max_length", |
| | return_tensors="pt" |
| | ) |
| | input_ids = inputs["input_ids"].to(device) |
| | attention_mask = inputs["attention_mask"].to(device) |
| | |
| | # Get model predictions |
| | logits = model(input_ids=input_ids, attention_mask=attention_mask) |
| | preds = torch.argmax(logits, dim=-1) |
| | |
| | # Map class index to label |
| | label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} |
| | return label_map[preds.item()] |
| | |
| | # Example usage |
| | text = "NASA announces new mission to study asteroids" |
| | result = infer(model, tokenizer, text) |
| | print(f"Predicted category: {result}") |