Spaces:
Running
Running
Shreyas Meher commited on
Commit ·
2be141f
1
Parent(s): c0557bb
Add LoRA/QLoRA fine-tuning and active learning
Browse files- LoRA/QLoRA support in Fine-tune tab (via PEFT) for parameter-efficient
training with lower VRAM usage; also available in model comparison
- New Active Learning tab with iterative uncertainty-based labeling:
entropy, margin, and least-confidence query strategies, round-by-round
metrics chart, and example dataset
- Add peft and bitsandbytes to requirements.txt
- Update README with new features and quick-start guides
- Fix .gitignore to exclude conflibertr/ directory
- .gitignore +2 -0
- README.md +28 -2
- app.py +677 -6
- examples/active_learning/pool.txt +61 -0
- examples/active_learning/pool_with_labels.tsv +61 -0
- examples/active_learning/seed.tsv +20 -0
- requirements.txt +2 -0
.gitignore
CHANGED
|
@@ -6,6 +6,8 @@ env/
|
|
| 6 |
__pycache__/
|
| 7 |
*.pyc
|
| 8 |
conflibertR/
|
|
|
|
|
|
|
| 9 |
screenshots/
|
| 10 |
finetuned_model/
|
| 11 |
ft_output/
|
|
|
|
| 6 |
__pycache__/
|
| 7 |
*.pyc
|
| 8 |
conflibertR/
|
| 9 |
+
conflibertr/
|
| 10 |
+
al_model/
|
| 11 |
screenshots/
|
| 12 |
finetuned_model/
|
| 13 |
ft_output/
|
README.md
CHANGED
|
@@ -41,7 +41,7 @@ Provide a context passage and a question. The model extracts the most relevant a
|
|
| 41 |
|
| 42 |
### Fine-tuning
|
| 43 |
|
| 44 |
-
Train your own binary or multiclass classifier directly in the browser. Upload data (or load a built-in example), pick a base model, configure training, and go. After training, results and a "Try Your Model" panel appear side by side. You can also save the model and run batch predictions.
|
| 45 |
|
| 46 |
### Model Comparison
|
| 47 |
|
|
@@ -50,7 +50,9 @@ Compare multiple base model architectures on the same dataset. The comparison pr
|
|
| 50 |
<!-- Take a screenshot of the Fine-tune tab and save as screenshots/finetune.png -->
|
| 51 |

|
| 52 |
|
|
|
|
| 53 |
|
|
|
|
| 54 |
|
| 55 |
## Supported Models
|
| 56 |
|
|
@@ -144,7 +146,8 @@ Opens at `http://localhost:7860` and generates a public shareable link. The firs
|
|
| 144 |
| Binary Classification | Conflict vs. non-conflict, supports custom models |
|
| 145 |
| Multilabel Classification | Multi-event-type scoring |
|
| 146 |
| Question Answering | Extract answers from a context passage |
|
| 147 |
-
| Fine-tune | Train classifiers, compare models, ROC curves |
|
|
|
|
| 148 |
|
| 149 |
### Fine-tuning Quick Start
|
| 150 |
|
|
@@ -154,6 +157,13 @@ Opens at `http://localhost:7860` and generates a public shareable link. The firs
|
|
| 154 |
4. Review metrics and try your model on new text
|
| 155 |
5. Save the model and load it in the **Binary Classification** tab
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
### Model Comparison Quick Start
|
| 158 |
|
| 159 |
1. Upload data (or load an example) in the **Fine-tune** tab
|
|
@@ -162,6 +172,16 @@ Opens at `http://localhost:7860` and generates a public shareable link. The firs
|
|
| 162 |
4. Click **"Compare Models"**
|
| 163 |
5. View the metrics table, bar chart, and ROC-AUC curves
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
### Data Format
|
| 166 |
|
| 167 |
Tab-separated values (TSV), no header row. Each line: `text<TAB>label`
|
|
@@ -209,10 +229,16 @@ conflibert-gui/
|
|
| 209 |
train.tsv # 0=Diplomacy, 1=Armed Conflict,
|
| 210 |
dev.tsv # 2=Protest, 3=Humanitarian
|
| 211 |
test.tsv
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
```
|
| 213 |
|
| 214 |
## Training Features
|
| 215 |
|
|
|
|
|
|
|
| 216 |
- Early stopping with configurable patience
|
| 217 |
- Learning rate schedulers: linear, cosine, constant, constant with warmup
|
| 218 |
- Mixed precision training (FP16) on CUDA GPUs
|
|
|
|
| 41 |
|
| 42 |
### Fine-tuning
|
| 43 |
|
| 44 |
+
Train your own binary or multiclass classifier directly in the browser. Upload data (or load a built-in example), pick a base model, configure training, and go. Supports **LoRA** and **QLoRA** for parameter-efficient training with lower VRAM usage. After training, results and a "Try Your Model" panel appear side by side. You can also save the model and run batch predictions.
|
| 45 |
|
| 46 |
### Model Comparison
|
| 47 |
|
|
|
|
| 50 |
<!-- Take a screenshot of the Fine-tune tab and save as screenshots/finetune.png -->
|
| 51 |

|
| 52 |
|
| 53 |
+
### Active Learning
|
| 54 |
|
| 55 |
+
Iteratively build a strong classifier with fewer labels. Start with a small labeled seed set and a pool of unlabeled text. The model identifies the most uncertain samples for you to label, retrains, and repeats. Supports entropy, margin, and least-confidence query strategies.
|
| 56 |
|
| 57 |
## Supported Models
|
| 58 |
|
|
|
|
| 146 |
| Binary Classification | Conflict vs. non-conflict, supports custom models |
|
| 147 |
| Multilabel Classification | Multi-event-type scoring |
|
| 148 |
| Question Answering | Extract answers from a context passage |
|
| 149 |
+
| Fine-tune | Train classifiers with optional LoRA/QLoRA, compare models, ROC curves |
|
| 150 |
+
| Active Learning | Iterative uncertainty-based labeling and retraining |
|
| 151 |
|
| 152 |
### Fine-tuning Quick Start
|
| 153 |
|
|
|
|
| 157 |
4. Review metrics and try your model on new text
|
| 158 |
5. Save the model and load it in the **Binary Classification** tab
|
| 159 |
|
| 160 |
+
### LoRA / QLoRA Fine-tuning
|
| 161 |
+
|
| 162 |
+
1. Go to the **Fine-tune** tab
|
| 163 |
+
2. Open **Advanced Settings** and check **Use LoRA** (optionally enable **QLoRA** for 4-bit quantization on CUDA GPUs)
|
| 164 |
+
3. Adjust LoRA rank and alpha as needed (defaults of r=8, alpha=16 work well)
|
| 165 |
+
4. Train as usual — LoRA weights are merged back automatically so the saved model works like any other
|
| 166 |
+
|
| 167 |
### Model Comparison Quick Start
|
| 168 |
|
| 169 |
1. Upload data (or load an example) in the **Fine-tune** tab
|
|
|
|
| 172 |
4. Click **"Compare Models"**
|
| 173 |
5. View the metrics table, bar chart, and ROC-AUC curves
|
| 174 |
|
| 175 |
+
### Active Learning Quick Start
|
| 176 |
+
|
| 177 |
+
1. Go to the **Active Learning** tab
|
| 178 |
+
2. Click **"Load Example: Binary Active Learning"** (or upload your own seed + pool)
|
| 179 |
+
3. Configure the query strategy and samples per round
|
| 180 |
+
4. Click **"Initialize Active Learning"**
|
| 181 |
+
5. Label the uncertain samples shown in the table (fill in 0 or 1)
|
| 182 |
+
6. Click **"Submit Labels & Next Round"** to retrain and get the next batch
|
| 183 |
+
7. Repeat until satisfied, then save the model
|
| 184 |
+
|
| 185 |
### Data Format
|
| 186 |
|
| 187 |
Tab-separated values (TSV), no header row. Each line: `text<TAB>label`
|
|
|
|
| 229 |
train.tsv # 0=Diplomacy, 1=Armed Conflict,
|
| 230 |
dev.tsv # 2=Protest, 3=Humanitarian
|
| 231 |
test.tsv
|
| 232 |
+
active_learning/ # Example active learning dataset
|
| 233 |
+
seed.tsv # 20 labeled seed samples
|
| 234 |
+
pool.txt # 61 unlabeled pool texts
|
| 235 |
+
pool_with_labels.tsv # Ground truth for pool (cheat sheet)
|
| 236 |
```
|
| 237 |
|
| 238 |
## Training Features
|
| 239 |
|
| 240 |
+
- **LoRA / QLoRA** parameter-efficient fine-tuning (via [PEFT](https://github.com/huggingface/peft))
|
| 241 |
+
- **Active learning** with entropy, margin, and least-confidence query strategies
|
| 242 |
- Early stopping with configurable patience
|
| 243 |
- Learning rate schedulers: linear, cosine, constant, constant with warmup
|
| 244 |
- Mixed precision training (FP16) on CUDA GPUs
|
app.py
CHANGED
|
@@ -45,6 +45,19 @@ from sklearn.preprocessing import label_binarize
|
|
| 45 |
from torch.utils.data import Dataset as TorchDataset
|
| 46 |
import gc
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# ============================================================================
|
| 50 |
# CONFIGURATION
|
|
@@ -613,6 +626,7 @@ def run_finetuning(
|
|
| 613 |
train_file, dev_file, test_file, task_type, model_display_name,
|
| 614 |
epochs, batch_size, lr, weight_decay, warmup_ratio, max_seq_len,
|
| 615 |
grad_accum, fp16, patience, scheduler,
|
|
|
|
| 616 |
progress=gr.Progress(track_tqdm=True),
|
| 617 |
):
|
| 618 |
"""Main finetuning function. Returns logs, metrics, model state, and visibility updates."""
|
|
@@ -644,9 +658,42 @@ def run_finetuning(
|
|
| 644 |
# Load model and tokenizer
|
| 645 |
model_id = FINETUNE_MODELS[model_display_name]
|
| 646 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
|
| 651 |
# Create datasets
|
| 652 |
train_ds = TextClassificationDataset(
|
|
@@ -709,6 +756,10 @@ def run_finetuning(
|
|
| 709 |
test_results = trainer.evaluate(test_ds, metric_key_prefix='test')
|
| 710 |
|
| 711 |
# Build log text
|
|
|
|
|
|
|
|
|
|
|
|
|
| 712 |
header = (
|
| 713 |
f"=== Configuration ===\n"
|
| 714 |
f"Model: {model_display_name}\n"
|
|
@@ -716,6 +767,7 @@ def run_finetuning(
|
|
| 716 |
f"Task: {task_type} Classification ({num_labels} classes)\n"
|
| 717 |
f"Data: {len(train_texts)} train / {len(dev_texts)} dev / {len(test_texts)} test\n"
|
| 718 |
f"Epochs: {epochs} Batch: {batch_size} LR: {lr} Scheduler: {scheduler}\n"
|
|
|
|
| 719 |
f"\n=== Training Log ===\n"
|
| 720 |
)
|
| 721 |
runtime = train_result.metrics.get('train_runtime', 0)
|
|
@@ -733,8 +785,11 @@ def run_finetuning(
|
|
| 733 |
metrics_data.append([name, f"{float(v):.4f}"])
|
| 734 |
metrics_df = pd.DataFrame(metrics_data, columns=['Metric', 'Score'])
|
| 735 |
|
| 736 |
-
#
|
| 737 |
-
trained_model = trainer.model
|
|
|
|
|
|
|
|
|
|
| 738 |
trained_model.eval()
|
| 739 |
|
| 740 |
return (
|
|
@@ -864,9 +919,412 @@ def load_example_multiclass():
|
|
| 864 |
)
|
| 865 |
|
| 866 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
def run_comparison(
|
| 868 |
train_file, dev_file, test_file, task_type, selected_models,
|
| 869 |
-
epochs, batch_size, lr,
|
| 870 |
progress=gr.Progress(track_tqdm=True),
|
| 871 |
):
|
| 872 |
"""Train multiple models on the same data and compare performance + ROC curves."""
|
|
@@ -913,6 +1371,18 @@ def run_comparison(
|
|
| 913 |
model = AutoModelForSequenceClassification.from_pretrained(
|
| 914 |
model_id, num_labels=num_labels,
|
| 915 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 916 |
train_ds = TextClassificationDataset(train_texts, train_labels, tokenizer, 512)
|
| 917 |
dev_ds = TextClassificationDataset(dev_texts, dev_labels, tokenizer, 512)
|
| 918 |
test_ds = TextClassificationDataset(test_texts, test_labels, tokenizer, 512)
|
|
@@ -949,6 +1419,10 @@ def run_comparison(
|
|
| 949 |
|
| 950 |
train_result = trainer.train()
|
| 951 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
# Get predictions for ROC curves
|
| 953 |
pred_output = trainer.predict(test_ds)
|
| 954 |
logits = pred_output.predictions
|
|
@@ -1194,6 +1668,10 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 1194 |
"8. Save the model and load it later in the "
|
| 1195 |
"Classification tab\n\n"
|
| 1196 |
"**Advanced features:**\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1197 |
"- Early stopping with configurable patience\n"
|
| 1198 |
"- Learning rate schedulers (linear, cosine, constant)\n"
|
| 1199 |
"- Mixed precision training (FP16 on CUDA GPUs)\n"
|
|
@@ -1429,6 +1907,22 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 1429 |
["linear", "cosine", "constant", "constant_with_warmup"],
|
| 1430 |
label="LR Scheduler", value="linear",
|
| 1431 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1432 |
|
| 1433 |
# -- Train --
|
| 1434 |
ft_train_btn = gr.Button(
|
|
@@ -1502,6 +1996,10 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 1502 |
cmp_epochs = gr.Number(label="Epochs", value=3, minimum=1, precision=0)
|
| 1503 |
cmp_batch = gr.Number(label="Batch Size", value=8, minimum=1, precision=0)
|
| 1504 |
cmp_lr = gr.Number(label="Learning Rate", value=2e-5, minimum=1e-7)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1505 |
cmp_btn = gr.Button("Compare Models", variant="primary")
|
| 1506 |
cmp_log = gr.Textbox(
|
| 1507 |
label="Comparison Log", lines=8,
|
|
@@ -1514,6 +2012,135 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 1514 |
cmp_plot = gr.Plot(label="Metrics Comparison")
|
| 1515 |
cmp_roc = gr.Plot(label="ROC Curves")
|
| 1516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1517 |
# ---- FOOTER ----
|
| 1518 |
gr.Markdown(
|
| 1519 |
"<div style='text-align: center; padding: 1rem 0; margin-top: 0.5rem; "
|
|
@@ -1600,6 +2227,7 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 1600 |
ft_epochs, ft_batch, ft_lr,
|
| 1601 |
ft_weight_decay, ft_warmup, ft_max_len,
|
| 1602 |
ft_grad_accum, ft_fp16, ft_patience, ft_scheduler,
|
|
|
|
| 1603 |
],
|
| 1604 |
outputs=[
|
| 1605 |
ft_log, ft_metrics,
|
|
@@ -1630,12 +2258,55 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
|
|
| 1630 |
outputs=[ft_batch_out],
|
| 1631 |
)
|
| 1632 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1633 |
# Model comparison
|
| 1634 |
cmp_btn.click(
|
| 1635 |
fn=run_comparison,
|
| 1636 |
inputs=[
|
| 1637 |
ft_train_file, ft_dev_file, ft_test_file,
|
| 1638 |
ft_task, cmp_models, cmp_epochs, cmp_batch, cmp_lr,
|
|
|
|
| 1639 |
],
|
| 1640 |
outputs=[cmp_log, cmp_table, cmp_plot, cmp_roc, cmp_results_col],
|
| 1641 |
concurrency_limit=1,
|
|
|
|
| 45 |
from torch.utils.data import Dataset as TorchDataset
|
| 46 |
import gc
|
| 47 |
|
| 48 |
+
# LoRA / QLoRA support (optional)
|
| 49 |
+
try:
|
| 50 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 51 |
+
PEFT_AVAILABLE = True
|
| 52 |
+
except ImportError:
|
| 53 |
+
PEFT_AVAILABLE = False
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
from transformers import BitsAndBytesConfig
|
| 57 |
+
BNB_AVAILABLE = True
|
| 58 |
+
except ImportError:
|
| 59 |
+
BNB_AVAILABLE = False
|
| 60 |
+
|
| 61 |
|
| 62 |
# ============================================================================
|
| 63 |
# CONFIGURATION
|
|
|
|
| 626 |
train_file, dev_file, test_file, task_type, model_display_name,
|
| 627 |
epochs, batch_size, lr, weight_decay, warmup_ratio, max_seq_len,
|
| 628 |
grad_accum, fp16, patience, scheduler,
|
| 629 |
+
use_lora, lora_rank, lora_alpha, use_qlora,
|
| 630 |
progress=gr.Progress(track_tqdm=True),
|
| 631 |
):
|
| 632 |
"""Main finetuning function. Returns logs, metrics, model state, and visibility updates."""
|
|
|
|
| 658 |
# Load model and tokenizer
|
| 659 |
model_id = FINETUNE_MODELS[model_display_name]
|
| 660 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 661 |
+
|
| 662 |
+
lora_active = False
|
| 663 |
+
if use_qlora:
|
| 664 |
+
if not (PEFT_AVAILABLE and BNB_AVAILABLE and torch.cuda.is_available()):
|
| 665 |
+
raise ValueError(
|
| 666 |
+
"QLoRA requires a CUDA GPU and the peft + bitsandbytes packages."
|
| 667 |
+
)
|
| 668 |
+
bnb_config = BitsAndBytesConfig(
|
| 669 |
+
load_in_4bit=True,
|
| 670 |
+
bnb_4bit_quant_type="nf4",
|
| 671 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 672 |
+
bnb_4bit_use_double_quant=True,
|
| 673 |
+
)
|
| 674 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 675 |
+
model_id, num_labels=num_labels, quantization_config=bnb_config,
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 679 |
+
model_id, num_labels=num_labels,
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
if use_lora or use_qlora:
|
| 683 |
+
if not PEFT_AVAILABLE:
|
| 684 |
+
raise ValueError(
|
| 685 |
+
"LoRA requires the 'peft' package. Install: pip install peft"
|
| 686 |
+
)
|
| 687 |
+
lora_config = LoraConfig(
|
| 688 |
+
task_type=TaskType.SEQ_CLS,
|
| 689 |
+
r=int(lora_rank),
|
| 690 |
+
lora_alpha=int(lora_alpha),
|
| 691 |
+
lora_dropout=0.1,
|
| 692 |
+
bias="none",
|
| 693 |
+
)
|
| 694 |
+
model.enable_input_require_grads()
|
| 695 |
+
model = get_peft_model(model, lora_config)
|
| 696 |
+
lora_active = True
|
| 697 |
|
| 698 |
# Create datasets
|
| 699 |
train_ds = TextClassificationDataset(
|
|
|
|
| 756 |
test_results = trainer.evaluate(test_ds, metric_key_prefix='test')
|
| 757 |
|
| 758 |
# Build log text
|
| 759 |
+
lora_info = ""
|
| 760 |
+
if lora_active:
|
| 761 |
+
method = "QLoRA (4-bit)" if use_qlora else "LoRA"
|
| 762 |
+
lora_info = f"PEFT: {method} r={int(lora_rank)} alpha={int(lora_alpha)}\n"
|
| 763 |
header = (
|
| 764 |
f"=== Configuration ===\n"
|
| 765 |
f"Model: {model_display_name}\n"
|
|
|
|
| 767 |
f"Task: {task_type} Classification ({num_labels} classes)\n"
|
| 768 |
f"Data: {len(train_texts)} train / {len(dev_texts)} dev / {len(test_texts)} test\n"
|
| 769 |
f"Epochs: {epochs} Batch: {batch_size} LR: {lr} Scheduler: {scheduler}\n"
|
| 770 |
+
f"{lora_info}"
|
| 771 |
f"\n=== Training Log ===\n"
|
| 772 |
)
|
| 773 |
runtime = train_result.metrics.get('train_runtime', 0)
|
|
|
|
| 785 |
metrics_data.append([name, f"{float(v):.4f}"])
|
| 786 |
metrics_df = pd.DataFrame(metrics_data, columns=['Metric', 'Score'])
|
| 787 |
|
| 788 |
+
# Merge LoRA weights back into base model for clean save/inference
|
| 789 |
+
trained_model = trainer.model
|
| 790 |
+
if lora_active and hasattr(trained_model, 'merge_and_unload'):
|
| 791 |
+
trained_model = trained_model.merge_and_unload()
|
| 792 |
+
trained_model = trained_model.cpu()
|
| 793 |
trained_model.eval()
|
| 794 |
|
| 795 |
return (
|
|
|
|
| 919 |
)
|
| 920 |
|
| 921 |
|
| 922 |
+
# ============================================================================
|
| 923 |
+
# ACTIVE LEARNING
|
| 924 |
+
# ============================================================================
|
| 925 |
+
|
| 926 |
+
def parse_pool_file(file_path):
|
| 927 |
+
"""Parse an unlabeled text pool. Accepts CSV with 'text' column, or one text per line."""
|
| 928 |
+
path = get_path(file_path)
|
| 929 |
+
# Try CSV/TSV with 'text' column first
|
| 930 |
+
try:
|
| 931 |
+
df = pd.read_csv(path)
|
| 932 |
+
if 'text' in df.columns:
|
| 933 |
+
texts = [str(t) for t in df['text'].dropna().tolist()]
|
| 934 |
+
if texts:
|
| 935 |
+
return texts
|
| 936 |
+
except Exception:
|
| 937 |
+
pass
|
| 938 |
+
# Fallback: one text per line
|
| 939 |
+
texts = []
|
| 940 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 941 |
+
for line in f:
|
| 942 |
+
line = line.strip()
|
| 943 |
+
if line:
|
| 944 |
+
texts.append(line)
|
| 945 |
+
if not texts:
|
| 946 |
+
raise ValueError("No texts found in pool file.")
|
| 947 |
+
return texts
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def compute_uncertainty(model, tokenizer, texts, strategy='entropy',
|
| 951 |
+
max_seq_len=512, batch_size=32):
|
| 952 |
+
"""Compute uncertainty scores for unlabeled texts. Higher = more uncertain."""
|
| 953 |
+
model.eval()
|
| 954 |
+
dev = next(model.parameters()).device
|
| 955 |
+
scores = []
|
| 956 |
+
|
| 957 |
+
for i in range(0, len(texts), batch_size):
|
| 958 |
+
batch_texts = texts[i:i + batch_size]
|
| 959 |
+
inputs = tokenizer(
|
| 960 |
+
batch_texts, return_tensors='pt', truncation=True,
|
| 961 |
+
padding=True, max_length=max_seq_len,
|
| 962 |
+
)
|
| 963 |
+
inputs = {k: v.to(dev) for k, v in inputs.items()}
|
| 964 |
+
with torch.no_grad():
|
| 965 |
+
logits = model(**inputs).logits
|
| 966 |
+
probs = torch.softmax(logits, dim=1).cpu().numpy()
|
| 967 |
+
|
| 968 |
+
if strategy == 'entropy':
|
| 969 |
+
s = -np.sum(probs * np.log(probs + 1e-10), axis=1)
|
| 970 |
+
elif strategy == 'margin':
|
| 971 |
+
sorted_p = np.sort(probs, axis=1)
|
| 972 |
+
s = -(sorted_p[:, -1] - sorted_p[:, -2])
|
| 973 |
+
else: # least_confidence
|
| 974 |
+
s = -np.max(probs, axis=1)
|
| 975 |
+
scores.extend(s.tolist())
|
| 976 |
+
|
| 977 |
+
return scores
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
def _build_al_metrics_chart(metrics_history, task_type):
|
| 981 |
+
"""Build a Plotly chart of active-learning metrics across rounds."""
|
| 982 |
+
import plotly.graph_objects as go
|
| 983 |
+
|
| 984 |
+
if not metrics_history:
|
| 985 |
+
return None
|
| 986 |
+
|
| 987 |
+
rounds = [m['round'] for m in metrics_history]
|
| 988 |
+
train_sizes = [m.get('train_size', 0) for m in metrics_history]
|
| 989 |
+
|
| 990 |
+
metric_keys = (['f1', 'accuracy', 'precision', 'recall']
|
| 991 |
+
if task_type == 'Binary'
|
| 992 |
+
else ['f1_macro', 'accuracy'])
|
| 993 |
+
|
| 994 |
+
fig = go.Figure()
|
| 995 |
+
colors = ['#ff6b35', '#3b82f6', '#10b981', '#8b5cf6']
|
| 996 |
+
|
| 997 |
+
for i, key in enumerate(metric_keys):
|
| 998 |
+
values = [m.get(key) for m in metrics_history]
|
| 999 |
+
if any(v is not None for v in values):
|
| 1000 |
+
fig.add_trace(go.Scatter(
|
| 1001 |
+
x=rounds, y=values, mode='lines+markers',
|
| 1002 |
+
name=key.replace('_', ' ').title(),
|
| 1003 |
+
line=dict(color=colors[i % len(colors)], width=2),
|
| 1004 |
+
))
|
| 1005 |
+
|
| 1006 |
+
fig.add_trace(go.Bar(
|
| 1007 |
+
x=rounds, y=train_sizes, name='Train Size',
|
| 1008 |
+
marker_color='rgba(200,200,200,0.4)', yaxis='y2',
|
| 1009 |
+
))
|
| 1010 |
+
|
| 1011 |
+
fig.update_layout(
|
| 1012 |
+
xaxis_title='Round', yaxis_title='Score', yaxis_range=[0, 1.05],
|
| 1013 |
+
yaxis2=dict(title='Train Size', overlaying='y', side='right'),
|
| 1014 |
+
template='plotly_white',
|
| 1015 |
+
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
|
| 1016 |
+
height=350, margin=dict(t=40, b=40),
|
| 1017 |
+
)
|
| 1018 |
+
return fig
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
def _train_al_model(texts, labels, num_labels, dev_texts, dev_labels,
|
| 1022 |
+
task_type, model_id, epochs, batch_size, lr, max_seq_len,
|
| 1023 |
+
use_lora, lora_rank, lora_alpha):
|
| 1024 |
+
"""Train a model for one active-learning round. Returns (model, tokenizer, eval_metrics)."""
|
| 1025 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 1026 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 1027 |
+
model_id, num_labels=num_labels,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
if use_lora and PEFT_AVAILABLE:
|
| 1031 |
+
lora_cfg = LoraConfig(
|
| 1032 |
+
task_type=TaskType.SEQ_CLS,
|
| 1033 |
+
r=int(lora_rank), lora_alpha=int(lora_alpha),
|
| 1034 |
+
lora_dropout=0.1, bias="none",
|
| 1035 |
+
)
|
| 1036 |
+
model.enable_input_require_grads()
|
| 1037 |
+
model = get_peft_model(model, lora_cfg)
|
| 1038 |
+
|
| 1039 |
+
train_ds = TextClassificationDataset(texts, labels, tokenizer, max_seq_len)
|
| 1040 |
+
dev_ds = None
|
| 1041 |
+
if dev_texts is not None:
|
| 1042 |
+
dev_ds = TextClassificationDataset(dev_texts, dev_labels, tokenizer, max_seq_len)
|
| 1043 |
+
|
| 1044 |
+
output_dir = tempfile.mkdtemp(prefix='conflibert_al_')
|
| 1045 |
+
training_args = TrainingArguments(
|
| 1046 |
+
output_dir=output_dir,
|
| 1047 |
+
num_train_epochs=epochs,
|
| 1048 |
+
per_device_train_batch_size=batch_size,
|
| 1049 |
+
per_device_eval_batch_size=batch_size * 2,
|
| 1050 |
+
learning_rate=lr,
|
| 1051 |
+
weight_decay=0.01,
|
| 1052 |
+
warmup_ratio=0.1,
|
| 1053 |
+
eval_strategy='epoch' if dev_ds else 'no',
|
| 1054 |
+
save_strategy='no',
|
| 1055 |
+
logging_steps=10,
|
| 1056 |
+
report_to='none',
|
| 1057 |
+
seed=42,
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
trainer = Trainer(
|
| 1061 |
+
model=model,
|
| 1062 |
+
args=training_args,
|
| 1063 |
+
train_dataset=train_ds,
|
| 1064 |
+
eval_dataset=dev_ds,
|
| 1065 |
+
compute_metrics=make_compute_metrics(task_type) if dev_ds else None,
|
| 1066 |
+
)
|
| 1067 |
+
trainer.train()
|
| 1068 |
+
|
| 1069 |
+
eval_metrics = {}
|
| 1070 |
+
if dev_ds:
|
| 1071 |
+
results = trainer.evaluate()
|
| 1072 |
+
for k, v in results.items():
|
| 1073 |
+
if isinstance(v, (int, float, np.floating)):
|
| 1074 |
+
eval_metrics[k.replace('eval_', '')] = round(float(v), 4)
|
| 1075 |
+
|
| 1076 |
+
trained_model = trainer.model
|
| 1077 |
+
if use_lora and PEFT_AVAILABLE and hasattr(trained_model, 'merge_and_unload'):
|
| 1078 |
+
trained_model = trained_model.merge_and_unload()
|
| 1079 |
+
|
| 1080 |
+
return trained_model, tokenizer, eval_metrics
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
def al_initialize(
|
| 1084 |
+
seed_file, pool_file, dev_file, task_type, model_display_name,
|
| 1085 |
+
query_strategy, query_size, epochs, batch_size, lr, max_seq_len,
|
| 1086 |
+
use_lora, lora_rank, lora_alpha,
|
| 1087 |
+
progress=gr.Progress(track_tqdm=True),
|
| 1088 |
+
):
|
| 1089 |
+
"""Initialize active learning: train on seed data, query first uncertain batch."""
|
| 1090 |
+
try:
|
| 1091 |
+
if seed_file is None or pool_file is None:
|
| 1092 |
+
raise ValueError("Upload both a labeled seed file and an unlabeled pool file.")
|
| 1093 |
+
|
| 1094 |
+
seed_texts, seed_labels, num_labels = parse_data_file(seed_file)
|
| 1095 |
+
pool_texts = parse_pool_file(pool_file)
|
| 1096 |
+
|
| 1097 |
+
dev_texts, dev_labels = None, None
|
| 1098 |
+
if dev_file is not None:
|
| 1099 |
+
dev_texts, dev_labels, _ = parse_data_file(dev_file)
|
| 1100 |
+
|
| 1101 |
+
if task_type == "Binary":
|
| 1102 |
+
num_labels = 2
|
| 1103 |
+
|
| 1104 |
+
query_size = int(query_size)
|
| 1105 |
+
model_id = FINETUNE_MODELS[model_display_name]
|
| 1106 |
+
|
| 1107 |
+
trained_model, tokenizer, eval_metrics = _train_al_model(
|
| 1108 |
+
seed_texts, seed_labels, num_labels, dev_texts, dev_labels,
|
| 1109 |
+
task_type, model_id, int(epochs), int(batch_size), lr,
|
| 1110 |
+
int(max_seq_len), use_lora, lora_rank, lora_alpha,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
# Build round-0 metrics
|
| 1114 |
+
round_metrics = {'round': 0, 'train_size': len(seed_texts)}
|
| 1115 |
+
round_metrics.update(eval_metrics)
|
| 1116 |
+
|
| 1117 |
+
# Query uncertain samples from pool
|
| 1118 |
+
scores = compute_uncertainty(
|
| 1119 |
+
trained_model, tokenizer, pool_texts, query_strategy, int(max_seq_len),
|
| 1120 |
+
)
|
| 1121 |
+
top_indices = np.argsort(scores)[-query_size:][::-1].tolist()
|
| 1122 |
+
query_texts_batch = [pool_texts[i] for i in top_indices]
|
| 1123 |
+
|
| 1124 |
+
annotation_df = pd.DataFrame({
|
| 1125 |
+
'Text': query_texts_batch,
|
| 1126 |
+
'Label': [''] * len(query_texts_batch),
|
| 1127 |
+
})
|
| 1128 |
+
|
| 1129 |
+
al_state = {
|
| 1130 |
+
'labeled_texts': list(seed_texts),
|
| 1131 |
+
'labeled_labels': list(seed_labels),
|
| 1132 |
+
'pool_texts': pool_texts,
|
| 1133 |
+
'pool_available': [i for i in range(len(pool_texts)) if i not in set(top_indices)],
|
| 1134 |
+
'current_query_indices': top_indices,
|
| 1135 |
+
'dev_texts': dev_texts,
|
| 1136 |
+
'dev_labels': dev_labels,
|
| 1137 |
+
'num_labels': num_labels,
|
| 1138 |
+
'round': 1,
|
| 1139 |
+
'metrics_history': [round_metrics],
|
| 1140 |
+
'model_id': model_id,
|
| 1141 |
+
'model_display_name': model_display_name,
|
| 1142 |
+
'task_type': task_type,
|
| 1143 |
+
'query_strategy': query_strategy,
|
| 1144 |
+
'query_size': query_size,
|
| 1145 |
+
'epochs': int(epochs),
|
| 1146 |
+
'batch_size': int(batch_size),
|
| 1147 |
+
'lr': lr,
|
| 1148 |
+
'max_seq_len': int(max_seq_len),
|
| 1149 |
+
'use_lora': use_lora,
|
| 1150 |
+
'lora_rank': int(lora_rank) if use_lora else 8,
|
| 1151 |
+
'lora_alpha': int(lora_alpha) if use_lora else 16,
|
| 1152 |
+
}
|
| 1153 |
+
|
| 1154 |
+
trained_model = trained_model.cpu()
|
| 1155 |
+
trained_model.eval()
|
| 1156 |
+
|
| 1157 |
+
log_text = (
|
| 1158 |
+
f"=== Active Learning Initialized ===\n"
|
| 1159 |
+
f"Seed: {len(seed_texts)} labeled | Pool: {len(pool_texts)} unlabeled\n"
|
| 1160 |
+
f"Model: {model_display_name}\n"
|
| 1161 |
+
f"Strategy: {query_strategy} | Samples/round: {query_size}\n\n"
|
| 1162 |
+
f"--- Round 0 (seed) ---\n"
|
| 1163 |
+
f"Train size: {len(seed_texts)}\n"
|
| 1164 |
+
)
|
| 1165 |
+
for k, v in eval_metrics.items():
|
| 1166 |
+
log_text += f" {k}: {v}\n"
|
| 1167 |
+
log_text += (
|
| 1168 |
+
f"\n--- Round 1: {len(query_texts_batch)} samples queried ---\n"
|
| 1169 |
+
f"Label the samples below, then click 'Submit Labels & Next Round'.\n"
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
chart = _build_al_metrics_chart([round_metrics], task_type)
|
| 1173 |
+
|
| 1174 |
+
return (
|
| 1175 |
+
al_state, trained_model, tokenizer,
|
| 1176 |
+
annotation_df, log_text, chart,
|
| 1177 |
+
gr.Column(visible=True),
|
| 1178 |
+
)
|
| 1179 |
+
|
| 1180 |
+
except Exception as e:
|
| 1181 |
+
return (
|
| 1182 |
+
{}, None, None,
|
| 1183 |
+
pd.DataFrame(columns=['Text', 'Label']),
|
| 1184 |
+
f"Initialization failed:\n{str(e)}",
|
| 1185 |
+
None,
|
| 1186 |
+
gr.Column(visible=False),
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
|
| 1190 |
+
def al_submit_and_continue(
|
| 1191 |
+
annotation_df, al_state, al_model, al_tokenizer, prev_log,
|
| 1192 |
+
progress=gr.Progress(track_tqdm=True),
|
| 1193 |
+
):
|
| 1194 |
+
"""Accept user labels, retrain, query next uncertain batch."""
|
| 1195 |
+
try:
|
| 1196 |
+
if not al_state or al_model is None:
|
| 1197 |
+
raise ValueError("No active session. Initialize first.")
|
| 1198 |
+
|
| 1199 |
+
new_texts = annotation_df['Text'].tolist()
|
| 1200 |
+
new_labels = []
|
| 1201 |
+
for i, raw in enumerate(annotation_df['Label'].tolist()):
|
| 1202 |
+
s = str(raw).strip()
|
| 1203 |
+
if s in ('', 'nan'):
|
| 1204 |
+
raise ValueError(f"Row {i + 1} has no label. Label all samples first.")
|
| 1205 |
+
new_labels.append(int(s))
|
| 1206 |
+
|
| 1207 |
+
num_labels = al_state['num_labels']
|
| 1208 |
+
for l in new_labels:
|
| 1209 |
+
if l < 0 or l >= num_labels:
|
| 1210 |
+
raise ValueError(f"Label {l} out of range [0, {num_labels - 1}].")
|
| 1211 |
+
|
| 1212 |
+
# Add newly labeled samples
|
| 1213 |
+
al_state['labeled_texts'].extend(new_texts)
|
| 1214 |
+
al_state['labeled_labels'].extend(new_labels)
|
| 1215 |
+
|
| 1216 |
+
queried_set = set(al_state['current_query_indices'])
|
| 1217 |
+
al_state['pool_available'] = [
|
| 1218 |
+
i for i in al_state['pool_available'] if i not in queried_set
|
| 1219 |
+
]
|
| 1220 |
+
|
| 1221 |
+
current_round = al_state['round']
|
| 1222 |
+
|
| 1223 |
+
# Retrain on all labeled data
|
| 1224 |
+
trained_model, tokenizer, eval_metrics = _train_al_model(
|
| 1225 |
+
al_state['labeled_texts'], al_state['labeled_labels'],
|
| 1226 |
+
num_labels, al_state['dev_texts'], al_state['dev_labels'],
|
| 1227 |
+
al_state['task_type'], al_state['model_id'],
|
| 1228 |
+
al_state['epochs'], al_state['batch_size'], al_state['lr'],
|
| 1229 |
+
al_state['max_seq_len'], al_state['use_lora'],
|
| 1230 |
+
al_state['lora_rank'], al_state['lora_alpha'],
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
round_metrics = {
|
| 1234 |
+
'round': current_round,
|
| 1235 |
+
'train_size': len(al_state['labeled_texts']),
|
| 1236 |
+
}
|
| 1237 |
+
round_metrics.update(eval_metrics)
|
| 1238 |
+
al_state['metrics_history'].append(round_metrics)
|
| 1239 |
+
|
| 1240 |
+
# Query next batch from remaining pool
|
| 1241 |
+
remaining_pool = al_state['pool_available']
|
| 1242 |
+
remaining_texts = [al_state['pool_texts'][i] for i in remaining_pool]
|
| 1243 |
+
|
| 1244 |
+
log_add = (
|
| 1245 |
+
f"\n--- Round {current_round} complete ---\n"
|
| 1246 |
+
f"Added {len(new_labels)} labels | "
|
| 1247 |
+
f"Total train: {len(al_state['labeled_texts'])}\n"
|
| 1248 |
+
)
|
| 1249 |
+
for k, v in eval_metrics.items():
|
| 1250 |
+
log_add += f" {k}: {v}\n"
|
| 1251 |
+
|
| 1252 |
+
if remaining_texts:
|
| 1253 |
+
scores = compute_uncertainty(
|
| 1254 |
+
trained_model, tokenizer, remaining_texts,
|
| 1255 |
+
al_state['query_strategy'], al_state['max_seq_len'],
|
| 1256 |
+
)
|
| 1257 |
+
q = min(al_state['query_size'], len(remaining_texts))
|
| 1258 |
+
top_local = np.argsort(scores)[-q:][::-1].tolist()
|
| 1259 |
+
top_pool_indices = [remaining_pool[i] for i in top_local]
|
| 1260 |
+
query_texts = [al_state['pool_texts'][i] for i in top_pool_indices]
|
| 1261 |
+
|
| 1262 |
+
al_state['current_query_indices'] = top_pool_indices
|
| 1263 |
+
al_state['round'] = current_round + 1
|
| 1264 |
+
|
| 1265 |
+
annotation_out = pd.DataFrame({
|
| 1266 |
+
'Text': query_texts,
|
| 1267 |
+
'Label': [''] * len(query_texts),
|
| 1268 |
+
})
|
| 1269 |
+
pool_left = len(remaining_pool) - len(top_pool_indices)
|
| 1270 |
+
log_add += (
|
| 1271 |
+
f"Pool remaining: {pool_left}\n"
|
| 1272 |
+
f"\n--- Round {current_round + 1}: {len(query_texts)} samples queried ---\n"
|
| 1273 |
+
)
|
| 1274 |
+
else:
|
| 1275 |
+
annotation_out = pd.DataFrame(columns=['Text', 'Label'])
|
| 1276 |
+
al_state['current_query_indices'] = []
|
| 1277 |
+
al_state['round'] = current_round + 1
|
| 1278 |
+
log_add += "\nPool exhausted. Active learning complete!\n"
|
| 1279 |
+
|
| 1280 |
+
trained_model = trained_model.cpu()
|
| 1281 |
+
trained_model.eval()
|
| 1282 |
+
|
| 1283 |
+
chart = _build_al_metrics_chart(al_state['metrics_history'], al_state['task_type'])
|
| 1284 |
+
log_text = prev_log + log_add
|
| 1285 |
+
|
| 1286 |
+
return (
|
| 1287 |
+
al_state, trained_model, tokenizer,
|
| 1288 |
+
annotation_out, log_text, chart,
|
| 1289 |
+
)
|
| 1290 |
+
|
| 1291 |
+
except Exception as e:
|
| 1292 |
+
return (
|
| 1293 |
+
al_state, al_model, al_tokenizer,
|
| 1294 |
+
pd.DataFrame(columns=['Text', 'Label']),
|
| 1295 |
+
prev_log + f"\nError: {str(e)}\n",
|
| 1296 |
+
None,
|
| 1297 |
+
)
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
def al_save_model(save_path, al_model, al_tokenizer):
|
| 1301 |
+
"""Save the active-learning model to disk."""
|
| 1302 |
+
if al_model is None:
|
| 1303 |
+
return "No model to save. Run at least one round first."
|
| 1304 |
+
if not save_path:
|
| 1305 |
+
return "Please specify a save directory."
|
| 1306 |
+
try:
|
| 1307 |
+
os.makedirs(save_path, exist_ok=True)
|
| 1308 |
+
al_model.save_pretrained(save_path)
|
| 1309 |
+
al_tokenizer.save_pretrained(save_path)
|
| 1310 |
+
return f"Model saved to: {save_path}"
|
| 1311 |
+
except Exception as e:
|
| 1312 |
+
return f"Error saving model: {str(e)}"
|
| 1313 |
+
|
| 1314 |
+
|
| 1315 |
+
def load_example_active_learning():
|
| 1316 |
+
"""Load the active learning example dataset."""
|
| 1317 |
+
return (
|
| 1318 |
+
os.path.join(EXAMPLES_DIR, "active_learning", "seed.tsv"),
|
| 1319 |
+
os.path.join(EXAMPLES_DIR, "active_learning", "pool.txt"),
|
| 1320 |
+
os.path.join(EXAMPLES_DIR, "binary", "dev.tsv"),
|
| 1321 |
+
"Binary",
|
| 1322 |
+
)
|
| 1323 |
+
|
| 1324 |
+
|
| 1325 |
def run_comparison(
|
| 1326 |
train_file, dev_file, test_file, task_type, selected_models,
|
| 1327 |
+
epochs, batch_size, lr, cmp_use_lora, cmp_lora_rank, cmp_lora_alpha,
|
| 1328 |
progress=gr.Progress(track_tqdm=True),
|
| 1329 |
):
|
| 1330 |
"""Train multiple models on the same data and compare performance + ROC curves."""
|
|
|
|
| 1371 |
model = AutoModelForSequenceClassification.from_pretrained(
|
| 1372 |
model_id, num_labels=num_labels,
|
| 1373 |
)
|
| 1374 |
+
|
| 1375 |
+
cmp_lora_active = False
|
| 1376 |
+
if cmp_use_lora and PEFT_AVAILABLE:
|
| 1377 |
+
lora_cfg = LoraConfig(
|
| 1378 |
+
task_type=TaskType.SEQ_CLS,
|
| 1379 |
+
r=int(cmp_lora_rank), lora_alpha=int(cmp_lora_alpha),
|
| 1380 |
+
lora_dropout=0.1, bias="none",
|
| 1381 |
+
)
|
| 1382 |
+
model.enable_input_require_grads()
|
| 1383 |
+
model = get_peft_model(model, lora_cfg)
|
| 1384 |
+
cmp_lora_active = True
|
| 1385 |
+
|
| 1386 |
train_ds = TextClassificationDataset(train_texts, train_labels, tokenizer, 512)
|
| 1387 |
dev_ds = TextClassificationDataset(dev_texts, dev_labels, tokenizer, 512)
|
| 1388 |
test_ds = TextClassificationDataset(test_texts, test_labels, tokenizer, 512)
|
|
|
|
| 1419 |
|
| 1420 |
train_result = trainer.train()
|
| 1421 |
|
| 1422 |
+
# Merge LoRA weights before prediction
|
| 1423 |
+
if cmp_lora_active and hasattr(trainer.model, 'merge_and_unload'):
|
| 1424 |
+
trainer.model = trainer.model.merge_and_unload()
|
| 1425 |
+
|
| 1426 |
# Get predictions for ROC curves
|
| 1427 |
pred_output = trainer.predict(test_ds)
|
| 1428 |
logits = pred_output.predictions
|
|
|
|
| 1668 |
"8. Save the model and load it later in the "
|
| 1669 |
"Classification tab\n\n"
|
| 1670 |
"**Advanced features:**\n"
|
| 1671 |
+
"- **LoRA / QLoRA** for parameter-efficient training "
|
| 1672 |
+
"(lower VRAM, faster)\n"
|
| 1673 |
+
"- **Active Learning** tab for iterative labeling "
|
| 1674 |
+
"with uncertainty sampling\n"
|
| 1675 |
"- Early stopping with configurable patience\n"
|
| 1676 |
"- Learning rate schedulers (linear, cosine, constant)\n"
|
| 1677 |
"- Mixed precision training (FP16 on CUDA GPUs)\n"
|
|
|
|
| 1907 |
["linear", "cosine", "constant", "constant_with_warmup"],
|
| 1908 |
label="LR Scheduler", value="linear",
|
| 1909 |
)
|
| 1910 |
+
gr.Markdown("**Parameter-Efficient Fine-Tuning (PEFT)**")
|
| 1911 |
+
with gr.Row():
|
| 1912 |
+
ft_use_lora = gr.Checkbox(
|
| 1913 |
+
label="Use LoRA", value=False,
|
| 1914 |
+
)
|
| 1915 |
+
ft_lora_rank = gr.Number(
|
| 1916 |
+
label="LoRA Rank (r)", value=8,
|
| 1917 |
+
minimum=1, maximum=256, precision=0,
|
| 1918 |
+
)
|
| 1919 |
+
ft_lora_alpha = gr.Number(
|
| 1920 |
+
label="LoRA Alpha", value=16,
|
| 1921 |
+
minimum=1, maximum=512, precision=0,
|
| 1922 |
+
)
|
| 1923 |
+
ft_use_qlora = gr.Checkbox(
|
| 1924 |
+
label="QLoRA (4-bit, CUDA only)", value=False,
|
| 1925 |
+
)
|
| 1926 |
|
| 1927 |
# -- Train --
|
| 1928 |
ft_train_btn = gr.Button(
|
|
|
|
| 1996 |
cmp_epochs = gr.Number(label="Epochs", value=3, minimum=1, precision=0)
|
| 1997 |
cmp_batch = gr.Number(label="Batch Size", value=8, minimum=1, precision=0)
|
| 1998 |
cmp_lr = gr.Number(label="Learning Rate", value=2e-5, minimum=1e-7)
|
| 1999 |
+
with gr.Row():
|
| 2000 |
+
cmp_use_lora = gr.Checkbox(label="Use LoRA", value=False)
|
| 2001 |
+
cmp_lora_rank = gr.Number(label="LoRA Rank", value=8, minimum=1, maximum=256, precision=0)
|
| 2002 |
+
cmp_lora_alpha = gr.Number(label="LoRA Alpha", value=16, minimum=1, maximum=512, precision=0)
|
| 2003 |
cmp_btn = gr.Button("Compare Models", variant="primary")
|
| 2004 |
cmp_log = gr.Textbox(
|
| 2005 |
label="Comparison Log", lines=8,
|
|
|
|
| 2012 |
cmp_plot = gr.Plot(label="Metrics Comparison")
|
| 2013 |
cmp_roc = gr.Plot(label="ROC Curves")
|
| 2014 |
|
| 2015 |
+
# ================================================================
|
| 2016 |
+
# ACTIVE LEARNING TAB
|
| 2017 |
+
# ================================================================
|
| 2018 |
+
with gr.Tab("Active Learning"):
|
| 2019 |
+
gr.Markdown(info_callout(
|
| 2020 |
+
"**Active learning** iteratively selects the most uncertain "
|
| 2021 |
+
"samples from an unlabeled pool for you to label, then retrains. "
|
| 2022 |
+
"This lets you build a strong classifier with far fewer labels."
|
| 2023 |
+
))
|
| 2024 |
+
|
| 2025 |
+
# -- Data --
|
| 2026 |
+
gr.Markdown("### Data")
|
| 2027 |
+
gr.Markdown(
|
| 2028 |
+
"**Seed file** — small labeled set (TSV, `text[TAB]label`). \n"
|
| 2029 |
+
"**Pool file** — unlabeled texts (one per line, or CSV with `text` column). \n"
|
| 2030 |
+
"**Dev file** *(optional)* — held-out labeled set to track metrics."
|
| 2031 |
+
)
|
| 2032 |
+
al_ex_btn = gr.Button(
|
| 2033 |
+
"Load Example: Binary Active Learning",
|
| 2034 |
+
variant="secondary", size="sm",
|
| 2035 |
+
)
|
| 2036 |
+
with gr.Row():
|
| 2037 |
+
al_seed_file = gr.File(
|
| 2038 |
+
label="Labeled Seed (TSV)",
|
| 2039 |
+
file_types=[".tsv", ".csv", ".txt"],
|
| 2040 |
+
)
|
| 2041 |
+
al_pool_file = gr.File(
|
| 2042 |
+
label="Unlabeled Pool",
|
| 2043 |
+
file_types=[".tsv", ".csv", ".txt"],
|
| 2044 |
+
)
|
| 2045 |
+
al_dev_file = gr.File(
|
| 2046 |
+
label="Dev / Validation (optional)",
|
| 2047 |
+
file_types=[".tsv", ".csv", ".txt"],
|
| 2048 |
+
)
|
| 2049 |
+
|
| 2050 |
+
# -- Configuration --
|
| 2051 |
+
gr.Markdown("### Configuration")
|
| 2052 |
+
with gr.Row():
|
| 2053 |
+
al_task = gr.Radio(
|
| 2054 |
+
["Binary", "Multiclass"],
|
| 2055 |
+
label="Task Type", value="Binary",
|
| 2056 |
+
)
|
| 2057 |
+
al_model_dd = gr.Dropdown(
|
| 2058 |
+
choices=list(FINETUNE_MODELS.keys()),
|
| 2059 |
+
label="Base Model",
|
| 2060 |
+
value=list(FINETUNE_MODELS.keys())[0],
|
| 2061 |
+
)
|
| 2062 |
+
with gr.Row():
|
| 2063 |
+
al_strategy = gr.Dropdown(
|
| 2064 |
+
["entropy", "margin", "least_confidence"],
|
| 2065 |
+
label="Query Strategy", value="entropy",
|
| 2066 |
+
)
|
| 2067 |
+
al_query_size = gr.Number(
|
| 2068 |
+
label="Samples per Round", value=20,
|
| 2069 |
+
minimum=1, maximum=500, precision=0,
|
| 2070 |
+
)
|
| 2071 |
+
with gr.Row():
|
| 2072 |
+
al_epochs = gr.Number(
|
| 2073 |
+
label="Epochs per Round", value=3,
|
| 2074 |
+
minimum=1, maximum=50, precision=0,
|
| 2075 |
+
)
|
| 2076 |
+
al_batch_size = gr.Number(
|
| 2077 |
+
label="Batch Size", value=8,
|
| 2078 |
+
minimum=1, maximum=128, precision=0,
|
| 2079 |
+
)
|
| 2080 |
+
al_lr = gr.Number(
|
| 2081 |
+
label="Learning Rate", value=2e-5,
|
| 2082 |
+
minimum=1e-7, maximum=1e-2,
|
| 2083 |
+
)
|
| 2084 |
+
with gr.Accordion("Advanced", open=False):
|
| 2085 |
+
with gr.Row():
|
| 2086 |
+
al_max_len = gr.Number(
|
| 2087 |
+
label="Max Sequence Length", value=512,
|
| 2088 |
+
minimum=32, maximum=8192, precision=0,
|
| 2089 |
+
)
|
| 2090 |
+
al_use_lora = gr.Checkbox(label="Use LoRA", value=False)
|
| 2091 |
+
al_lora_rank = gr.Number(
|
| 2092 |
+
label="LoRA Rank", value=8,
|
| 2093 |
+
minimum=1, maximum=256, precision=0,
|
| 2094 |
+
)
|
| 2095 |
+
al_lora_alpha = gr.Number(
|
| 2096 |
+
label="LoRA Alpha", value=16,
|
| 2097 |
+
minimum=1, maximum=512, precision=0,
|
| 2098 |
+
)
|
| 2099 |
+
|
| 2100 |
+
al_init_btn = gr.Button(
|
| 2101 |
+
"Initialize Active Learning", variant="primary", size="lg",
|
| 2102 |
+
)
|
| 2103 |
+
|
| 2104 |
+
# -- State --
|
| 2105 |
+
al_state = gr.State({})
|
| 2106 |
+
al_model_state = gr.State(None)
|
| 2107 |
+
al_tokenizer_state = gr.State(None)
|
| 2108 |
+
|
| 2109 |
+
with gr.Accordion("Log", open=False):
|
| 2110 |
+
al_log = gr.Textbox(
|
| 2111 |
+
lines=12, interactive=False, elem_classes="log-output",
|
| 2112 |
+
show_label=False,
|
| 2113 |
+
)
|
| 2114 |
+
|
| 2115 |
+
# -- Annotation panel (hidden until init) --
|
| 2116 |
+
with gr.Column(visible=False) as al_annotation_col:
|
| 2117 |
+
gr.Markdown("### Label These Samples")
|
| 2118 |
+
gr.Markdown(
|
| 2119 |
+
"Fill in the **Label** column with integer class labels "
|
| 2120 |
+
"(e.g. 0 or 1 for binary). Then click **Submit**."
|
| 2121 |
+
)
|
| 2122 |
+
al_annotation_df = gr.Dataframe(
|
| 2123 |
+
headers=["Text", "Label"],
|
| 2124 |
+
interactive=True,
|
| 2125 |
+
wrap=True,
|
| 2126 |
+
row_count=(1, "dynamic"),
|
| 2127 |
+
)
|
| 2128 |
+
with gr.Row():
|
| 2129 |
+
al_submit_btn = gr.Button(
|
| 2130 |
+
"Submit Labels & Next Round",
|
| 2131 |
+
variant="primary",
|
| 2132 |
+
)
|
| 2133 |
+
|
| 2134 |
+
al_chart = gr.Plot(label="Metrics Across Rounds")
|
| 2135 |
+
|
| 2136 |
+
gr.Markdown("### Save Model")
|
| 2137 |
+
with gr.Row():
|
| 2138 |
+
al_save_path = gr.Textbox(
|
| 2139 |
+
label="Save Directory", value="./al_model",
|
| 2140 |
+
)
|
| 2141 |
+
al_save_btn = gr.Button("Save", variant="secondary")
|
| 2142 |
+
al_save_status = gr.Markdown("")
|
| 2143 |
+
|
| 2144 |
# ---- FOOTER ----
|
| 2145 |
gr.Markdown(
|
| 2146 |
"<div style='text-align: center; padding: 1rem 0; margin-top: 0.5rem; "
|
|
|
|
| 2227 |
ft_epochs, ft_batch, ft_lr,
|
| 2228 |
ft_weight_decay, ft_warmup, ft_max_len,
|
| 2229 |
ft_grad_accum, ft_fp16, ft_patience, ft_scheduler,
|
| 2230 |
+
ft_use_lora, ft_lora_rank, ft_lora_alpha, ft_use_qlora,
|
| 2231 |
],
|
| 2232 |
outputs=[
|
| 2233 |
ft_log, ft_metrics,
|
|
|
|
| 2258 |
outputs=[ft_batch_out],
|
| 2259 |
)
|
| 2260 |
|
| 2261 |
+
# Active Learning: example loader
|
| 2262 |
+
al_ex_btn.click(
|
| 2263 |
+
fn=load_example_active_learning,
|
| 2264 |
+
outputs=[al_seed_file, al_pool_file, al_dev_file, al_task],
|
| 2265 |
+
)
|
| 2266 |
+
|
| 2267 |
+
# Active Learning
|
| 2268 |
+
al_init_btn.click(
|
| 2269 |
+
fn=al_initialize,
|
| 2270 |
+
inputs=[
|
| 2271 |
+
al_seed_file, al_pool_file, al_dev_file,
|
| 2272 |
+
al_task, al_model_dd, al_strategy, al_query_size,
|
| 2273 |
+
al_epochs, al_batch_size, al_lr, al_max_len,
|
| 2274 |
+
al_use_lora, al_lora_rank, al_lora_alpha,
|
| 2275 |
+
],
|
| 2276 |
+
outputs=[
|
| 2277 |
+
al_state, al_model_state, al_tokenizer_state,
|
| 2278 |
+
al_annotation_df, al_log, al_chart,
|
| 2279 |
+
al_annotation_col,
|
| 2280 |
+
],
|
| 2281 |
+
concurrency_limit=1,
|
| 2282 |
+
)
|
| 2283 |
+
|
| 2284 |
+
al_submit_btn.click(
|
| 2285 |
+
fn=al_submit_and_continue,
|
| 2286 |
+
inputs=[
|
| 2287 |
+
al_annotation_df, al_state, al_model_state, al_tokenizer_state,
|
| 2288 |
+
al_log,
|
| 2289 |
+
],
|
| 2290 |
+
outputs=[
|
| 2291 |
+
al_state, al_model_state, al_tokenizer_state,
|
| 2292 |
+
al_annotation_df, al_log, al_chart,
|
| 2293 |
+
],
|
| 2294 |
+
concurrency_limit=1,
|
| 2295 |
+
)
|
| 2296 |
+
|
| 2297 |
+
al_save_btn.click(
|
| 2298 |
+
fn=al_save_model,
|
| 2299 |
+
inputs=[al_save_path, al_model_state, al_tokenizer_state],
|
| 2300 |
+
outputs=[al_save_status],
|
| 2301 |
+
)
|
| 2302 |
+
|
| 2303 |
# Model comparison
|
| 2304 |
cmp_btn.click(
|
| 2305 |
fn=run_comparison,
|
| 2306 |
inputs=[
|
| 2307 |
ft_train_file, ft_dev_file, ft_test_file,
|
| 2308 |
ft_task, cmp_models, cmp_epochs, cmp_batch, cmp_lr,
|
| 2309 |
+
cmp_use_lora, cmp_lora_rank, cmp_lora_alpha,
|
| 2310 |
],
|
| 2311 |
outputs=[cmp_log, cmp_table, cmp_plot, cmp_roc, cmp_results_col],
|
| 2312 |
concurrency_limit=1,
|
examples/active_learning/pool.txt
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
A car bomb exploded near a military checkpoint killing at least twelve soldiers
|
| 2 |
+
The oceanographic institute published research on coral reef restoration
|
| 3 |
+
Annual tourism numbers reached an all-time high at the coastal resorts
|
| 4 |
+
Gunmen opened fire on a convoy of government officials killing two bodyguards
|
| 5 |
+
Cross-border shelling between the two nations continued for the third consecutive day
|
| 6 |
+
Public transit ridership increased following improvements to the subway system
|
| 7 |
+
Armed bandits attacked a refugee camp displacing thousands of people
|
| 8 |
+
The guerrilla fighters ambushed a supply convoy on the main highway
|
| 9 |
+
The bakery chain announced plans to expand into twelve new locations
|
| 10 |
+
The city hosted a successful international food and wine festival
|
| 11 |
+
The university announced a new scholarship program for students in engineering
|
| 12 |
+
A new study found that regular exercise significantly reduces heart disease risk
|
| 13 |
+
The national football team secured a convincing victory in the qualifying match
|
| 14 |
+
The rebel forces captured a strategic town after weeks of intense battles
|
| 15 |
+
The solar energy project is expected to power thousands of homes by year end
|
| 16 |
+
Insurgents attacked a police station in the capital overnight leaving several officers wounded
|
| 17 |
+
Military helicopters were deployed to support ground troops fighting in the eastern region
|
| 18 |
+
The opposition forces breached the defensive perimeter around the government compound
|
| 19 |
+
A mortar attack on the military base resulted in significant casualties
|
| 20 |
+
The technology company unveiled its latest smartphone with improved camera capabilities
|
| 21 |
+
Government aircraft bombed suspected rebel strongholds in the mountainous region
|
| 22 |
+
Security forces conducted raids targeting suspected members of the armed opposition
|
| 23 |
+
A suicide bomber detonated explosives at a crowded marketplace injuring dozens of civilians
|
| 24 |
+
Local farmers reported an excellent harvest this season due to favorable weather
|
| 25 |
+
The agricultural ministry launched a program to support organic farming
|
| 26 |
+
An airstrike destroyed a weapons depot used by the insurgent group
|
| 27 |
+
The orchestra performed a sold-out concert of works by contemporary composers
|
| 28 |
+
The opposing forces exchanged heavy gunfire throughout the night
|
| 29 |
+
The gaming company released a new title that quickly became a bestseller
|
| 30 |
+
The swimming team broke the national record at the regional championships
|
| 31 |
+
The government declared a state of emergency following widespread political violence
|
| 32 |
+
The city council approved plans for a new public park in the downtown area
|
| 33 |
+
Government forces launched an offensive against rebel positions in the northern province early this morning
|
| 34 |
+
Researchers published findings on a promising treatment for a rare disorder
|
| 35 |
+
A popular streaming service announced an original series based on the classic novel
|
| 36 |
+
A drone strike targeted a meeting of senior militant commanders
|
| 37 |
+
The winter ski season opened early due to heavy snowfall in the mountains
|
| 38 |
+
The museum opened a new exhibition showcasing contemporary sculpture and painting
|
| 39 |
+
The separatist movement launched coordinated attacks on government installations
|
| 40 |
+
An improvised explosive device was found near the parliament building
|
| 41 |
+
The film festival announced its lineup featuring works from emerging directors
|
| 42 |
+
Temperatures are expected to reach record highs this weekend according to forecasters
|
| 43 |
+
Scientists discovered a new species of deep-sea fish in the Pacific Ocean
|
| 44 |
+
Two soldiers were killed when their vehicle struck a landmine on a rural road
|
| 45 |
+
The cycling tour attracted international competitors to the coastal route
|
| 46 |
+
The pharmaceutical company received approval for a new vaccine formulation
|
| 47 |
+
A grenade attack on a busy intersection killed four people and wounded many more
|
| 48 |
+
Ethnic tensions erupted into open violence as rival communities clashed in the market
|
| 49 |
+
The hospital inaugurated a state-of-the-art wing dedicated to pediatric care
|
| 50 |
+
The cookbook featuring traditional regional recipes became an unexpected bestseller
|
| 51 |
+
Heavy fighting broke out between rival armed factions in the disputed border region
|
| 52 |
+
The automotive company revealed plans to launch three new electric vehicle models
|
| 53 |
+
Armed men attacked a village killing several residents and burning homes
|
| 54 |
+
The annual science fair showcased innovative projects by high school students
|
| 55 |
+
A roadside bomb targeted a military patrol wounding three soldiers
|
| 56 |
+
An explosion at a government building was attributed to opposition fighters
|
| 57 |
+
Armed opposition forces shelled the outskirts of the capital city
|
| 58 |
+
Fighting between government troops and rebels displaced thousands of families
|
| 59 |
+
Coalition forces conducted a night raid capturing several high-value targets
|
| 60 |
+
The militant group claimed responsibility for the ambush on a military convoy
|
| 61 |
+
The armed group kidnapped aid workers operating in the conflict zone
|
examples/active_learning/pool_with_labels.tsv
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
A car bomb exploded near a military checkpoint killing at least twelve soldiers 1
|
| 2 |
+
The oceanographic institute published research on coral reef restoration 0
|
| 3 |
+
Annual tourism numbers reached an all-time high at the coastal resorts 0
|
| 4 |
+
Gunmen opened fire on a convoy of government officials killing two bodyguards 1
|
| 5 |
+
Cross-border shelling between the two nations continued for the third consecutive day 1
|
| 6 |
+
Public transit ridership increased following improvements to the subway system 0
|
| 7 |
+
Armed bandits attacked a refugee camp displacing thousands of people 1
|
| 8 |
+
The guerrilla fighters ambushed a supply convoy on the main highway 1
|
| 9 |
+
The bakery chain announced plans to expand into twelve new locations 0
|
| 10 |
+
The city hosted a successful international food and wine festival 0
|
| 11 |
+
The university announced a new scholarship program for students in engineering 0
|
| 12 |
+
A new study found that regular exercise significantly reduces heart disease risk 0
|
| 13 |
+
The national football team secured a convincing victory in the qualifying match 0
|
| 14 |
+
The rebel forces captured a strategic town after weeks of intense battles 1
|
| 15 |
+
The solar energy project is expected to power thousands of homes by year end 0
|
| 16 |
+
Insurgents attacked a police station in the capital overnight leaving several officers wounded 1
|
| 17 |
+
Military helicopters were deployed to support ground troops fighting in the eastern region 1
|
| 18 |
+
The opposition forces breached the defensive perimeter around the government compound 1
|
| 19 |
+
A mortar attack on the military base resulted in significant casualties 1
|
| 20 |
+
The technology company unveiled its latest smartphone with improved camera capabilities 0
|
| 21 |
+
Government aircraft bombed suspected rebel strongholds in the mountainous region 1
|
| 22 |
+
Security forces conducted raids targeting suspected members of the armed opposition 1
|
| 23 |
+
A suicide bomber detonated explosives at a crowded marketplace injuring dozens of civilians 1
|
| 24 |
+
Local farmers reported an excellent harvest this season due to favorable weather 0
|
| 25 |
+
The agricultural ministry launched a program to support organic farming 0
|
| 26 |
+
An airstrike destroyed a weapons depot used by the insurgent group 1
|
| 27 |
+
The orchestra performed a sold-out concert of works by contemporary composers 0
|
| 28 |
+
The opposing forces exchanged heavy gunfire throughout the night 1
|
| 29 |
+
The gaming company released a new title that quickly became a bestseller 0
|
| 30 |
+
The swimming team broke the national record at the regional championships 0
|
| 31 |
+
The government declared a state of emergency following widespread political violence 1
|
| 32 |
+
The city council approved plans for a new public park in the downtown area 0
|
| 33 |
+
Government forces launched an offensive against rebel positions in the northern province early this morning 1
|
| 34 |
+
Researchers published findings on a promising treatment for a rare disorder 0
|
| 35 |
+
A popular streaming service announced an original series based on the classic novel 0
|
| 36 |
+
A drone strike targeted a meeting of senior militant commanders 1
|
| 37 |
+
The winter ski season opened early due to heavy snowfall in the mountains 0
|
| 38 |
+
The museum opened a new exhibition showcasing contemporary sculpture and painting 0
|
| 39 |
+
The separatist movement launched coordinated attacks on government installations 1
|
| 40 |
+
An improvised explosive device was found near the parliament building 1
|
| 41 |
+
The film festival announced its lineup featuring works from emerging directors 0
|
| 42 |
+
Temperatures are expected to reach record highs this weekend according to forecasters 0
|
| 43 |
+
Scientists discovered a new species of deep-sea fish in the Pacific Ocean 0
|
| 44 |
+
Two soldiers were killed when their vehicle struck a landmine on a rural road 1
|
| 45 |
+
The cycling tour attracted international competitors to the coastal route 0
|
| 46 |
+
The pharmaceutical company received approval for a new vaccine formulation 0
|
| 47 |
+
A grenade attack on a busy intersection killed four people and wounded many more 1
|
| 48 |
+
Ethnic tensions erupted into open violence as rival communities clashed in the market 1
|
| 49 |
+
The hospital inaugurated a state-of-the-art wing dedicated to pediatric care 0
|
| 50 |
+
The cookbook featuring traditional regional recipes became an unexpected bestseller 0
|
| 51 |
+
Heavy fighting broke out between rival armed factions in the disputed border region 1
|
| 52 |
+
The automotive company revealed plans to launch three new electric vehicle models 0
|
| 53 |
+
Armed men attacked a village killing several residents and burning homes 1
|
| 54 |
+
The annual science fair showcased innovative projects by high school students 0
|
| 55 |
+
A roadside bomb targeted a military patrol wounding three soldiers 1
|
| 56 |
+
An explosion at a government building was attributed to opposition fighters 1
|
| 57 |
+
Armed opposition forces shelled the outskirts of the capital city 1
|
| 58 |
+
Fighting between government troops and rebels displaced thousands of families 1
|
| 59 |
+
Coalition forces conducted a night raid capturing several high-value targets 1
|
| 60 |
+
The militant group claimed responsibility for the ambush on a military convoy 1
|
| 61 |
+
The armed group kidnapped aid workers operating in the conflict zone 1
|
examples/active_learning/seed.tsv
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Astronomers observed a rare celestial event visible from the southern hemisphere 0
|
| 2 |
+
Sniper fire killed two civilians in the besieged neighborhood 1
|
| 3 |
+
A major software update was released improving performance and adding new features 0
|
| 4 |
+
The airline announced new direct flights connecting the capital with European cities 0
|
| 5 |
+
Security operations intensified after a series of bombings in the commercial district 1
|
| 6 |
+
Protesters clashed violently with police during demonstrations against the military regime 1
|
| 7 |
+
A popular author released the highly anticipated sequel to her bestselling novel 0
|
| 8 |
+
Artillery shells struck residential areas as the conflict between the two sides intensified 1
|
| 9 |
+
Archaeologists uncovered ancient pottery at a dig site near the monument 0
|
| 10 |
+
The military junta deployed tanks to suppress the growing resistance movement 1
|
| 11 |
+
The electric vehicle charging network expanded to cover all major highways 0
|
| 12 |
+
The marathon attracted over twenty thousand runners from across the country 0
|
| 13 |
+
The ongoing civil war has resulted in thousands of casualties and widespread destruction 1
|
| 14 |
+
The dairy industry adopted new standards for sustainable milk production 0
|
| 15 |
+
Paramilitary groups carried out targeted assassinations of political opponents 1
|
| 16 |
+
A local nonprofit organized a community cleanup event at the riverside park 0
|
| 17 |
+
The construction of the new high-speed rail line is ahead of schedule 0
|
| 18 |
+
Stock markets rallied on news of stronger than expected economic growth 0
|
| 19 |
+
The tech startup raised significant funding in its latest investment round 0
|
| 20 |
+
A militia group took control of a key oil facility in the contested region 1
|
requirements.txt
CHANGED
|
@@ -10,3 +10,5 @@ accelerate
|
|
| 10 |
scikit-learn
|
| 11 |
pandas
|
| 12 |
plotly
|
|
|
|
|
|
|
|
|
| 10 |
scikit-learn
|
| 11 |
pandas
|
| 12 |
plotly
|
| 13 |
+
peft>=0.6
|
| 14 |
+
bitsandbytes
|