| # RoBERTa-Based Topic Classification Model Using AG News Dataset | |
| This repository hosts a RoBERTa-based transformer model fine-tuned for topic classification on the AG News dataset. The model identifies topics such as World, Sports, Business, and Science/Technology in a given news text. | |
| ## Model Details | |
| - **Model Architecture:** RoBERTa (roberta-base) | |
| - **Task:** Topic Classification | |
| - **Dataset:** AG News (from Hugging Face Datasets) | |
| - **Fine-tuning Framework:** Hugging Face Transformers | |
| ## Usage | |
| ### Installation | |
| ```sh | |
| pip install transformers datasets torch | |
| ``` | |
| ### Loading and Predicting | |
| ```python | |
| from transformers import RobertaTokenizer, RobertaForSequenceClassification | |
| import torch | |
| # Load model and tokenizer | |
| model = RobertaForSequenceClassification.from_pretrained("your-saved-model-directory") | |
| tokenizer = RobertaTokenizer.from_pretrained("roberta-base") | |
| model.eval() | |
| # Sample prediction | |
| def predict_topic(texts, model, tokenizer, device='cpu'): | |
| import re | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| def preprocess(text): | |
| text = text.lower() | |
| text = re.sub(r"http\S+|www\S+|https\S+", '', text) | |
| text = re.sub(r'\@\w+|\#', '', text) | |
| text = re.sub(r"[^a-zA-Z0-9\s.,!?']", '', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| cleaned_texts = [preprocess(t) for t in texts] | |
| inputs = tokenizer(cleaned_texts, padding=True, truncation=True, return_tensors="pt").to(device) | |
| model.to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| preds = torch.argmax(outputs.logits, dim=1).tolist() | |
| label_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} | |
| return [label_map[p] for p in preds] | |
| # Example | |
| sample_texts = [ | |
| "The stock market witnessed a major crash today due to inflation concerns.", | |
| "The new space telescope has captured unprecedented images of distant galaxies." | |
| ] | |
| results = predict_topic(sample_texts, model, tokenizer) | |
| for text, label in zip(sample_texts, results): | |
| print(f"Text: {text}\nPredicted Topic: {label}\n") | |
| ``` | |
| ## Performance Metrics | |
| - **Accuracy:** ~0.96 on AG News test split | |
| ## Fine-Tuning Details | |
| ### Dataset | |
| The dataset is sourced from the AG News dataset available via Hugging Face Datasets. | |
| ### Training | |
| - Number of epochs: 3 | |
| - Batch size: 8 | |
| - Evaluation strategy: epoch | |
| - Learning rate: 2e-5 | |
| ## Repository Structure | |
| ``` | |
| . | |
| βββ model/ # Contains the fine-tuned model files | |
| βββ tokenizer/ # Tokenizer configuration and vocab | |
| βββ README.md # Model documentation | |
| ``` | |
| ## Limitations | |
| - The model is trained specifically for AG News-style texts and may not generalize well to informal or unrelated domains. | |
| ## Contributing | |
| Contributions are welcome! Please open an issue or submit a PR for suggestions or improvements. |