--- license: apache-2.0 tags: - bert - deberta - text-classification - fine-tuned - databricks-dolly - prompt-category language: en datasets: - databricks/databricks-dolly-15k base_model: - microsoft/deberta-v3-base --- # ๐Ÿง  DeBERTa-v3 Base - Prompt Category Classifier (Fine-tuned) This model is a fine-tuned version of [`microsoft/deberta-v3-base`](https://huggingface.co/microsoft/deberta-v3-base) on the [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) dataset. It has been trained to classify the **prompt category** based solely on the **response** text. ## ๐Ÿ—‚๏ธ Task **Text Classification** **Input**: Response text **Output**: One of the predefined categories such as: - `brainstorming` - `classification` - `closed_qa` - `creative_writing` - `general_qa` - `information_extraction` - `open_qa` - `summarization` ## ๐Ÿ“Š Evaluation The model was evaluated on a balanced version of the dataset. Here are the results: - **Validation Accuracy**: ~85.5% - **F1 Score**: ~85.0% - Best performance on: `creative_writing`, `classification`, `summarization` - Room for improvement on: `open_qa` ## ๐Ÿงช How to Use ```python from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch model = AutoModelForSequenceClassification.from_pretrained("mariadg/deberta-v3-prompt-recognition") tokenizer = AutoTokenizer.from_pretrained("mariadg/deberta-v3-prompt-recognition") text = "The mitochondria is known as the powerhouse of the cell." inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) outputs = model(**inputs) pred = torch.argmax(outputs.logits, dim=1).item() print(pred) # Map this index back to label if needed ``` ## ๐Ÿ“ฆ Label Mapping The model outputs a numerical label corresponding to a prompt category. Below is the mapping between label IDs and their respective categories: 0: brainstorming 1: classification 2: closed_qa 3: creative_writing 4: general_qa 5: information_extraction 6: open_qa 7: summarization ## ๐Ÿ› ๏ธ Training Details - **Base model**: `microsoft/deberta-v3-base` - **Framework**: PyTorch - **Max length**: 256 - **Batch size**: 16 - **Epochs**: 4 - **Loss function**: `CrossEntropyLoss` ## ๐Ÿ” License Apache 2.0 --- ๐Ÿ“ Fine-tuned by [mariadg](https://huggingface.co/mariadg) โ€“ for research purposes.