Shreyas Meher commited on
Commit
2be141f
·
1 Parent(s): c0557bb

Add LoRA/QLoRA fine-tuning and active learning

Browse files

- LoRA/QLoRA support in Fine-tune tab (via PEFT) for parameter-efficient
training with lower VRAM usage; also available in model comparison
- New Active Learning tab with iterative uncertainty-based labeling:
entropy, margin, and least-confidence query strategies, round-by-round
metrics chart, and example dataset
- Add peft and bitsandbytes to requirements.txt
- Update README with new features and quick-start guides
- Fix .gitignore to exclude conflibertr/ directory

.gitignore CHANGED
@@ -6,6 +6,8 @@ env/
6
  __pycache__/
7
  *.pyc
8
  conflibertR/
 
 
9
  screenshots/
10
  finetuned_model/
11
  ft_output/
 
6
  __pycache__/
7
  *.pyc
8
  conflibertR/
9
+ conflibertr/
10
+ al_model/
11
  screenshots/
12
  finetuned_model/
13
  ft_output/
README.md CHANGED
@@ -41,7 +41,7 @@ Provide a context passage and a question. The model extracts the most relevant a
41
 
42
  ### Fine-tuning
43
 
44
- Train your own binary or multiclass classifier directly in the browser. Upload data (or load a built-in example), pick a base model, configure training, and go. After training, results and a "Try Your Model" panel appear side by side. You can also save the model and run batch predictions.
45
 
46
  ### Model Comparison
47
 
@@ -50,7 +50,9 @@ Compare multiple base model architectures on the same dataset. The comparison pr
50
  <!-- Take a screenshot of the Fine-tune tab and save as screenshots/finetune.png -->
51
  ![Fine-tune](./screenshots/finetune.png)
52
 
 
53
 
 
54
 
55
  ## Supported Models
56
 
@@ -144,7 +146,8 @@ Opens at `http://localhost:7860` and generates a public shareable link. The firs
144
  | Binary Classification | Conflict vs. non-conflict, supports custom models |
145
  | Multilabel Classification | Multi-event-type scoring |
146
  | Question Answering | Extract answers from a context passage |
147
- | Fine-tune | Train classifiers, compare models, ROC curves |
 
148
 
149
  ### Fine-tuning Quick Start
150
 
@@ -154,6 +157,13 @@ Opens at `http://localhost:7860` and generates a public shareable link. The firs
154
  4. Review metrics and try your model on new text
155
  5. Save the model and load it in the **Binary Classification** tab
156
 
 
 
 
 
 
 
 
157
  ### Model Comparison Quick Start
158
 
159
  1. Upload data (or load an example) in the **Fine-tune** tab
@@ -162,6 +172,16 @@ Opens at `http://localhost:7860` and generates a public shareable link. The firs
162
  4. Click **"Compare Models"**
163
  5. View the metrics table, bar chart, and ROC-AUC curves
164
 
 
 
 
 
 
 
 
 
 
 
165
  ### Data Format
166
 
167
  Tab-separated values (TSV), no header row. Each line: `text<TAB>label`
@@ -209,10 +229,16 @@ conflibert-gui/
209
  train.tsv # 0=Diplomacy, 1=Armed Conflict,
210
  dev.tsv # 2=Protest, 3=Humanitarian
211
  test.tsv
 
 
 
 
212
  ```
213
 
214
  ## Training Features
215
 
 
 
216
  - Early stopping with configurable patience
217
  - Learning rate schedulers: linear, cosine, constant, constant with warmup
218
  - Mixed precision training (FP16) on CUDA GPUs
 
41
 
42
  ### Fine-tuning
43
 
44
+ Train your own binary or multiclass classifier directly in the browser. Upload data (or load a built-in example), pick a base model, configure training, and go. Supports **LoRA** and **QLoRA** for parameter-efficient training with lower VRAM usage. After training, results and a "Try Your Model" panel appear side by side. You can also save the model and run batch predictions.
45
 
46
  ### Model Comparison
47
 
 
50
  <!-- Take a screenshot of the Fine-tune tab and save as screenshots/finetune.png -->
51
  ![Fine-tune](./screenshots/finetune.png)
52
 
53
+ ### Active Learning
54
 
55
+ Iteratively build a strong classifier with fewer labels. Start with a small labeled seed set and a pool of unlabeled text. The model identifies the most uncertain samples for you to label, retrains, and repeats. Supports entropy, margin, and least-confidence query strategies.
56
 
57
  ## Supported Models
58
 
 
146
  | Binary Classification | Conflict vs. non-conflict, supports custom models |
147
  | Multilabel Classification | Multi-event-type scoring |
148
  | Question Answering | Extract answers from a context passage |
149
+ | Fine-tune | Train classifiers with optional LoRA/QLoRA, compare models, ROC curves |
150
+ | Active Learning | Iterative uncertainty-based labeling and retraining |
151
 
152
  ### Fine-tuning Quick Start
153
 
 
157
  4. Review metrics and try your model on new text
158
  5. Save the model and load it in the **Binary Classification** tab
159
 
160
+ ### LoRA / QLoRA Fine-tuning
161
+
162
+ 1. Go to the **Fine-tune** tab
163
+ 2. Open **Advanced Settings** and check **Use LoRA** (optionally enable **QLoRA** for 4-bit quantization on CUDA GPUs)
164
+ 3. Adjust LoRA rank and alpha as needed (defaults of r=8, alpha=16 work well)
165
+ 4. Train as usual — LoRA weights are merged back automatically so the saved model works like any other
166
+
167
  ### Model Comparison Quick Start
168
 
169
  1. Upload data (or load an example) in the **Fine-tune** tab
 
172
  4. Click **"Compare Models"**
173
  5. View the metrics table, bar chart, and ROC-AUC curves
174
 
175
+ ### Active Learning Quick Start
176
+
177
+ 1. Go to the **Active Learning** tab
178
+ 2. Click **"Load Example: Binary Active Learning"** (or upload your own seed + pool)
179
+ 3. Configure the query strategy and samples per round
180
+ 4. Click **"Initialize Active Learning"**
181
+ 5. Label the uncertain samples shown in the table (fill in 0 or 1)
182
+ 6. Click **"Submit Labels & Next Round"** to retrain and get the next batch
183
+ 7. Repeat until satisfied, then save the model
184
+
185
  ### Data Format
186
 
187
  Tab-separated values (TSV), no header row. Each line: `text<TAB>label`
 
229
  train.tsv # 0=Diplomacy, 1=Armed Conflict,
230
  dev.tsv # 2=Protest, 3=Humanitarian
231
  test.tsv
232
+ active_learning/ # Example active learning dataset
233
+ seed.tsv # 20 labeled seed samples
234
+ pool.txt # 61 unlabeled pool texts
235
+ pool_with_labels.tsv # Ground truth for pool (cheat sheet)
236
  ```
237
 
238
  ## Training Features
239
 
240
+ - **LoRA / QLoRA** parameter-efficient fine-tuning (via [PEFT](https://github.com/huggingface/peft))
241
+ - **Active learning** with entropy, margin, and least-confidence query strategies
242
  - Early stopping with configurable patience
243
  - Learning rate schedulers: linear, cosine, constant, constant with warmup
244
  - Mixed precision training (FP16) on CUDA GPUs
app.py CHANGED
@@ -45,6 +45,19 @@ from sklearn.preprocessing import label_binarize
45
  from torch.utils.data import Dataset as TorchDataset
46
  import gc
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # ============================================================================
50
  # CONFIGURATION
@@ -613,6 +626,7 @@ def run_finetuning(
613
  train_file, dev_file, test_file, task_type, model_display_name,
614
  epochs, batch_size, lr, weight_decay, warmup_ratio, max_seq_len,
615
  grad_accum, fp16, patience, scheduler,
 
616
  progress=gr.Progress(track_tqdm=True),
617
  ):
618
  """Main finetuning function. Returns logs, metrics, model state, and visibility updates."""
@@ -644,9 +658,42 @@ def run_finetuning(
644
  # Load model and tokenizer
645
  model_id = FINETUNE_MODELS[model_display_name]
646
  tokenizer = AutoTokenizer.from_pretrained(model_id)
647
- model = AutoModelForSequenceClassification.from_pretrained(
648
- model_id, num_labels=num_labels
649
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
 
651
  # Create datasets
652
  train_ds = TextClassificationDataset(
@@ -709,6 +756,10 @@ def run_finetuning(
709
  test_results = trainer.evaluate(test_ds, metric_key_prefix='test')
710
 
711
  # Build log text
 
 
 
 
712
  header = (
713
  f"=== Configuration ===\n"
714
  f"Model: {model_display_name}\n"
@@ -716,6 +767,7 @@ def run_finetuning(
716
  f"Task: {task_type} Classification ({num_labels} classes)\n"
717
  f"Data: {len(train_texts)} train / {len(dev_texts)} dev / {len(test_texts)} test\n"
718
  f"Epochs: {epochs} Batch: {batch_size} LR: {lr} Scheduler: {scheduler}\n"
 
719
  f"\n=== Training Log ===\n"
720
  )
721
  runtime = train_result.metrics.get('train_runtime', 0)
@@ -733,8 +785,11 @@ def run_finetuning(
733
  metrics_data.append([name, f"{float(v):.4f}"])
734
  metrics_df = pd.DataFrame(metrics_data, columns=['Metric', 'Score'])
735
 
736
- # Move trained model to CPU for inference
737
- trained_model = trainer.model.cpu()
 
 
 
738
  trained_model.eval()
739
 
740
  return (
@@ -864,9 +919,412 @@ def load_example_multiclass():
864
  )
865
 
866
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867
  def run_comparison(
868
  train_file, dev_file, test_file, task_type, selected_models,
869
- epochs, batch_size, lr,
870
  progress=gr.Progress(track_tqdm=True),
871
  ):
872
  """Train multiple models on the same data and compare performance + ROC curves."""
@@ -913,6 +1371,18 @@ def run_comparison(
913
  model = AutoModelForSequenceClassification.from_pretrained(
914
  model_id, num_labels=num_labels,
915
  )
 
 
 
 
 
 
 
 
 
 
 
 
916
  train_ds = TextClassificationDataset(train_texts, train_labels, tokenizer, 512)
917
  dev_ds = TextClassificationDataset(dev_texts, dev_labels, tokenizer, 512)
918
  test_ds = TextClassificationDataset(test_texts, test_labels, tokenizer, 512)
@@ -949,6 +1419,10 @@ def run_comparison(
949
 
950
  train_result = trainer.train()
951
 
 
 
 
 
952
  # Get predictions for ROC curves
953
  pred_output = trainer.predict(test_ds)
954
  logits = pred_output.predictions
@@ -1194,6 +1668,10 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1194
  "8. Save the model and load it later in the "
1195
  "Classification tab\n\n"
1196
  "**Advanced features:**\n"
 
 
 
 
1197
  "- Early stopping with configurable patience\n"
1198
  "- Learning rate schedulers (linear, cosine, constant)\n"
1199
  "- Mixed precision training (FP16 on CUDA GPUs)\n"
@@ -1429,6 +1907,22 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1429
  ["linear", "cosine", "constant", "constant_with_warmup"],
1430
  label="LR Scheduler", value="linear",
1431
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1432
 
1433
  # -- Train --
1434
  ft_train_btn = gr.Button(
@@ -1502,6 +1996,10 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1502
  cmp_epochs = gr.Number(label="Epochs", value=3, minimum=1, precision=0)
1503
  cmp_batch = gr.Number(label="Batch Size", value=8, minimum=1, precision=0)
1504
  cmp_lr = gr.Number(label="Learning Rate", value=2e-5, minimum=1e-7)
 
 
 
 
1505
  cmp_btn = gr.Button("Compare Models", variant="primary")
1506
  cmp_log = gr.Textbox(
1507
  label="Comparison Log", lines=8,
@@ -1514,6 +2012,135 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1514
  cmp_plot = gr.Plot(label="Metrics Comparison")
1515
  cmp_roc = gr.Plot(label="ROC Curves")
1516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1517
  # ---- FOOTER ----
1518
  gr.Markdown(
1519
  "<div style='text-align: center; padding: 1rem 0; margin-top: 0.5rem; "
@@ -1600,6 +2227,7 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1600
  ft_epochs, ft_batch, ft_lr,
1601
  ft_weight_decay, ft_warmup, ft_max_len,
1602
  ft_grad_accum, ft_fp16, ft_patience, ft_scheduler,
 
1603
  ],
1604
  outputs=[
1605
  ft_log, ft_metrics,
@@ -1630,12 +2258,55 @@ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1630
  outputs=[ft_batch_out],
1631
  )
1632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1633
  # Model comparison
1634
  cmp_btn.click(
1635
  fn=run_comparison,
1636
  inputs=[
1637
  ft_train_file, ft_dev_file, ft_test_file,
1638
  ft_task, cmp_models, cmp_epochs, cmp_batch, cmp_lr,
 
1639
  ],
1640
  outputs=[cmp_log, cmp_table, cmp_plot, cmp_roc, cmp_results_col],
1641
  concurrency_limit=1,
 
45
  from torch.utils.data import Dataset as TorchDataset
46
  import gc
47
 
48
+ # LoRA / QLoRA support (optional)
49
+ try:
50
+ from peft import LoraConfig, get_peft_model, TaskType
51
+ PEFT_AVAILABLE = True
52
+ except ImportError:
53
+ PEFT_AVAILABLE = False
54
+
55
+ try:
56
+ from transformers import BitsAndBytesConfig
57
+ BNB_AVAILABLE = True
58
+ except ImportError:
59
+ BNB_AVAILABLE = False
60
+
61
 
62
  # ============================================================================
63
  # CONFIGURATION
 
626
  train_file, dev_file, test_file, task_type, model_display_name,
627
  epochs, batch_size, lr, weight_decay, warmup_ratio, max_seq_len,
628
  grad_accum, fp16, patience, scheduler,
629
+ use_lora, lora_rank, lora_alpha, use_qlora,
630
  progress=gr.Progress(track_tqdm=True),
631
  ):
632
  """Main finetuning function. Returns logs, metrics, model state, and visibility updates."""
 
658
  # Load model and tokenizer
659
  model_id = FINETUNE_MODELS[model_display_name]
660
  tokenizer = AutoTokenizer.from_pretrained(model_id)
661
+
662
+ lora_active = False
663
+ if use_qlora:
664
+ if not (PEFT_AVAILABLE and BNB_AVAILABLE and torch.cuda.is_available()):
665
+ raise ValueError(
666
+ "QLoRA requires a CUDA GPU and the peft + bitsandbytes packages."
667
+ )
668
+ bnb_config = BitsAndBytesConfig(
669
+ load_in_4bit=True,
670
+ bnb_4bit_quant_type="nf4",
671
+ bnb_4bit_compute_dtype=torch.float16,
672
+ bnb_4bit_use_double_quant=True,
673
+ )
674
+ model = AutoModelForSequenceClassification.from_pretrained(
675
+ model_id, num_labels=num_labels, quantization_config=bnb_config,
676
+ )
677
+ else:
678
+ model = AutoModelForSequenceClassification.from_pretrained(
679
+ model_id, num_labels=num_labels,
680
+ )
681
+
682
+ if use_lora or use_qlora:
683
+ if not PEFT_AVAILABLE:
684
+ raise ValueError(
685
+ "LoRA requires the 'peft' package. Install: pip install peft"
686
+ )
687
+ lora_config = LoraConfig(
688
+ task_type=TaskType.SEQ_CLS,
689
+ r=int(lora_rank),
690
+ lora_alpha=int(lora_alpha),
691
+ lora_dropout=0.1,
692
+ bias="none",
693
+ )
694
+ model.enable_input_require_grads()
695
+ model = get_peft_model(model, lora_config)
696
+ lora_active = True
697
 
698
  # Create datasets
699
  train_ds = TextClassificationDataset(
 
756
  test_results = trainer.evaluate(test_ds, metric_key_prefix='test')
757
 
758
  # Build log text
759
+ lora_info = ""
760
+ if lora_active:
761
+ method = "QLoRA (4-bit)" if use_qlora else "LoRA"
762
+ lora_info = f"PEFT: {method} r={int(lora_rank)} alpha={int(lora_alpha)}\n"
763
  header = (
764
  f"=== Configuration ===\n"
765
  f"Model: {model_display_name}\n"
 
767
  f"Task: {task_type} Classification ({num_labels} classes)\n"
768
  f"Data: {len(train_texts)} train / {len(dev_texts)} dev / {len(test_texts)} test\n"
769
  f"Epochs: {epochs} Batch: {batch_size} LR: {lr} Scheduler: {scheduler}\n"
770
+ f"{lora_info}"
771
  f"\n=== Training Log ===\n"
772
  )
773
  runtime = train_result.metrics.get('train_runtime', 0)
 
785
  metrics_data.append([name, f"{float(v):.4f}"])
786
  metrics_df = pd.DataFrame(metrics_data, columns=['Metric', 'Score'])
787
 
788
+ # Merge LoRA weights back into base model for clean save/inference
789
+ trained_model = trainer.model
790
+ if lora_active and hasattr(trained_model, 'merge_and_unload'):
791
+ trained_model = trained_model.merge_and_unload()
792
+ trained_model = trained_model.cpu()
793
  trained_model.eval()
794
 
795
  return (
 
919
  )
920
 
921
 
922
+ # ============================================================================
923
+ # ACTIVE LEARNING
924
+ # ============================================================================
925
+
926
+ def parse_pool_file(file_path):
927
+ """Parse an unlabeled text pool. Accepts CSV with 'text' column, or one text per line."""
928
+ path = get_path(file_path)
929
+ # Try CSV/TSV with 'text' column first
930
+ try:
931
+ df = pd.read_csv(path)
932
+ if 'text' in df.columns:
933
+ texts = [str(t) for t in df['text'].dropna().tolist()]
934
+ if texts:
935
+ return texts
936
+ except Exception:
937
+ pass
938
+ # Fallback: one text per line
939
+ texts = []
940
+ with open(path, 'r', encoding='utf-8') as f:
941
+ for line in f:
942
+ line = line.strip()
943
+ if line:
944
+ texts.append(line)
945
+ if not texts:
946
+ raise ValueError("No texts found in pool file.")
947
+ return texts
948
+
949
+
950
+ def compute_uncertainty(model, tokenizer, texts, strategy='entropy',
951
+ max_seq_len=512, batch_size=32):
952
+ """Compute uncertainty scores for unlabeled texts. Higher = more uncertain."""
953
+ model.eval()
954
+ dev = next(model.parameters()).device
955
+ scores = []
956
+
957
+ for i in range(0, len(texts), batch_size):
958
+ batch_texts = texts[i:i + batch_size]
959
+ inputs = tokenizer(
960
+ batch_texts, return_tensors='pt', truncation=True,
961
+ padding=True, max_length=max_seq_len,
962
+ )
963
+ inputs = {k: v.to(dev) for k, v in inputs.items()}
964
+ with torch.no_grad():
965
+ logits = model(**inputs).logits
966
+ probs = torch.softmax(logits, dim=1).cpu().numpy()
967
+
968
+ if strategy == 'entropy':
969
+ s = -np.sum(probs * np.log(probs + 1e-10), axis=1)
970
+ elif strategy == 'margin':
971
+ sorted_p = np.sort(probs, axis=1)
972
+ s = -(sorted_p[:, -1] - sorted_p[:, -2])
973
+ else: # least_confidence
974
+ s = -np.max(probs, axis=1)
975
+ scores.extend(s.tolist())
976
+
977
+ return scores
978
+
979
+
980
+ def _build_al_metrics_chart(metrics_history, task_type):
981
+ """Build a Plotly chart of active-learning metrics across rounds."""
982
+ import plotly.graph_objects as go
983
+
984
+ if not metrics_history:
985
+ return None
986
+
987
+ rounds = [m['round'] for m in metrics_history]
988
+ train_sizes = [m.get('train_size', 0) for m in metrics_history]
989
+
990
+ metric_keys = (['f1', 'accuracy', 'precision', 'recall']
991
+ if task_type == 'Binary'
992
+ else ['f1_macro', 'accuracy'])
993
+
994
+ fig = go.Figure()
995
+ colors = ['#ff6b35', '#3b82f6', '#10b981', '#8b5cf6']
996
+
997
+ for i, key in enumerate(metric_keys):
998
+ values = [m.get(key) for m in metrics_history]
999
+ if any(v is not None for v in values):
1000
+ fig.add_trace(go.Scatter(
1001
+ x=rounds, y=values, mode='lines+markers',
1002
+ name=key.replace('_', ' ').title(),
1003
+ line=dict(color=colors[i % len(colors)], width=2),
1004
+ ))
1005
+
1006
+ fig.add_trace(go.Bar(
1007
+ x=rounds, y=train_sizes, name='Train Size',
1008
+ marker_color='rgba(200,200,200,0.4)', yaxis='y2',
1009
+ ))
1010
+
1011
+ fig.update_layout(
1012
+ xaxis_title='Round', yaxis_title='Score', yaxis_range=[0, 1.05],
1013
+ yaxis2=dict(title='Train Size', overlaying='y', side='right'),
1014
+ template='plotly_white',
1015
+ legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
1016
+ height=350, margin=dict(t=40, b=40),
1017
+ )
1018
+ return fig
1019
+
1020
+
1021
+ def _train_al_model(texts, labels, num_labels, dev_texts, dev_labels,
1022
+ task_type, model_id, epochs, batch_size, lr, max_seq_len,
1023
+ use_lora, lora_rank, lora_alpha):
1024
+ """Train a model for one active-learning round. Returns (model, tokenizer, eval_metrics)."""
1025
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
1026
+ model = AutoModelForSequenceClassification.from_pretrained(
1027
+ model_id, num_labels=num_labels,
1028
+ )
1029
+
1030
+ if use_lora and PEFT_AVAILABLE:
1031
+ lora_cfg = LoraConfig(
1032
+ task_type=TaskType.SEQ_CLS,
1033
+ r=int(lora_rank), lora_alpha=int(lora_alpha),
1034
+ lora_dropout=0.1, bias="none",
1035
+ )
1036
+ model.enable_input_require_grads()
1037
+ model = get_peft_model(model, lora_cfg)
1038
+
1039
+ train_ds = TextClassificationDataset(texts, labels, tokenizer, max_seq_len)
1040
+ dev_ds = None
1041
+ if dev_texts is not None:
1042
+ dev_ds = TextClassificationDataset(dev_texts, dev_labels, tokenizer, max_seq_len)
1043
+
1044
+ output_dir = tempfile.mkdtemp(prefix='conflibert_al_')
1045
+ training_args = TrainingArguments(
1046
+ output_dir=output_dir,
1047
+ num_train_epochs=epochs,
1048
+ per_device_train_batch_size=batch_size,
1049
+ per_device_eval_batch_size=batch_size * 2,
1050
+ learning_rate=lr,
1051
+ weight_decay=0.01,
1052
+ warmup_ratio=0.1,
1053
+ eval_strategy='epoch' if dev_ds else 'no',
1054
+ save_strategy='no',
1055
+ logging_steps=10,
1056
+ report_to='none',
1057
+ seed=42,
1058
+ )
1059
+
1060
+ trainer = Trainer(
1061
+ model=model,
1062
+ args=training_args,
1063
+ train_dataset=train_ds,
1064
+ eval_dataset=dev_ds,
1065
+ compute_metrics=make_compute_metrics(task_type) if dev_ds else None,
1066
+ )
1067
+ trainer.train()
1068
+
1069
+ eval_metrics = {}
1070
+ if dev_ds:
1071
+ results = trainer.evaluate()
1072
+ for k, v in results.items():
1073
+ if isinstance(v, (int, float, np.floating)):
1074
+ eval_metrics[k.replace('eval_', '')] = round(float(v), 4)
1075
+
1076
+ trained_model = trainer.model
1077
+ if use_lora and PEFT_AVAILABLE and hasattr(trained_model, 'merge_and_unload'):
1078
+ trained_model = trained_model.merge_and_unload()
1079
+
1080
+ return trained_model, tokenizer, eval_metrics
1081
+
1082
+
1083
+ def al_initialize(
1084
+ seed_file, pool_file, dev_file, task_type, model_display_name,
1085
+ query_strategy, query_size, epochs, batch_size, lr, max_seq_len,
1086
+ use_lora, lora_rank, lora_alpha,
1087
+ progress=gr.Progress(track_tqdm=True),
1088
+ ):
1089
+ """Initialize active learning: train on seed data, query first uncertain batch."""
1090
+ try:
1091
+ if seed_file is None or pool_file is None:
1092
+ raise ValueError("Upload both a labeled seed file and an unlabeled pool file.")
1093
+
1094
+ seed_texts, seed_labels, num_labels = parse_data_file(seed_file)
1095
+ pool_texts = parse_pool_file(pool_file)
1096
+
1097
+ dev_texts, dev_labels = None, None
1098
+ if dev_file is not None:
1099
+ dev_texts, dev_labels, _ = parse_data_file(dev_file)
1100
+
1101
+ if task_type == "Binary":
1102
+ num_labels = 2
1103
+
1104
+ query_size = int(query_size)
1105
+ model_id = FINETUNE_MODELS[model_display_name]
1106
+
1107
+ trained_model, tokenizer, eval_metrics = _train_al_model(
1108
+ seed_texts, seed_labels, num_labels, dev_texts, dev_labels,
1109
+ task_type, model_id, int(epochs), int(batch_size), lr,
1110
+ int(max_seq_len), use_lora, lora_rank, lora_alpha,
1111
+ )
1112
+
1113
+ # Build round-0 metrics
1114
+ round_metrics = {'round': 0, 'train_size': len(seed_texts)}
1115
+ round_metrics.update(eval_metrics)
1116
+
1117
+ # Query uncertain samples from pool
1118
+ scores = compute_uncertainty(
1119
+ trained_model, tokenizer, pool_texts, query_strategy, int(max_seq_len),
1120
+ )
1121
+ top_indices = np.argsort(scores)[-query_size:][::-1].tolist()
1122
+ query_texts_batch = [pool_texts[i] for i in top_indices]
1123
+
1124
+ annotation_df = pd.DataFrame({
1125
+ 'Text': query_texts_batch,
1126
+ 'Label': [''] * len(query_texts_batch),
1127
+ })
1128
+
1129
+ al_state = {
1130
+ 'labeled_texts': list(seed_texts),
1131
+ 'labeled_labels': list(seed_labels),
1132
+ 'pool_texts': pool_texts,
1133
+ 'pool_available': [i for i in range(len(pool_texts)) if i not in set(top_indices)],
1134
+ 'current_query_indices': top_indices,
1135
+ 'dev_texts': dev_texts,
1136
+ 'dev_labels': dev_labels,
1137
+ 'num_labels': num_labels,
1138
+ 'round': 1,
1139
+ 'metrics_history': [round_metrics],
1140
+ 'model_id': model_id,
1141
+ 'model_display_name': model_display_name,
1142
+ 'task_type': task_type,
1143
+ 'query_strategy': query_strategy,
1144
+ 'query_size': query_size,
1145
+ 'epochs': int(epochs),
1146
+ 'batch_size': int(batch_size),
1147
+ 'lr': lr,
1148
+ 'max_seq_len': int(max_seq_len),
1149
+ 'use_lora': use_lora,
1150
+ 'lora_rank': int(lora_rank) if use_lora else 8,
1151
+ 'lora_alpha': int(lora_alpha) if use_lora else 16,
1152
+ }
1153
+
1154
+ trained_model = trained_model.cpu()
1155
+ trained_model.eval()
1156
+
1157
+ log_text = (
1158
+ f"=== Active Learning Initialized ===\n"
1159
+ f"Seed: {len(seed_texts)} labeled | Pool: {len(pool_texts)} unlabeled\n"
1160
+ f"Model: {model_display_name}\n"
1161
+ f"Strategy: {query_strategy} | Samples/round: {query_size}\n\n"
1162
+ f"--- Round 0 (seed) ---\n"
1163
+ f"Train size: {len(seed_texts)}\n"
1164
+ )
1165
+ for k, v in eval_metrics.items():
1166
+ log_text += f" {k}: {v}\n"
1167
+ log_text += (
1168
+ f"\n--- Round 1: {len(query_texts_batch)} samples queried ---\n"
1169
+ f"Label the samples below, then click 'Submit Labels & Next Round'.\n"
1170
+ )
1171
+
1172
+ chart = _build_al_metrics_chart([round_metrics], task_type)
1173
+
1174
+ return (
1175
+ al_state, trained_model, tokenizer,
1176
+ annotation_df, log_text, chart,
1177
+ gr.Column(visible=True),
1178
+ )
1179
+
1180
+ except Exception as e:
1181
+ return (
1182
+ {}, None, None,
1183
+ pd.DataFrame(columns=['Text', 'Label']),
1184
+ f"Initialization failed:\n{str(e)}",
1185
+ None,
1186
+ gr.Column(visible=False),
1187
+ )
1188
+
1189
+
1190
+ def al_submit_and_continue(
1191
+ annotation_df, al_state, al_model, al_tokenizer, prev_log,
1192
+ progress=gr.Progress(track_tqdm=True),
1193
+ ):
1194
+ """Accept user labels, retrain, query next uncertain batch."""
1195
+ try:
1196
+ if not al_state or al_model is None:
1197
+ raise ValueError("No active session. Initialize first.")
1198
+
1199
+ new_texts = annotation_df['Text'].tolist()
1200
+ new_labels = []
1201
+ for i, raw in enumerate(annotation_df['Label'].tolist()):
1202
+ s = str(raw).strip()
1203
+ if s in ('', 'nan'):
1204
+ raise ValueError(f"Row {i + 1} has no label. Label all samples first.")
1205
+ new_labels.append(int(s))
1206
+
1207
+ num_labels = al_state['num_labels']
1208
+ for l in new_labels:
1209
+ if l < 0 or l >= num_labels:
1210
+ raise ValueError(f"Label {l} out of range [0, {num_labels - 1}].")
1211
+
1212
+ # Add newly labeled samples
1213
+ al_state['labeled_texts'].extend(new_texts)
1214
+ al_state['labeled_labels'].extend(new_labels)
1215
+
1216
+ queried_set = set(al_state['current_query_indices'])
1217
+ al_state['pool_available'] = [
1218
+ i for i in al_state['pool_available'] if i not in queried_set
1219
+ ]
1220
+
1221
+ current_round = al_state['round']
1222
+
1223
+ # Retrain on all labeled data
1224
+ trained_model, tokenizer, eval_metrics = _train_al_model(
1225
+ al_state['labeled_texts'], al_state['labeled_labels'],
1226
+ num_labels, al_state['dev_texts'], al_state['dev_labels'],
1227
+ al_state['task_type'], al_state['model_id'],
1228
+ al_state['epochs'], al_state['batch_size'], al_state['lr'],
1229
+ al_state['max_seq_len'], al_state['use_lora'],
1230
+ al_state['lora_rank'], al_state['lora_alpha'],
1231
+ )
1232
+
1233
+ round_metrics = {
1234
+ 'round': current_round,
1235
+ 'train_size': len(al_state['labeled_texts']),
1236
+ }
1237
+ round_metrics.update(eval_metrics)
1238
+ al_state['metrics_history'].append(round_metrics)
1239
+
1240
+ # Query next batch from remaining pool
1241
+ remaining_pool = al_state['pool_available']
1242
+ remaining_texts = [al_state['pool_texts'][i] for i in remaining_pool]
1243
+
1244
+ log_add = (
1245
+ f"\n--- Round {current_round} complete ---\n"
1246
+ f"Added {len(new_labels)} labels | "
1247
+ f"Total train: {len(al_state['labeled_texts'])}\n"
1248
+ )
1249
+ for k, v in eval_metrics.items():
1250
+ log_add += f" {k}: {v}\n"
1251
+
1252
+ if remaining_texts:
1253
+ scores = compute_uncertainty(
1254
+ trained_model, tokenizer, remaining_texts,
1255
+ al_state['query_strategy'], al_state['max_seq_len'],
1256
+ )
1257
+ q = min(al_state['query_size'], len(remaining_texts))
1258
+ top_local = np.argsort(scores)[-q:][::-1].tolist()
1259
+ top_pool_indices = [remaining_pool[i] for i in top_local]
1260
+ query_texts = [al_state['pool_texts'][i] for i in top_pool_indices]
1261
+
1262
+ al_state['current_query_indices'] = top_pool_indices
1263
+ al_state['round'] = current_round + 1
1264
+
1265
+ annotation_out = pd.DataFrame({
1266
+ 'Text': query_texts,
1267
+ 'Label': [''] * len(query_texts),
1268
+ })
1269
+ pool_left = len(remaining_pool) - len(top_pool_indices)
1270
+ log_add += (
1271
+ f"Pool remaining: {pool_left}\n"
1272
+ f"\n--- Round {current_round + 1}: {len(query_texts)} samples queried ---\n"
1273
+ )
1274
+ else:
1275
+ annotation_out = pd.DataFrame(columns=['Text', 'Label'])
1276
+ al_state['current_query_indices'] = []
1277
+ al_state['round'] = current_round + 1
1278
+ log_add += "\nPool exhausted. Active learning complete!\n"
1279
+
1280
+ trained_model = trained_model.cpu()
1281
+ trained_model.eval()
1282
+
1283
+ chart = _build_al_metrics_chart(al_state['metrics_history'], al_state['task_type'])
1284
+ log_text = prev_log + log_add
1285
+
1286
+ return (
1287
+ al_state, trained_model, tokenizer,
1288
+ annotation_out, log_text, chart,
1289
+ )
1290
+
1291
+ except Exception as e:
1292
+ return (
1293
+ al_state, al_model, al_tokenizer,
1294
+ pd.DataFrame(columns=['Text', 'Label']),
1295
+ prev_log + f"\nError: {str(e)}\n",
1296
+ None,
1297
+ )
1298
+
1299
+
1300
+ def al_save_model(save_path, al_model, al_tokenizer):
1301
+ """Save the active-learning model to disk."""
1302
+ if al_model is None:
1303
+ return "No model to save. Run at least one round first."
1304
+ if not save_path:
1305
+ return "Please specify a save directory."
1306
+ try:
1307
+ os.makedirs(save_path, exist_ok=True)
1308
+ al_model.save_pretrained(save_path)
1309
+ al_tokenizer.save_pretrained(save_path)
1310
+ return f"Model saved to: {save_path}"
1311
+ except Exception as e:
1312
+ return f"Error saving model: {str(e)}"
1313
+
1314
+
1315
+ def load_example_active_learning():
1316
+ """Load the active learning example dataset."""
1317
+ return (
1318
+ os.path.join(EXAMPLES_DIR, "active_learning", "seed.tsv"),
1319
+ os.path.join(EXAMPLES_DIR, "active_learning", "pool.txt"),
1320
+ os.path.join(EXAMPLES_DIR, "binary", "dev.tsv"),
1321
+ "Binary",
1322
+ )
1323
+
1324
+
1325
  def run_comparison(
1326
  train_file, dev_file, test_file, task_type, selected_models,
1327
+ epochs, batch_size, lr, cmp_use_lora, cmp_lora_rank, cmp_lora_alpha,
1328
  progress=gr.Progress(track_tqdm=True),
1329
  ):
1330
  """Train multiple models on the same data and compare performance + ROC curves."""
 
1371
  model = AutoModelForSequenceClassification.from_pretrained(
1372
  model_id, num_labels=num_labels,
1373
  )
1374
+
1375
+ cmp_lora_active = False
1376
+ if cmp_use_lora and PEFT_AVAILABLE:
1377
+ lora_cfg = LoraConfig(
1378
+ task_type=TaskType.SEQ_CLS,
1379
+ r=int(cmp_lora_rank), lora_alpha=int(cmp_lora_alpha),
1380
+ lora_dropout=0.1, bias="none",
1381
+ )
1382
+ model.enable_input_require_grads()
1383
+ model = get_peft_model(model, lora_cfg)
1384
+ cmp_lora_active = True
1385
+
1386
  train_ds = TextClassificationDataset(train_texts, train_labels, tokenizer, 512)
1387
  dev_ds = TextClassificationDataset(dev_texts, dev_labels, tokenizer, 512)
1388
  test_ds = TextClassificationDataset(test_texts, test_labels, tokenizer, 512)
 
1419
 
1420
  train_result = trainer.train()
1421
 
1422
+ # Merge LoRA weights before prediction
1423
+ if cmp_lora_active and hasattr(trainer.model, 'merge_and_unload'):
1424
+ trainer.model = trainer.model.merge_and_unload()
1425
+
1426
  # Get predictions for ROC curves
1427
  pred_output = trainer.predict(test_ds)
1428
  logits = pred_output.predictions
 
1668
  "8. Save the model and load it later in the "
1669
  "Classification tab\n\n"
1670
  "**Advanced features:**\n"
1671
+ "- **LoRA / QLoRA** for parameter-efficient training "
1672
+ "(lower VRAM, faster)\n"
1673
+ "- **Active Learning** tab for iterative labeling "
1674
+ "with uncertainty sampling\n"
1675
  "- Early stopping with configurable patience\n"
1676
  "- Learning rate schedulers (linear, cosine, constant)\n"
1677
  "- Mixed precision training (FP16 on CUDA GPUs)\n"
 
1907
  ["linear", "cosine", "constant", "constant_with_warmup"],
1908
  label="LR Scheduler", value="linear",
1909
  )
1910
+ gr.Markdown("**Parameter-Efficient Fine-Tuning (PEFT)**")
1911
+ with gr.Row():
1912
+ ft_use_lora = gr.Checkbox(
1913
+ label="Use LoRA", value=False,
1914
+ )
1915
+ ft_lora_rank = gr.Number(
1916
+ label="LoRA Rank (r)", value=8,
1917
+ minimum=1, maximum=256, precision=0,
1918
+ )
1919
+ ft_lora_alpha = gr.Number(
1920
+ label="LoRA Alpha", value=16,
1921
+ minimum=1, maximum=512, precision=0,
1922
+ )
1923
+ ft_use_qlora = gr.Checkbox(
1924
+ label="QLoRA (4-bit, CUDA only)", value=False,
1925
+ )
1926
 
1927
  # -- Train --
1928
  ft_train_btn = gr.Button(
 
1996
  cmp_epochs = gr.Number(label="Epochs", value=3, minimum=1, precision=0)
1997
  cmp_batch = gr.Number(label="Batch Size", value=8, minimum=1, precision=0)
1998
  cmp_lr = gr.Number(label="Learning Rate", value=2e-5, minimum=1e-7)
1999
+ with gr.Row():
2000
+ cmp_use_lora = gr.Checkbox(label="Use LoRA", value=False)
2001
+ cmp_lora_rank = gr.Number(label="LoRA Rank", value=8, minimum=1, maximum=256, precision=0)
2002
+ cmp_lora_alpha = gr.Number(label="LoRA Alpha", value=16, minimum=1, maximum=512, precision=0)
2003
  cmp_btn = gr.Button("Compare Models", variant="primary")
2004
  cmp_log = gr.Textbox(
2005
  label="Comparison Log", lines=8,
 
2012
  cmp_plot = gr.Plot(label="Metrics Comparison")
2013
  cmp_roc = gr.Plot(label="ROC Curves")
2014
 
2015
+ # ================================================================
2016
+ # ACTIVE LEARNING TAB
2017
+ # ================================================================
2018
+ with gr.Tab("Active Learning"):
2019
+ gr.Markdown(info_callout(
2020
+ "**Active learning** iteratively selects the most uncertain "
2021
+ "samples from an unlabeled pool for you to label, then retrains. "
2022
+ "This lets you build a strong classifier with far fewer labels."
2023
+ ))
2024
+
2025
+ # -- Data --
2026
+ gr.Markdown("### Data")
2027
+ gr.Markdown(
2028
+ "**Seed file** — small labeled set (TSV, `text[TAB]label`). \n"
2029
+ "**Pool file** — unlabeled texts (one per line, or CSV with `text` column). \n"
2030
+ "**Dev file** *(optional)* — held-out labeled set to track metrics."
2031
+ )
2032
+ al_ex_btn = gr.Button(
2033
+ "Load Example: Binary Active Learning",
2034
+ variant="secondary", size="sm",
2035
+ )
2036
+ with gr.Row():
2037
+ al_seed_file = gr.File(
2038
+ label="Labeled Seed (TSV)",
2039
+ file_types=[".tsv", ".csv", ".txt"],
2040
+ )
2041
+ al_pool_file = gr.File(
2042
+ label="Unlabeled Pool",
2043
+ file_types=[".tsv", ".csv", ".txt"],
2044
+ )
2045
+ al_dev_file = gr.File(
2046
+ label="Dev / Validation (optional)",
2047
+ file_types=[".tsv", ".csv", ".txt"],
2048
+ )
2049
+
2050
+ # -- Configuration --
2051
+ gr.Markdown("### Configuration")
2052
+ with gr.Row():
2053
+ al_task = gr.Radio(
2054
+ ["Binary", "Multiclass"],
2055
+ label="Task Type", value="Binary",
2056
+ )
2057
+ al_model_dd = gr.Dropdown(
2058
+ choices=list(FINETUNE_MODELS.keys()),
2059
+ label="Base Model",
2060
+ value=list(FINETUNE_MODELS.keys())[0],
2061
+ )
2062
+ with gr.Row():
2063
+ al_strategy = gr.Dropdown(
2064
+ ["entropy", "margin", "least_confidence"],
2065
+ label="Query Strategy", value="entropy",
2066
+ )
2067
+ al_query_size = gr.Number(
2068
+ label="Samples per Round", value=20,
2069
+ minimum=1, maximum=500, precision=0,
2070
+ )
2071
+ with gr.Row():
2072
+ al_epochs = gr.Number(
2073
+ label="Epochs per Round", value=3,
2074
+ minimum=1, maximum=50, precision=0,
2075
+ )
2076
+ al_batch_size = gr.Number(
2077
+ label="Batch Size", value=8,
2078
+ minimum=1, maximum=128, precision=0,
2079
+ )
2080
+ al_lr = gr.Number(
2081
+ label="Learning Rate", value=2e-5,
2082
+ minimum=1e-7, maximum=1e-2,
2083
+ )
2084
+ with gr.Accordion("Advanced", open=False):
2085
+ with gr.Row():
2086
+ al_max_len = gr.Number(
2087
+ label="Max Sequence Length", value=512,
2088
+ minimum=32, maximum=8192, precision=0,
2089
+ )
2090
+ al_use_lora = gr.Checkbox(label="Use LoRA", value=False)
2091
+ al_lora_rank = gr.Number(
2092
+ label="LoRA Rank", value=8,
2093
+ minimum=1, maximum=256, precision=0,
2094
+ )
2095
+ al_lora_alpha = gr.Number(
2096
+ label="LoRA Alpha", value=16,
2097
+ minimum=1, maximum=512, precision=0,
2098
+ )
2099
+
2100
+ al_init_btn = gr.Button(
2101
+ "Initialize Active Learning", variant="primary", size="lg",
2102
+ )
2103
+
2104
+ # -- State --
2105
+ al_state = gr.State({})
2106
+ al_model_state = gr.State(None)
2107
+ al_tokenizer_state = gr.State(None)
2108
+
2109
+ with gr.Accordion("Log", open=False):
2110
+ al_log = gr.Textbox(
2111
+ lines=12, interactive=False, elem_classes="log-output",
2112
+ show_label=False,
2113
+ )
2114
+
2115
+ # -- Annotation panel (hidden until init) --
2116
+ with gr.Column(visible=False) as al_annotation_col:
2117
+ gr.Markdown("### Label These Samples")
2118
+ gr.Markdown(
2119
+ "Fill in the **Label** column with integer class labels "
2120
+ "(e.g. 0 or 1 for binary). Then click **Submit**."
2121
+ )
2122
+ al_annotation_df = gr.Dataframe(
2123
+ headers=["Text", "Label"],
2124
+ interactive=True,
2125
+ wrap=True,
2126
+ row_count=(1, "dynamic"),
2127
+ )
2128
+ with gr.Row():
2129
+ al_submit_btn = gr.Button(
2130
+ "Submit Labels & Next Round",
2131
+ variant="primary",
2132
+ )
2133
+
2134
+ al_chart = gr.Plot(label="Metrics Across Rounds")
2135
+
2136
+ gr.Markdown("### Save Model")
2137
+ with gr.Row():
2138
+ al_save_path = gr.Textbox(
2139
+ label="Save Directory", value="./al_model",
2140
+ )
2141
+ al_save_btn = gr.Button("Save", variant="secondary")
2142
+ al_save_status = gr.Markdown("")
2143
+
2144
  # ---- FOOTER ----
2145
  gr.Markdown(
2146
  "<div style='text-align: center; padding: 1rem 0; margin-top: 0.5rem; "
 
2227
  ft_epochs, ft_batch, ft_lr,
2228
  ft_weight_decay, ft_warmup, ft_max_len,
2229
  ft_grad_accum, ft_fp16, ft_patience, ft_scheduler,
2230
+ ft_use_lora, ft_lora_rank, ft_lora_alpha, ft_use_qlora,
2231
  ],
2232
  outputs=[
2233
  ft_log, ft_metrics,
 
2258
  outputs=[ft_batch_out],
2259
  )
2260
 
2261
+ # Active Learning: example loader
2262
+ al_ex_btn.click(
2263
+ fn=load_example_active_learning,
2264
+ outputs=[al_seed_file, al_pool_file, al_dev_file, al_task],
2265
+ )
2266
+
2267
+ # Active Learning
2268
+ al_init_btn.click(
2269
+ fn=al_initialize,
2270
+ inputs=[
2271
+ al_seed_file, al_pool_file, al_dev_file,
2272
+ al_task, al_model_dd, al_strategy, al_query_size,
2273
+ al_epochs, al_batch_size, al_lr, al_max_len,
2274
+ al_use_lora, al_lora_rank, al_lora_alpha,
2275
+ ],
2276
+ outputs=[
2277
+ al_state, al_model_state, al_tokenizer_state,
2278
+ al_annotation_df, al_log, al_chart,
2279
+ al_annotation_col,
2280
+ ],
2281
+ concurrency_limit=1,
2282
+ )
2283
+
2284
+ al_submit_btn.click(
2285
+ fn=al_submit_and_continue,
2286
+ inputs=[
2287
+ al_annotation_df, al_state, al_model_state, al_tokenizer_state,
2288
+ al_log,
2289
+ ],
2290
+ outputs=[
2291
+ al_state, al_model_state, al_tokenizer_state,
2292
+ al_annotation_df, al_log, al_chart,
2293
+ ],
2294
+ concurrency_limit=1,
2295
+ )
2296
+
2297
+ al_save_btn.click(
2298
+ fn=al_save_model,
2299
+ inputs=[al_save_path, al_model_state, al_tokenizer_state],
2300
+ outputs=[al_save_status],
2301
+ )
2302
+
2303
  # Model comparison
2304
  cmp_btn.click(
2305
  fn=run_comparison,
2306
  inputs=[
2307
  ft_train_file, ft_dev_file, ft_test_file,
2308
  ft_task, cmp_models, cmp_epochs, cmp_batch, cmp_lr,
2309
+ cmp_use_lora, cmp_lora_rank, cmp_lora_alpha,
2310
  ],
2311
  outputs=[cmp_log, cmp_table, cmp_plot, cmp_roc, cmp_results_col],
2312
  concurrency_limit=1,
examples/active_learning/pool.txt ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A car bomb exploded near a military checkpoint killing at least twelve soldiers
2
+ The oceanographic institute published research on coral reef restoration
3
+ Annual tourism numbers reached an all-time high at the coastal resorts
4
+ Gunmen opened fire on a convoy of government officials killing two bodyguards
5
+ Cross-border shelling between the two nations continued for the third consecutive day
6
+ Public transit ridership increased following improvements to the subway system
7
+ Armed bandits attacked a refugee camp displacing thousands of people
8
+ The guerrilla fighters ambushed a supply convoy on the main highway
9
+ The bakery chain announced plans to expand into twelve new locations
10
+ The city hosted a successful international food and wine festival
11
+ The university announced a new scholarship program for students in engineering
12
+ A new study found that regular exercise significantly reduces heart disease risk
13
+ The national football team secured a convincing victory in the qualifying match
14
+ The rebel forces captured a strategic town after weeks of intense battles
15
+ The solar energy project is expected to power thousands of homes by year end
16
+ Insurgents attacked a police station in the capital overnight leaving several officers wounded
17
+ Military helicopters were deployed to support ground troops fighting in the eastern region
18
+ The opposition forces breached the defensive perimeter around the government compound
19
+ A mortar attack on the military base resulted in significant casualties
20
+ The technology company unveiled its latest smartphone with improved camera capabilities
21
+ Government aircraft bombed suspected rebel strongholds in the mountainous region
22
+ Security forces conducted raids targeting suspected members of the armed opposition
23
+ A suicide bomber detonated explosives at a crowded marketplace injuring dozens of civilians
24
+ Local farmers reported an excellent harvest this season due to favorable weather
25
+ The agricultural ministry launched a program to support organic farming
26
+ An airstrike destroyed a weapons depot used by the insurgent group
27
+ The orchestra performed a sold-out concert of works by contemporary composers
28
+ The opposing forces exchanged heavy gunfire throughout the night
29
+ The gaming company released a new title that quickly became a bestseller
30
+ The swimming team broke the national record at the regional championships
31
+ The government declared a state of emergency following widespread political violence
32
+ The city council approved plans for a new public park in the downtown area
33
+ Government forces launched an offensive against rebel positions in the northern province early this morning
34
+ Researchers published findings on a promising treatment for a rare disorder
35
+ A popular streaming service announced an original series based on the classic novel
36
+ A drone strike targeted a meeting of senior militant commanders
37
+ The winter ski season opened early due to heavy snowfall in the mountains
38
+ The museum opened a new exhibition showcasing contemporary sculpture and painting
39
+ The separatist movement launched coordinated attacks on government installations
40
+ An improvised explosive device was found near the parliament building
41
+ The film festival announced its lineup featuring works from emerging directors
42
+ Temperatures are expected to reach record highs this weekend according to forecasters
43
+ Scientists discovered a new species of deep-sea fish in the Pacific Ocean
44
+ Two soldiers were killed when their vehicle struck a landmine on a rural road
45
+ The cycling tour attracted international competitors to the coastal route
46
+ The pharmaceutical company received approval for a new vaccine formulation
47
+ A grenade attack on a busy intersection killed four people and wounded many more
48
+ Ethnic tensions erupted into open violence as rival communities clashed in the market
49
+ The hospital inaugurated a state-of-the-art wing dedicated to pediatric care
50
+ The cookbook featuring traditional regional recipes became an unexpected bestseller
51
+ Heavy fighting broke out between rival armed factions in the disputed border region
52
+ The automotive company revealed plans to launch three new electric vehicle models
53
+ Armed men attacked a village killing several residents and burning homes
54
+ The annual science fair showcased innovative projects by high school students
55
+ A roadside bomb targeted a military patrol wounding three soldiers
56
+ An explosion at a government building was attributed to opposition fighters
57
+ Armed opposition forces shelled the outskirts of the capital city
58
+ Fighting between government troops and rebels displaced thousands of families
59
+ Coalition forces conducted a night raid capturing several high-value targets
60
+ The militant group claimed responsibility for the ambush on a military convoy
61
+ The armed group kidnapped aid workers operating in the conflict zone
examples/active_learning/pool_with_labels.tsv ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A car bomb exploded near a military checkpoint killing at least twelve soldiers 1
2
+ The oceanographic institute published research on coral reef restoration 0
3
+ Annual tourism numbers reached an all-time high at the coastal resorts 0
4
+ Gunmen opened fire on a convoy of government officials killing two bodyguards 1
5
+ Cross-border shelling between the two nations continued for the third consecutive day 1
6
+ Public transit ridership increased following improvements to the subway system 0
7
+ Armed bandits attacked a refugee camp displacing thousands of people 1
8
+ The guerrilla fighters ambushed a supply convoy on the main highway 1
9
+ The bakery chain announced plans to expand into twelve new locations 0
10
+ The city hosted a successful international food and wine festival 0
11
+ The university announced a new scholarship program for students in engineering 0
12
+ A new study found that regular exercise significantly reduces heart disease risk 0
13
+ The national football team secured a convincing victory in the qualifying match 0
14
+ The rebel forces captured a strategic town after weeks of intense battles 1
15
+ The solar energy project is expected to power thousands of homes by year end 0
16
+ Insurgents attacked a police station in the capital overnight leaving several officers wounded 1
17
+ Military helicopters were deployed to support ground troops fighting in the eastern region 1
18
+ The opposition forces breached the defensive perimeter around the government compound 1
19
+ A mortar attack on the military base resulted in significant casualties 1
20
+ The technology company unveiled its latest smartphone with improved camera capabilities 0
21
+ Government aircraft bombed suspected rebel strongholds in the mountainous region 1
22
+ Security forces conducted raids targeting suspected members of the armed opposition 1
23
+ A suicide bomber detonated explosives at a crowded marketplace injuring dozens of civilians 1
24
+ Local farmers reported an excellent harvest this season due to favorable weather 0
25
+ The agricultural ministry launched a program to support organic farming 0
26
+ An airstrike destroyed a weapons depot used by the insurgent group 1
27
+ The orchestra performed a sold-out concert of works by contemporary composers 0
28
+ The opposing forces exchanged heavy gunfire throughout the night 1
29
+ The gaming company released a new title that quickly became a bestseller 0
30
+ The swimming team broke the national record at the regional championships 0
31
+ The government declared a state of emergency following widespread political violence 1
32
+ The city council approved plans for a new public park in the downtown area 0
33
+ Government forces launched an offensive against rebel positions in the northern province early this morning 1
34
+ Researchers published findings on a promising treatment for a rare disorder 0
35
+ A popular streaming service announced an original series based on the classic novel 0
36
+ A drone strike targeted a meeting of senior militant commanders 1
37
+ The winter ski season opened early due to heavy snowfall in the mountains 0
38
+ The museum opened a new exhibition showcasing contemporary sculpture and painting 0
39
+ The separatist movement launched coordinated attacks on government installations 1
40
+ An improvised explosive device was found near the parliament building 1
41
+ The film festival announced its lineup featuring works from emerging directors 0
42
+ Temperatures are expected to reach record highs this weekend according to forecasters 0
43
+ Scientists discovered a new species of deep-sea fish in the Pacific Ocean 0
44
+ Two soldiers were killed when their vehicle struck a landmine on a rural road 1
45
+ The cycling tour attracted international competitors to the coastal route 0
46
+ The pharmaceutical company received approval for a new vaccine formulation 0
47
+ A grenade attack on a busy intersection killed four people and wounded many more 1
48
+ Ethnic tensions erupted into open violence as rival communities clashed in the market 1
49
+ The hospital inaugurated a state-of-the-art wing dedicated to pediatric care 0
50
+ The cookbook featuring traditional regional recipes became an unexpected bestseller 0
51
+ Heavy fighting broke out between rival armed factions in the disputed border region 1
52
+ The automotive company revealed plans to launch three new electric vehicle models 0
53
+ Armed men attacked a village killing several residents and burning homes 1
54
+ The annual science fair showcased innovative projects by high school students 0
55
+ A roadside bomb targeted a military patrol wounding three soldiers 1
56
+ An explosion at a government building was attributed to opposition fighters 1
57
+ Armed opposition forces shelled the outskirts of the capital city 1
58
+ Fighting between government troops and rebels displaced thousands of families 1
59
+ Coalition forces conducted a night raid capturing several high-value targets 1
60
+ The militant group claimed responsibility for the ambush on a military convoy 1
61
+ The armed group kidnapped aid workers operating in the conflict zone 1
examples/active_learning/seed.tsv ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Astronomers observed a rare celestial event visible from the southern hemisphere 0
2
+ Sniper fire killed two civilians in the besieged neighborhood 1
3
+ A major software update was released improving performance and adding new features 0
4
+ The airline announced new direct flights connecting the capital with European cities 0
5
+ Security operations intensified after a series of bombings in the commercial district 1
6
+ Protesters clashed violently with police during demonstrations against the military regime 1
7
+ A popular author released the highly anticipated sequel to her bestselling novel 0
8
+ Artillery shells struck residential areas as the conflict between the two sides intensified 1
9
+ Archaeologists uncovered ancient pottery at a dig site near the monument 0
10
+ The military junta deployed tanks to suppress the growing resistance movement 1
11
+ The electric vehicle charging network expanded to cover all major highways 0
12
+ The marathon attracted over twenty thousand runners from across the country 0
13
+ The ongoing civil war has resulted in thousands of casualties and widespread destruction 1
14
+ The dairy industry adopted new standards for sustainable milk production 0
15
+ Paramilitary groups carried out targeted assassinations of political opponents 1
16
+ A local nonprofit organized a community cleanup event at the riverside park 0
17
+ The construction of the new high-speed rail line is ahead of schedule 0
18
+ Stock markets rallied on news of stronger than expected economic growth 0
19
+ The tech startup raised significant funding in its latest investment round 0
20
+ A militia group took control of a key oil facility in the contested region 1
requirements.txt CHANGED
@@ -10,3 +10,5 @@ accelerate
10
  scikit-learn
11
  pandas
12
  plotly
 
 
 
10
  scikit-learn
11
  pandas
12
  plotly
13
+ peft>=0.6
14
+ bitsandbytes