# ============================================================================ # ConfliBERT - Conflict & Political Violence NLP Toolkit # University of Texas at Dallas | Event Data Lab # ============================================================================ import os os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # Patch gradio_client bug: bool JSON sub-schemas crash schema parsing try: from gradio_client import utils as _gc_utils _original_get_type = _gc_utils.get_type def _patched_get_type(schema): if isinstance(schema, bool): return "Any" return _original_get_type(schema) _gc_utils.get_type = _patched_get_type _original_json_schema = _gc_utils._json_schema_to_python_type def _patched_json_schema(schema, defs=None): if isinstance(schema, bool): return "Any" return _original_json_schema(schema, defs) _gc_utils._json_schema_to_python_type = _patched_json_schema except Exception: pass import torch from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback, TrainerCallback, ) # QA model uses TensorFlow (transformers <5) or PyTorch fallback (transformers >=5) _USE_TF_QA = False try: import tensorflow as tf import tf_keras # noqa: F401 import keras # noqa: F401 from transformers import TFAutoModelForQuestionAnswering _USE_TF_QA = True except (ImportError, ModuleNotFoundError): from transformers import AutoModelForQuestionAnswering import gradio as gr import numpy as np import pandas as pd import re import csv import tempfile from sklearn.metrics import ( accuracy_score as sk_accuracy, precision_score as sk_precision, recall_score as sk_recall, f1_score as sk_f1, roc_curve, auc as sk_auc, ) from sklearn.preprocessing import label_binarize from torch.utils.data import Dataset as TorchDataset import gc # LoRA / QLoRA support (optional) try: from peft import LoraConfig, get_peft_model, TaskType PEFT_AVAILABLE = True except ImportError: PEFT_AVAILABLE = False try: from transformers import BitsAndBytesConfig BNB_AVAILABLE = True except ImportError: BNB_AVAILABLE = False # ============================================================================ # CONFIGURATION # ============================================================================ if torch.cuda.is_available(): device = torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') else: device = torch.device('cpu') MAX_TOKEN_LENGTH = 512 def get_system_info(): """Build an HTML string describing the user's compute environment.""" import platform lines = [] # Device if device.type == 'cuda': gpu_name = torch.cuda.get_device_name(0) vram = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) lines.append(f"GPU: {gpu_name} ({vram:.1f} GB VRAM)") lines.append("FP16 training: supported") elif device.type == 'mps': lines.append("GPU: Apple Silicon (MPS)") lines.append("FP16 training: not supported on MPS") else: lines.append("GPU: None detected (using CPU)") lines.append("FP16 training: not supported on CPU") # CPU / RAM import os cpu_count = os.cpu_count() or 1 lines.append(f"CPU cores: {cpu_count}") try: import psutil ram_gb = psutil.virtual_memory().total / (1024 ** 3) lines.append(f"RAM: {ram_gb:.1f} GB") except ImportError: pass lines.append(f"Platform: {platform.system()} {platform.machine()}") lines.append(f"PyTorch: {torch.__version__}") return " ยท ".join(lines) FINETUNE_MODELS = { "ConfliBERT (recommended for conflict/political text)": "snowood1/ConfliBERT-scr-uncased", "BERT Base Uncased": "bert-base-uncased", "BERT Base Cased": "bert-base-cased", "RoBERTa Base": "roberta-base", "ModernBERT Base": "answerdotai/ModernBERT-base", "DeBERTa v3 Base": "microsoft/deberta-v3-base", "DistilBERT Base Uncased": "distilbert-base-uncased", } NER_LABELS = { 'Organisation': '#3b82f6', 'Person': '#ef4444', 'Location': '#10b981', 'Quantity': '#ff6b35', 'Weapon': '#8b5cf6', 'Nationality': '#06b6d4', 'Temporal': '#ec4899', 'DocumentReference': '#92400e', 'MilitaryPlatform': '#f59e0b', 'Money': '#f472b6', } CLASS_NAMES = ['Negative', 'Positive'] MULTI_CLASS_NAMES = ["Armed Assault", "Bombing or Explosion", "Kidnapping", "Other"] # ============================================================================ # PRETRAINED MODEL LOADING # ============================================================================ qa_model_name = 'salsarra/ConfliBERT-QA' if _USE_TF_QA: qa_model = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name) else: qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name, from_tf=True) qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) ner_model_name = 'eventdata-utd/conflibert-named-entity-recognition' ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device) ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) clf_model_name = 'eventdata-utd/conflibert-binary-classification' clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name).to(device) clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name) multi_clf_model_name = 'eventdata-utd/conflibert-satp-relevant-multilabel' multi_clf_model = AutoModelForSequenceClassification.from_pretrained(multi_clf_model_name).to(device) multi_clf_tokenizer = AutoTokenizer.from_pretrained(multi_clf_model_name) # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ def get_path(f): """Get file path from Gradio file component output.""" if f is None: return None return f if isinstance(f, str) else getattr(f, 'name', str(f)) def truncate_text(text, tokenizer, max_length=MAX_TOKEN_LENGTH): tokens = tokenizer.encode(text, truncation=False) if len(tokens) > max_length: tokens = tokens[:max_length - 1] + [tokenizer.sep_token_id] return tokenizer.decode(tokens, skip_special_tokens=True) return text def info_callout(text): """Wrap markdown text in a styled callout div to avoid Gradio double-border.""" return ( "
\n\n" f"{text}\n\n
" ) def handle_error(e, default_limit=512): msg = str(e) match = re.search( r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)", msg ) if match: return ( f"" f"Error: Input ({match.group(1)} tokens) exceeds model limit ({match.group(2)})" ) match_qa = re.search(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)", msg) if match_qa: return ( f"" f"Error: Input too long for model (limit: {match_qa.group(2)} tokens)" ) return f"Error: {msg}" # ============================================================================ # INFERENCE FUNCTIONS # ============================================================================ def question_answering(context, question): if not context or not question: return "Please provide both context and question." try: if _USE_TF_QA: inputs = qa_tokenizer(question, context, return_tensors='tf', truncation=True) outputs = qa_model(inputs) start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 tokens = qa_tokenizer.convert_ids_to_tokens( inputs['input_ids'].numpy()[0][start:end] ) else: inputs = qa_tokenizer(question, context, return_tensors='pt', truncation=True) with torch.no_grad(): outputs = qa_model(**inputs) start = torch.argmax(outputs.start_logits, dim=1).item() end = torch.argmax(outputs.end_logits, dim=1).item() + 1 tokens = qa_tokenizer.convert_ids_to_tokens( inputs['input_ids'][0][start:end] ) answer = qa_tokenizer.convert_tokens_to_string(tokens) return f"{answer}" except Exception as e: return handle_error(e) def named_entity_recognition(text, output_format='html'): if not text: return "Please provide text for analysis." try: inputs = ner_tokenizer(text, return_tensors='pt', truncation=True) with torch.no_grad(): outputs = ner_model(**inputs) results = outputs.logits.argmax(dim=2).squeeze().tolist() tokens = ner_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist()) tokens = [t.replace('[UNK]', "'") for t in tokens] entities = [] seen_labels = set() current_entity = [] current_label = None for i in range(len(tokens)): token = tokens[i] label = ner_model.config.id2label[results[i]].split('-')[-1] if token.startswith('##'): if entities: if output_format == 'html': entities[-1][0] += token[2:] elif current_entity: current_entity[-1] = current_entity[-1] + token[2:] else: if output_format == 'csv': if label != 'O': if label == current_label: current_entity.append(token) else: if current_entity: entities.append([' '.join(current_entity), current_label]) current_entity = [token] current_label = label else: if current_entity: entities.append([' '.join(current_entity), current_label]) current_entity = [] current_label = None else: entities.append([token, label]) if label != 'O': seen_labels.add(label) if output_format == 'csv' and current_entity: entities.append([' '.join(current_entity), current_label]) if output_format == 'csv': grouped = {} for token, label in entities: if label != 'O': grouped.setdefault(label, []).append(token) parts = [] for label, toks in grouped.items(): unique = list(dict.fromkeys(toks)) parts.append(f"{label}: {' | '.join(unique)}") return ' || '.join(parts) # HTML output highlighted = "" for token, label in entities: color = NER_LABELS.get(label, 'inherit') if label != 'O': highlighted += ( f"{token} " ) else: highlighted += f"{token} " if seen_labels: legend_items = "" for label in sorted(seen_labels): color = NER_LABELS.get(label, '#666') legend_items += ( f"
  • {label}
  • " ) legend = ( f"
    " f"Entities found:" f"
    " ) return f"
    {highlighted}
    {legend}" else: return ( f"
    {highlighted}
    " f"
    No entities detected.
    " ) except Exception as e: return handle_error(e) def predict_with_model(text, model, tokenizer): """Run inference with an arbitrary classification model.""" model.eval() dev = next(model.parameters()).device inputs = tokenizer( text, return_tensors='pt', truncation=True, padding=True, max_length=512 ) inputs = {k: v.to(dev) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1).squeeze() predicted = torch.argmax(probs).item() num_classes = probs.shape[0] if probs.dim() > 0 else 1 lines = [] for i in range(num_classes): p = probs[i].item() * 100 if probs.dim() > 0 else probs.item() * 100 if i == predicted: lines.append( f"" f"Class {i}: {p:.2f}% (predicted)" ) else: lines.append(f"Class {i}: {p:.2f}%") return "
    ".join(lines) def text_classification(text, custom_model=None, custom_tokenizer=None): if not text: return "Please provide text for classification." try: # Use custom model if loaded if custom_model is not None and custom_tokenizer is not None: return predict_with_model(text, custom_model, custom_tokenizer) # Pretrained binary classifier inputs = clf_tokenizer( text, return_tensors='pt', truncation=True, padding=True ).to(device) with torch.no_grad(): outputs = clf_model(**inputs) predicted = torch.argmax(outputs.logits, dim=1).item() confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100 if predicted == 1: return ( f"" f"Positive -- Related to conflict, violence, or politics. " f"(Confidence: {confidence:.1f}%)" ) else: return ( f"" f"Negative -- Not related to conflict, violence, or politics. " f"(Confidence: {confidence:.1f}%)" ) except Exception as e: return handle_error(e) def multilabel_classification(text, custom_model=None, custom_tokenizer=None): if not text: return "Please provide text for classification." try: # Use custom model if loaded if custom_model is not None and custom_tokenizer is not None: return predict_with_model(text, custom_model, custom_tokenizer) inputs = multi_clf_tokenizer( text, return_tensors='pt', truncation=True, padding=True ).to(device) with torch.no_grad(): outputs = multi_clf_model(**inputs) probs = torch.sigmoid(outputs.logits).squeeze().tolist() results = [] for i in range(len(probs)): conf = probs[i] * 100 if probs[i] >= 0.5: results.append( f"" f"{MULTI_CLASS_NAMES[i]}: {conf:.1f}%" ) else: results.append( f"" f"{MULTI_CLASS_NAMES[i]}: {conf:.1f}%" ) return "
    ".join(results) except Exception as e: return handle_error(e) # ============================================================================ # CSV BATCH PROCESSING # ============================================================================ def process_csv_ner(file): path = get_path(file) if path is None: return None df = pd.read_csv(path) if 'text' not in df.columns: raise ValueError("CSV must contain a 'text' column") entities = [] for text in df['text']: if pd.isna(text): entities.append("") else: entities.append(named_entity_recognition(str(text), output_format='csv')) df['entities'] = entities out = tempfile.NamedTemporaryFile(suffix='_ner_results.csv', delete=False) df.to_csv(out.name, index=False) return out.name def process_csv_binary(file, custom_model=None, custom_tokenizer=None): path = get_path(file) if path is None: return None df = pd.read_csv(path) if 'text' not in df.columns: raise ValueError("CSV must contain a 'text' column") results = [] for text in df['text']: if pd.isna(text): results.append("") else: html = text_classification(str(text), custom_model, custom_tokenizer) results.append(re.sub(r'<[^>]+>', '', html).strip()) df['classification_results'] = results out = tempfile.NamedTemporaryFile(suffix='_classification_results.csv', delete=False) df.to_csv(out.name, index=False) return out.name def process_csv_multilabel(file): path = get_path(file) if path is None: return None df = pd.read_csv(path) if 'text' not in df.columns: raise ValueError("CSV must contain a 'text' column") results = [] for text in df['text']: if pd.isna(text): results.append("") else: html = multilabel_classification(str(text)) results.append(re.sub(r'<[^>]+>', '', html).strip()) df['multilabel_results'] = results out = tempfile.NamedTemporaryFile(suffix='_multilabel_results.csv', delete=False) df.to_csv(out.name, index=False) return out.name def process_csv_qa(file): path = get_path(file) if path is None: return None df = pd.read_csv(path) if 'context' not in df.columns or 'question' not in df.columns: raise ValueError("CSV must contain 'context' and 'question' columns") answers = [] for _, row in df.iterrows(): if pd.isna(row['context']) or pd.isna(row['question']): answers.append("") else: html = question_answering(str(row['context']), str(row['question'])) answers.append(re.sub(r'<[^>]+>', '', html).strip()) df['answer'] = answers out = tempfile.NamedTemporaryFile(suffix='_qa_results.csv', delete=False) df.to_csv(out.name, index=False) return out.name # ============================================================================ # FINETUNING # ============================================================================ class TextClassificationDataset(TorchDataset): """PyTorch Dataset for text classification with HuggingFace tokenizers.""" def __init__(self, texts, labels, tokenizer, max_length=512): self.encodings = tokenizer( texts, truncation=True, padding=True, max_length=max_length, return_tensors=None, ) self.labels = labels def __getitem__(self, idx): item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long) return item def __len__(self): return len(self.labels) def parse_data_file(file_path): """Parse a TSV/CSV data file. Expected format: textlabel (no header). Labels must be integers. Returns (texts, labels, num_labels).""" path = get_path(file_path) texts, labels = [], [] # Detect delimiter from first line with open(path, 'r', encoding='utf-8') as f: first_line = f.readline() delimiter = '\t' if '\t' in first_line else ',' with open(path, 'r', encoding='utf-8') as f: reader = csv.reader(f, delimiter=delimiter, quotechar='"') for row in reader: if len(row) < 2: continue try: label = int(row[-1].strip()) text = row[0].strip() if len(row) == 2 else delimiter.join(row[:-1]).strip() if text: texts.append(text) labels.append(label) except (ValueError, IndexError): continue # skip header or malformed rows if not texts: raise ValueError( "No valid data rows found. Expected format: textlabel (no header row)" ) num_labels = max(labels) + 1 return texts, labels, num_labels class LogCallback(TrainerCallback): """Captures training logs for display in the UI.""" def __init__(self): self.entries = [] def on_log(self, args, state, control, logs=None, **kwargs): if logs: self.entries.append({**logs}) def format(self): lines = [] skip_keys = { 'total_flos', 'train_runtime', 'train_samples_per_second', 'train_steps_per_second', 'train_loss', } for entry in self.entries: parts = [] for k, v in sorted(entry.items()): if k in skip_keys: continue if isinstance(v, float): parts.append(f"{k}: {v:.4f}") elif isinstance(v, (int, np.integer)): parts.append(f"{k}: {v}") if parts: lines.append(" ".join(parts)) return "\n".join(lines) def make_compute_metrics(task_type): """Factory for compute_metrics function based on task type.""" def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) acc = sk_accuracy(labels, preds) if task_type == "Binary": return { 'accuracy': acc, 'precision': sk_precision(labels, preds, zero_division=0), 'recall': sk_recall(labels, preds, zero_division=0), 'f1': sk_f1(labels, preds, zero_division=0), } else: return { 'accuracy': acc, 'f1_macro': sk_f1(labels, preds, average='macro', zero_division=0), 'f1_micro': sk_f1(labels, preds, average='micro', zero_division=0), 'precision_macro': sk_precision( labels, preds, average='macro', zero_division=0 ), 'precision_micro': sk_precision( labels, preds, average='micro', zero_division=0 ), 'recall_macro': sk_recall( labels, preds, average='macro', zero_division=0 ), 'recall_micro': sk_recall( labels, preds, average='micro', zero_division=0 ), } return compute_metrics def run_finetuning( train_file, dev_file, test_file, task_type, model_display_name, epochs, batch_size, lr, weight_decay, warmup_ratio, max_seq_len, grad_accum, fp16, patience, scheduler, use_lora, lora_rank, lora_alpha, use_qlora, progress=gr.Progress(track_tqdm=True), ): """Main finetuning function. Returns logs, metrics, model state, and visibility updates.""" try: # Validate inputs if train_file is None or dev_file is None or test_file is None: raise ValueError("Please upload all three data files (train, dev, test).") epochs = int(epochs) batch_size = int(batch_size) max_seq_len = int(max_seq_len) grad_accum = int(grad_accum) patience = int(patience) # Parse data files train_texts, train_labels, n_train = parse_data_file(train_file) dev_texts, dev_labels, n_dev = parse_data_file(dev_file) test_texts, test_labels, n_test = parse_data_file(test_file) num_labels = max(n_train, n_dev, n_test) if task_type == "Binary" and num_labels > 2: raise ValueError( f"Binary task selected but found {num_labels} label classes in data. " f"Use Multiclass instead." ) if task_type == "Binary": num_labels = 2 # Load model and tokenizer model_id = FINETUNE_MODELS[model_display_name] tokenizer = AutoTokenizer.from_pretrained(model_id) lora_active = False if use_qlora: if not (PEFT_AVAILABLE and BNB_AVAILABLE and torch.cuda.is_available()): raise ValueError( "QLoRA requires a CUDA GPU and the peft + bitsandbytes packages." ) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) model = AutoModelForSequenceClassification.from_pretrained( model_id, num_labels=num_labels, quantization_config=bnb_config, ) else: model = AutoModelForSequenceClassification.from_pretrained( model_id, num_labels=num_labels, ) if use_lora or use_qlora: if not PEFT_AVAILABLE: raise ValueError( "LoRA requires the 'peft' package. Install: pip install peft" ) lora_config = LoraConfig( task_type=TaskType.SEQ_CLS, r=int(lora_rank), lora_alpha=int(lora_alpha), lora_dropout=0.1, bias="none", ) model.enable_input_require_grads() model = get_peft_model(model, lora_config) lora_active = True # Create datasets train_ds = TextClassificationDataset( train_texts, train_labels, tokenizer, max_seq_len ) dev_ds = TextClassificationDataset( dev_texts, dev_labels, tokenizer, max_seq_len ) test_ds = TextClassificationDataset( test_texts, test_labels, tokenizer, max_seq_len ) # Output directory output_dir = tempfile.mkdtemp(prefix='conflibert_ft_') # Training arguments best_metric = 'f1' if task_type == 'Binary' else 'f1_macro' training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size * 2, learning_rate=lr, weight_decay=weight_decay, warmup_ratio=warmup_ratio, gradient_accumulation_steps=grad_accum, fp16=fp16 and torch.cuda.is_available(), eval_strategy='epoch', save_strategy='epoch', load_best_model_at_end=True, metric_for_best_model=best_metric, greater_is_better=True, logging_steps=10, save_total_limit=2, lr_scheduler_type=scheduler, report_to='none', seed=42, ) # Callbacks log_callback = LogCallback() callbacks = [log_callback] if patience > 0: callbacks.append(EarlyStoppingCallback(early_stopping_patience=patience)) # Create Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=dev_ds, compute_metrics=make_compute_metrics(task_type), callbacks=callbacks, ) # Train train_result = trainer.train() # Evaluate on test set test_results = trainer.evaluate(test_ds, metric_key_prefix='test') # Build log text lora_info = "" if lora_active: method = "QLoRA (4-bit)" if use_qlora else "LoRA" lora_info = f"PEFT: {method} r={int(lora_rank)} alpha={int(lora_alpha)}\n" header = ( f"=== Configuration ===\n" f"Model: {model_display_name}\n" f" {model_id}\n" f"Task: {task_type} Classification ({num_labels} classes)\n" f"Data: {len(train_texts)} train / {len(dev_texts)} dev / {len(test_texts)} test\n" f"Epochs: {epochs} Batch: {batch_size} LR: {lr} Scheduler: {scheduler}\n" f"{lora_info}" f"\n=== Training Log ===\n" ) runtime = train_result.metrics.get('train_runtime', 0) footer = ( f"\n=== Training Complete ===\n" f"Time: {runtime:.1f}s ({runtime / 60:.1f} min)\n" ) log_text = header + log_callback.format() + footer # Build metrics DataFrame metrics_data = [] for k, v in sorted(test_results.items()): if isinstance(v, (int, float, np.floating, np.integer)) and k != 'test_epoch': name = k.replace('test_', '').replace('_', ' ').title() metrics_data.append([name, f"{float(v):.4f}"]) metrics_df = pd.DataFrame(metrics_data, columns=['Metric', 'Score']) # Merge LoRA weights back into base model for clean save/inference trained_model = trainer.model if lora_active and hasattr(trained_model, 'merge_and_unload'): trained_model = trained_model.merge_and_unload() trained_model = trained_model.cpu() trained_model.eval() return ( log_text, metrics_df, trained_model, tokenizer, num_labels, gr.Column(visible=True), gr.Column(visible=True), ) except Exception as e: error_log = f"Training failed:\n{str(e)}" empty_df = pd.DataFrame(columns=['Metric', 'Score']) return ( error_log, empty_df, None, None, None, gr.Column(visible=False), gr.Column(visible=False), ) # ============================================================================ # MODEL MANAGEMENT (predict, save, load) # ============================================================================ def predict_finetuned(text, model_state, tokenizer_state, num_labels_state): """Run prediction with the finetuned model stored in gr.State.""" if not text: return "Please enter some text." if model_state is None: return "No model available. Please train a model first." return predict_with_model(text, model_state, tokenizer_state) def save_finetuned_model(model_state, tokenizer_state): """Save the finetuned model as a downloadable zip file.""" if model_state is None: return None, "No model to save. Please train a model first." try: save_dir = tempfile.mkdtemp(prefix='conflibert_save_') model_state.save_pretrained(save_dir) tokenizer_state.save_pretrained(save_dir) import shutil zip_path = os.path.join(tempfile.gettempdir(), 'finetuned_model') shutil.make_archive(zip_path, 'zip', save_dir) return zip_path + '.zip', "Model ready for download." except Exception as e: return None, f"Error saving model: {str(e)}" def load_custom_model(path): """Load a finetuned classification model from disk.""" if not path or not os.path.isdir(path): return None, None, "Invalid path. Please enter a valid model directory." try: tokenizer = AutoTokenizer.from_pretrained(path) model = AutoModelForSequenceClassification.from_pretrained(path) model.eval() n = model.config.num_labels return model, tokenizer, f"Loaded model with {n} classes from: {path}" except Exception as e: return None, None, f"Error loading model: {str(e)}" def reset_custom_model(): """Reset to the pretrained ConfliBERT binary classifier.""" return None, None, "Reset to pretrained ConfliBERT binary classifier." def batch_predict_finetuned(file, model_state, tokenizer_state, num_labels_state): """Run batch predictions on a CSV using the finetuned model.""" if model_state is None: return None path = get_path(file) if path is None: return None df = pd.read_csv(path) if 'text' not in df.columns: raise ValueError("CSV must contain a 'text' column") model_state.eval() dev = next(model_state.parameters()).device predictions, confidences = [], [] for text in df['text']: if pd.isna(text): predictions.append("") confidences.append("") continue inputs = tokenizer_state( str(text), return_tensors='pt', truncation=True, padding=True, max_length=512, ) inputs = {k: v.to(dev) for k, v in inputs.items()} with torch.no_grad(): outputs = model_state(**inputs) probs = torch.softmax(outputs.logits, dim=1).squeeze() pred = torch.argmax(probs).item() conf = probs[pred].item() * 100 predictions.append(str(pred)) confidences.append(f"{conf:.1f}%") df['predicted_class'] = predictions df['confidence'] = confidences out = tempfile.NamedTemporaryFile(suffix='_predictions.csv', delete=False) df.to_csv(out.name, index=False) return out.name EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples") def load_example_binary(): """Load the binary classification example dataset.""" return ( os.path.join(EXAMPLES_DIR, "binary", "train.tsv"), os.path.join(EXAMPLES_DIR, "binary", "dev.tsv"), os.path.join(EXAMPLES_DIR, "binary", "test.tsv"), "Binary", ) def load_example_multiclass(): """Load the multiclass classification example dataset.""" return ( os.path.join(EXAMPLES_DIR, "multiclass", "train.tsv"), os.path.join(EXAMPLES_DIR, "multiclass", "dev.tsv"), os.path.join(EXAMPLES_DIR, "multiclass", "test.tsv"), "Multiclass", ) # ============================================================================ # ACTIVE LEARNING # ============================================================================ def parse_pool_file(file_path): """Parse an unlabeled text pool. Accepts CSV with 'text' column, or one text per line.""" path = get_path(file_path) # Try CSV/TSV with 'text' column first try: df = pd.read_csv(path) if 'text' in df.columns: texts = [str(t) for t in df['text'].dropna().tolist()] if texts: return texts except Exception: pass # Fallback: one text per line texts = [] with open(path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: texts.append(line) if not texts: raise ValueError("No texts found in pool file.") return texts def compute_uncertainty(model, tokenizer, texts, strategy='entropy', max_seq_len=512, batch_size=32): """Compute uncertainty scores for unlabeled texts. Higher = more uncertain.""" model.eval() dev = next(model.parameters()).device scores = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] inputs = tokenizer( batch_texts, return_tensors='pt', truncation=True, padding=True, max_length=max_seq_len, ) inputs = {k: v.to(dev) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=1).cpu().numpy() if strategy == 'entropy': s = -np.sum(probs * np.log(probs + 1e-10), axis=1) elif strategy == 'margin': sorted_p = np.sort(probs, axis=1) s = -(sorted_p[:, -1] - sorted_p[:, -2]) else: # least_confidence s = -np.max(probs, axis=1) scores.extend(s.tolist()) return scores def _build_al_metrics_chart(metrics_history, task_type): """Build a Plotly chart of active-learning metrics across rounds.""" import plotly.graph_objects as go if not metrics_history: return None rounds = [m['round'] for m in metrics_history] train_sizes = [m.get('train_size', 0) for m in metrics_history] metric_keys = (['f1', 'accuracy', 'precision', 'recall'] if task_type == 'Binary' else ['f1_macro', 'accuracy']) fig = go.Figure() colors = ['#ff6b35', '#3b82f6', '#10b981', '#8b5cf6'] for i, key in enumerate(metric_keys): values = [m.get(key) for m in metrics_history] if any(v is not None for v in values): fig.add_trace(go.Scatter( x=rounds, y=values, mode='lines+markers', name=key.replace('_', ' ').title(), line=dict(color=colors[i % len(colors)], width=2), )) fig.add_trace(go.Bar( x=rounds, y=train_sizes, name='Train Size', marker_color='rgba(200,200,200,0.4)', yaxis='y2', )) fig.update_layout( xaxis_title='Round', yaxis_title='Score', yaxis_range=[0, 1.05], yaxis2=dict(title='Train Size', overlaying='y', side='right'), template='plotly_white', legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1), height=350, margin=dict(t=40, b=40), ) return fig def _train_al_model(texts, labels, num_labels, dev_texts, dev_labels, task_type, model_id, epochs, batch_size, lr, max_seq_len, use_lora, lora_rank, lora_alpha): """Train a model for one active-learning round. Returns (model, tokenizer, eval_metrics).""" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSequenceClassification.from_pretrained( model_id, num_labels=num_labels, ) if use_lora and PEFT_AVAILABLE: lora_cfg = LoraConfig( task_type=TaskType.SEQ_CLS, r=int(lora_rank), lora_alpha=int(lora_alpha), lora_dropout=0.1, bias="none", ) model.enable_input_require_grads() model = get_peft_model(model, lora_cfg) train_ds = TextClassificationDataset(texts, labels, tokenizer, max_seq_len) dev_ds = None if dev_texts is not None: dev_ds = TextClassificationDataset(dev_texts, dev_labels, tokenizer, max_seq_len) output_dir = tempfile.mkdtemp(prefix='conflibert_al_') training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size * 2, learning_rate=lr, weight_decay=0.01, warmup_ratio=0.1, eval_strategy='epoch' if dev_ds else 'no', save_strategy='no', logging_steps=10, report_to='none', seed=42, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=dev_ds, compute_metrics=make_compute_metrics(task_type) if dev_ds else None, ) trainer.train() eval_metrics = {} if dev_ds: results = trainer.evaluate() for k, v in results.items(): if isinstance(v, (int, float, np.floating)): eval_metrics[k.replace('eval_', '')] = round(float(v), 4) trained_model = trainer.model if use_lora and PEFT_AVAILABLE and hasattr(trained_model, 'merge_and_unload'): trained_model = trained_model.merge_and_unload() return trained_model, tokenizer, eval_metrics def al_initialize( seed_file, pool_file, dev_file, task_type, model_display_name, query_strategy, query_size, epochs, batch_size, lr, max_seq_len, use_lora, lora_rank, lora_alpha, progress=gr.Progress(track_tqdm=True), ): """Initialize active learning: train on seed data, query first uncertain batch.""" try: if seed_file is None or pool_file is None: raise ValueError("Upload both a labeled seed file and an unlabeled pool file.") seed_texts, seed_labels, num_labels = parse_data_file(seed_file) pool_texts = parse_pool_file(pool_file) dev_texts, dev_labels = None, None if dev_file is not None: dev_texts, dev_labels, _ = parse_data_file(dev_file) if task_type == "Binary": num_labels = 2 query_size = int(query_size) model_id = FINETUNE_MODELS[model_display_name] trained_model, tokenizer, eval_metrics = _train_al_model( seed_texts, seed_labels, num_labels, dev_texts, dev_labels, task_type, model_id, int(epochs), int(batch_size), lr, int(max_seq_len), use_lora, lora_rank, lora_alpha, ) # Build round-0 metrics round_metrics = {'round': 0, 'train_size': len(seed_texts)} round_metrics.update(eval_metrics) # Query uncertain samples from pool scores = compute_uncertainty( trained_model, tokenizer, pool_texts, query_strategy, int(max_seq_len), ) top_indices = np.argsort(scores)[-query_size:][::-1].tolist() query_texts_batch = [pool_texts[i] for i in top_indices] annotation_df = pd.DataFrame({ 'Text': query_texts_batch, 'Label': [''] * len(query_texts_batch), }) al_state = { 'labeled_texts': list(seed_texts), 'labeled_labels': list(seed_labels), 'pool_texts': pool_texts, 'pool_available': [i for i in range(len(pool_texts)) if i not in set(top_indices)], 'current_query_indices': top_indices, 'dev_texts': dev_texts, 'dev_labels': dev_labels, 'num_labels': num_labels, 'round': 1, 'metrics_history': [round_metrics], 'model_id': model_id, 'model_display_name': model_display_name, 'task_type': task_type, 'query_strategy': query_strategy, 'query_size': query_size, 'epochs': int(epochs), 'batch_size': int(batch_size), 'lr': lr, 'max_seq_len': int(max_seq_len), 'use_lora': use_lora, 'lora_rank': int(lora_rank) if use_lora else 8, 'lora_alpha': int(lora_alpha) if use_lora else 16, } trained_model = trained_model.cpu() trained_model.eval() log_text = ( f"=== Active Learning Initialized ===\n" f"Seed: {len(seed_texts)} labeled | Pool: {len(pool_texts)} unlabeled\n" f"Model: {model_display_name}\n" f"Strategy: {query_strategy} | Samples/round: {query_size}\n\n" f"--- Round 0 (seed) ---\n" f"Train size: {len(seed_texts)}\n" ) for k, v in eval_metrics.items(): log_text += f" {k}: {v}\n" log_text += ( f"\n--- Round 1: {len(query_texts_batch)} samples queried ---\n" f"Label the samples below, then click 'Submit Labels & Next Round'.\n" ) chart = _build_al_metrics_chart([round_metrics], task_type) return ( al_state, trained_model, tokenizer, annotation_df, log_text, chart, gr.Column(visible=True), ) except Exception as e: return ( {}, None, None, pd.DataFrame(columns=['Text', 'Label']), f"Initialization failed:\n{str(e)}", None, gr.Column(visible=False), ) def al_submit_and_continue( annotation_df, al_state, al_model, al_tokenizer, prev_log, progress=gr.Progress(track_tqdm=True), ): """Accept user labels, retrain, query next uncertain batch.""" try: if not al_state or al_model is None: raise ValueError("No active session. Initialize first.") new_texts = annotation_df['Text'].tolist() new_labels = [] for i, raw in enumerate(annotation_df['Label'].tolist()): s = str(raw).strip() if s in ('', 'nan'): raise ValueError(f"Row {i + 1} has no label. Label all samples first.") new_labels.append(int(s)) num_labels = al_state['num_labels'] for l in new_labels: if l < 0 or l >= num_labels: raise ValueError(f"Label {l} out of range [0, {num_labels - 1}].") # Add newly labeled samples al_state['labeled_texts'].extend(new_texts) al_state['labeled_labels'].extend(new_labels) queried_set = set(al_state['current_query_indices']) al_state['pool_available'] = [ i for i in al_state['pool_available'] if i not in queried_set ] current_round = al_state['round'] # Retrain on all labeled data trained_model, tokenizer, eval_metrics = _train_al_model( al_state['labeled_texts'], al_state['labeled_labels'], num_labels, al_state['dev_texts'], al_state['dev_labels'], al_state['task_type'], al_state['model_id'], al_state['epochs'], al_state['batch_size'], al_state['lr'], al_state['max_seq_len'], al_state['use_lora'], al_state['lora_rank'], al_state['lora_alpha'], ) round_metrics = { 'round': current_round, 'train_size': len(al_state['labeled_texts']), } round_metrics.update(eval_metrics) al_state['metrics_history'].append(round_metrics) # Query next batch from remaining pool remaining_pool = al_state['pool_available'] remaining_texts = [al_state['pool_texts'][i] for i in remaining_pool] log_add = ( f"\n--- Round {current_round} complete ---\n" f"Added {len(new_labels)} labels | " f"Total train: {len(al_state['labeled_texts'])}\n" ) for k, v in eval_metrics.items(): log_add += f" {k}: {v}\n" if remaining_texts: scores = compute_uncertainty( trained_model, tokenizer, remaining_texts, al_state['query_strategy'], al_state['max_seq_len'], ) q = min(al_state['query_size'], len(remaining_texts)) top_local = np.argsort(scores)[-q:][::-1].tolist() top_pool_indices = [remaining_pool[i] for i in top_local] query_texts = [al_state['pool_texts'][i] for i in top_pool_indices] al_state['current_query_indices'] = top_pool_indices al_state['round'] = current_round + 1 annotation_out = pd.DataFrame({ 'Text': query_texts, 'Label': [''] * len(query_texts), }) pool_left = len(remaining_pool) - len(top_pool_indices) log_add += ( f"Pool remaining: {pool_left}\n" f"\n--- Round {current_round + 1}: {len(query_texts)} samples queried ---\n" ) else: annotation_out = pd.DataFrame(columns=['Text', 'Label']) al_state['current_query_indices'] = [] al_state['round'] = current_round + 1 log_add += "\nPool exhausted. Active learning complete!\n" trained_model = trained_model.cpu() trained_model.eval() chart = _build_al_metrics_chart(al_state['metrics_history'], al_state['task_type']) log_text = prev_log + log_add return ( al_state, trained_model, tokenizer, annotation_out, log_text, chart, ) except Exception as e: return ( al_state, al_model, al_tokenizer, pd.DataFrame(columns=['Text', 'Label']), prev_log + f"\nError: {str(e)}\n", None, ) def al_save_model(al_model, al_tokenizer): """Save the active-learning model as a downloadable zip file.""" if al_model is None: return None, "No model to save. Run at least one round first." try: save_dir = tempfile.mkdtemp(prefix='conflibert_al_save_') al_model.save_pretrained(save_dir) al_tokenizer.save_pretrained(save_dir) import shutil zip_path = os.path.join(tempfile.gettempdir(), 'al_model') shutil.make_archive(zip_path, 'zip', save_dir) return zip_path + '.zip', "Model ready for download." except Exception as e: return None, f"Error saving model: {str(e)}" def load_example_active_learning(): """Load the active learning example dataset.""" return ( os.path.join(EXAMPLES_DIR, "active_learning", "seed.tsv"), os.path.join(EXAMPLES_DIR, "active_learning", "pool.txt"), os.path.join(EXAMPLES_DIR, "binary", "dev.tsv"), "Binary", ) def run_comparison( train_file, dev_file, test_file, task_type, selected_models, epochs, batch_size, lr, cmp_use_lora, cmp_lora_rank, cmp_lora_alpha, progress=gr.Progress(track_tqdm=True), ): """Train multiple models on the same data and compare performance + ROC curves.""" import plotly.graph_objects as go from plotly.subplots import make_subplots empty = ("", None, None, None, gr.Column(visible=False)) try: if not selected_models or len(selected_models) < 2: return ("Select at least 2 models to compare.",) + empty[1:] if train_file is None or dev_file is None or test_file is None: return ("Upload all 3 data files first.",) + empty[1:] epochs = int(epochs) batch_size = int(batch_size) train_texts, train_labels, n_train = parse_data_file(train_file) dev_texts, dev_labels, n_dev = parse_data_file(dev_file) test_texts, test_labels, n_test = parse_data_file(test_file) num_labels = max(n_train, n_dev, n_test) if task_type == "Binary": num_labels = 2 # Only keep these metrics for the table and bar chart if task_type == "Binary": keep_metrics = {'Accuracy', 'Precision', 'Recall', 'F1'} else: keep_metrics = { 'Accuracy', 'F1 Macro', 'F1 Micro', 'Precision Macro', 'Recall Macro', } results = [] roc_data = {} # model_name -> (true_labels, probabilities) log_lines = [] for i, model_display_name in enumerate(selected_models): model_id = FINETUNE_MODELS[model_display_name] short_name = model_display_name.split(" (")[0] log_lines.append(f"[{i + 1}/{len(selected_models)}] Training {short_name}...") try: tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSequenceClassification.from_pretrained( model_id, num_labels=num_labels, ) cmp_lora_active = False if cmp_use_lora and PEFT_AVAILABLE: lora_cfg = LoraConfig( task_type=TaskType.SEQ_CLS, r=int(cmp_lora_rank), lora_alpha=int(cmp_lora_alpha), lora_dropout=0.1, bias="none", ) model.enable_input_require_grads() model = get_peft_model(model, lora_cfg) cmp_lora_active = True train_ds = TextClassificationDataset(train_texts, train_labels, tokenizer, 512) dev_ds = TextClassificationDataset(dev_texts, dev_labels, tokenizer, 512) test_ds = TextClassificationDataset(test_texts, test_labels, tokenizer, 512) output_dir = tempfile.mkdtemp(prefix='conflibert_cmp_') best_metric = 'f1' if task_type == 'Binary' else 'f1_macro' training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size * 2, learning_rate=lr, weight_decay=0.01, warmup_ratio=0.1, eval_strategy='epoch', save_strategy='epoch', load_best_model_at_end=True, metric_for_best_model=best_metric, greater_is_better=True, logging_steps=50, save_total_limit=1, report_to='none', seed=42, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=dev_ds, compute_metrics=make_compute_metrics(task_type), ) train_result = trainer.train() # Merge LoRA weights before prediction if cmp_lora_active and hasattr(trainer.model, 'merge_and_unload'): trainer.model = trainer.model.merge_and_unload() # Get predictions for ROC curves pred_output = trainer.predict(test_ds) logits = pred_output.predictions true_labels = pred_output.label_ids probs = torch.softmax(torch.tensor(logits), dim=1).numpy() roc_data[short_name] = (true_labels, probs) # Collect classification metrics only test_results = trainer.evaluate(test_ds, metric_key_prefix='test') row = {'Model': short_name} for k, v in sorted(test_results.items()): if not isinstance(v, (int, float, np.floating, np.integer)): continue name = k.replace('test_', '').replace('_', ' ').title() if name in keep_metrics: row[name] = round(float(v), 4) results.append(row) runtime = train_result.metrics.get('train_runtime', 0) log_lines.append(f" Done in {runtime:.1f}s") del model, trainer, tokenizer, train_ds, dev_ds, test_ds gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: log_lines.append(f" Failed: {str(e)}") log_lines.append(f"\nComparison complete. {len(results)} models evaluated.") log_text = "\n".join(log_lines) if not results: return log_text, None, None, None, gr.Column(visible=False) comparison_df = pd.DataFrame(results) # --- Bar chart: classification metrics only --- metric_cols = [c for c in comparison_df.columns if c in keep_metrics] colors = ['#ff6b35', '#3b82f6', '#10b981', '#8b5cf6', '#f59e0b'] fig_bar = go.Figure() for j, metric in enumerate(metric_cols): fig_bar.add_trace(go.Bar( name=metric, x=comparison_df['Model'], y=comparison_df[metric], text=comparison_df[metric].apply( lambda x: f'{x:.3f}' if isinstance(x, float) else '' ), textposition='auto', marker_color=colors[j % len(colors)], )) fig_bar.update_layout( barmode='group', yaxis_title='Score', yaxis_range=[0, 1.05], template='plotly_white', legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1), height=400, margin=dict(t=40, b=40), ) # --- ROC curves --- model_colors = ['#ff6b35', '#3b82f6', '#10b981', '#8b5cf6', '#f59e0b', '#ec4899', '#06b6d4'] fig_roc = go.Figure() for j, (model_name, (labels, probs)) in enumerate(roc_data.items()): color = model_colors[j % len(model_colors)] if num_labels == 2: fpr, tpr, _ = roc_curve(labels, probs[:, 1]) roc_auc_val = sk_auc(fpr, tpr) fig_roc.add_trace(go.Scatter( x=fpr, y=tpr, mode='lines', name=f'{model_name} (AUC = {roc_auc_val:.3f})', line=dict(color=color, width=2), )) else: # Macro-average ROC for multiclass labels_bin = label_binarize(labels, classes=list(range(num_labels))) all_fpr = np.linspace(0, 1, 200) mean_tpr = np.zeros_like(all_fpr) for c in range(num_labels): fpr_c, tpr_c, _ = roc_curve(labels_bin[:, c], probs[:, c]) mean_tpr += np.interp(all_fpr, fpr_c, tpr_c) mean_tpr /= num_labels roc_auc_val = sk_auc(all_fpr, mean_tpr) fig_roc.add_trace(go.Scatter( x=all_fpr, y=mean_tpr, mode='lines', name=f'{model_name} (macro AUC = {roc_auc_val:.3f})', line=dict(color=color, width=2), )) # Diagonal reference line fig_roc.add_trace(go.Scatter( x=[0, 1], y=[0, 1], mode='lines', line=dict(dash='dash', color='#ccc', width=1), showlegend=False, )) fig_roc.update_layout( xaxis_title='False Positive Rate', yaxis_title='True Positive Rate', template='plotly_white', legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1), height=400, margin=dict(t=40, b=40), ) return log_text, comparison_df, fig_bar, fig_roc, gr.Column(visible=True) except Exception as e: return f"Comparison failed: {str(e)}", None, None, None, gr.Column(visible=False) # ============================================================================ # THEME & CSS # ============================================================================ utd_orange = gr.themes.Color( c50="#fff7f3", c100="#ffead9", c200="#ffd4b3", c300="#ffb380", c400="#ff8c52", c500="#ff6b35", c600="#e8551f", c700="#c2410c", c800="#9a3412", c900="#7c2d12", c950="#431407", ) theme = gr.themes.Soft( primary_hue=utd_orange, secondary_hue="neutral", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], ) custom_css = """ /* Top accent bar */ .gradio-container::before { content: ''; display: block; height: 4px; background: linear-gradient(90deg, #ff6b35, #ff9f40, #ff6b35); position: fixed; top: 0; left: 0; right: 0; z-index: 1000; } /* Active tab styling */ .tab-nav button.selected { border-bottom-color: #ff6b35 !important; color: #ff6b35 !important; font-weight: 600 !important; } /* Log output - monospace */ .log-output textarea { font-family: 'JetBrains Mono', 'Fira Code', 'Consolas', monospace !important; font-size: 0.8rem !important; line-height: 1.5 !important; } /* Dark mode: info callout adjustment */ .dark .info-callout-inner { background: rgba(255, 107, 53, 0.1) !important; color: #ffead9 !important; } /* Clean container width */ .gradio-container { max-width: 1200px !important; } /* Smooth transitions */ .gradio-container * { transition: background-color 0.2s ease, border-color 0.2s ease !important; } """ # ============================================================================ # GRADIO UI # ============================================================================ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo: # ---- HEADER ---- gr.Markdown( "
    " "

    " "ConfliBERT

    " "

    " "A Pretrained Language Model for Conflict and Political Violence

    " ) with gr.Tabs(): # ================================================================ # HOME TAB # ================================================================ with gr.Tab("Home"): gr.Markdown( "## Welcome to ConfliBERT\n\n" "ConfliBERT is a pretrained language model built specifically for " "conflict and political violence text. This application lets you " "run inference with ConfliBERT's pretrained models and fine-tune " "your own classifiers on custom data. Use the tabs above to get started." ) with gr.Row(equal_height=True): with gr.Column(): gr.Markdown( "### Inference\n\n" "Run pretrained ConfliBERT models on your text. " "Each task has its own tab with single-text analysis " "and CSV batch processing.\n\n" "**Named Entity Recognition**\n" "Identify persons, organizations, locations, weapons, " "and other entities in text. Results are color-coded " "by entity type.\n\n" "**Binary Classification**\n" "Determine whether text is related to conflict, violence, " "or politics (positive) or not (negative). You can also " "load a custom fine-tuned model here.\n\n" "**Multilabel Classification**\n" "Score text against four event categories: Armed Assault, " "Bombing/Explosion, Kidnapping, and Other. Each category " "is scored independently.\n\n" "**Question Answering**\n" "Provide a context passage and ask a question. The model " "extracts the most relevant answer span from the text." ) with gr.Column(): gr.Markdown( "### Fine-tuning\n\n" "Train your own binary or multiclass text classifier " "on custom labeled data, all within the browser.\n\n" "**Workflow:**\n" "1. Upload your training, validation, and test data as " "TSV files (or load a built-in example dataset)\n" "2. Pick a base model: ConfliBERT, BERT, RoBERTa, " "ModernBERT, DeBERTa, or DistilBERT\n" "3. Configure training parameters (sensible defaults " "are provided)\n" "4. Train and watch progress in real time\n" "5. Review test-set metrics (accuracy, precision, " "recall, F1)\n" "6. Try your model on new text immediately\n" "7. Run batch predictions on a CSV\n" "8. Save the model and load it later in the " "Classification tab\n\n" "**Advanced features:**\n" "- **LoRA / QLoRA** for parameter-efficient training " "(lower VRAM, faster)\n" "- **Active Learning** tab for iterative labeling " "with uncertainty sampling\n" "- Early stopping with configurable patience\n" "- Learning rate schedulers (linear, cosine, constant)\n" "- Mixed precision training (FP16 on CUDA GPUs)\n" "- Gradient accumulation for larger effective batch sizes\n" "- Weight decay regularization" ) gr.Markdown( f"---\n\n" f"**Your system:** {get_system_info()}" ) gr.Markdown( "**Citation:** Brandt, P.T., Alsarra, S., D'Orazio, V., " "Heintze, D., Khan, L., Meher, S., Osorio, J. and Sianan, M., " "2025. Extractive versus Generative Language Models for Political " "Conflict Text Classification. *Political Analysis*, pp.1-29." ) # ================================================================ # NER TAB # ================================================================ with gr.Tab("Named Entity Recognition"): gr.Markdown(info_callout( "Identify entities in text such as **persons**, **organizations**, " "**locations**, **weapons**, and more. Results are color-coded by type." )) with gr.Row(equal_height=True): with gr.Column(): ner_input = gr.Textbox( lines=6, placeholder="Paste or type text to analyze for entities...", label="Input Text", ) ner_btn = gr.Button("Analyze Entities", variant="primary") with gr.Column(): ner_output = gr.HTML(label="Results") with gr.Accordion("Batch Processing (CSV)", open=False): gr.Markdown( "Upload a CSV file with a `text` column to process " "multiple texts at once." ) with gr.Row(): ner_csv_in = gr.File( label="Upload CSV", file_types=[".csv"], ) ner_csv_out = gr.File(label="Download Results") ner_csv_btn = gr.Button("Process CSV", variant="secondary") # ================================================================ # BINARY CLASSIFICATION TAB # ================================================================ with gr.Tab("Binary Classification"): gr.Markdown(info_callout( "Classify text as **conflict-related** (positive) or " "**not conflict-related** (negative). Uses the pretrained ConfliBERT " "binary classifier by default, or load your own finetuned model below." )) custom_clf_model = gr.State(None) custom_clf_tokenizer = gr.State(None) with gr.Row(equal_height=True): with gr.Column(): clf_input = gr.Textbox( lines=6, placeholder="Paste or type text to classify...", label="Input Text", ) clf_btn = gr.Button("Classify", variant="primary") with gr.Column(): clf_output = gr.HTML(label="Results") with gr.Accordion("Batch Processing (CSV)", open=False): gr.Markdown("Upload a CSV file with a `text` column.") with gr.Row(): clf_csv_in = gr.File(label="Upload CSV", file_types=[".csv"]) clf_csv_out = gr.File(label="Download Results") clf_csv_btn = gr.Button("Process CSV", variant="secondary") with gr.Accordion("Load Custom Model", open=False): gr.Markdown( "Load a finetuned classification model from a local directory " "to use instead of the default pretrained classifier." ) clf_model_path = gr.Textbox( label="Model directory path", placeholder="e.g., ./finetuned_model", ) with gr.Row(): clf_load_btn = gr.Button("Load Model", variant="secondary") clf_reset_btn = gr.Button( "Reset to Pretrained", variant="secondary", ) clf_status = gr.Markdown("") # ================================================================ # MULTILABEL CLASSIFICATION TAB # ================================================================ with gr.Tab("Multilabel Classification"): gr.Markdown(info_callout( "Identify multiple event types in text. Each category is scored " "independently: **Armed Assault**, **Bombing/Explosion**, " "**Kidnapping**, **Other**. Categories above 50% confidence " "are highlighted. Load a custom finetuned model below." )) custom_multi_model = gr.State(None) custom_multi_tokenizer = gr.State(None) with gr.Row(equal_height=True): with gr.Column(): multi_input = gr.Textbox( lines=6, placeholder="Paste or type text to classify...", label="Input Text", ) multi_btn = gr.Button("Classify", variant="primary") with gr.Column(): multi_output = gr.HTML(label="Results") with gr.Accordion("Batch Processing (CSV)", open=False): gr.Markdown("Upload a CSV file with a `text` column.") with gr.Row(): multi_csv_in = gr.File(label="Upload CSV", file_types=[".csv"]) multi_csv_out = gr.File(label="Download Results") multi_csv_btn = gr.Button("Process CSV", variant="secondary") with gr.Accordion("Load Custom Model", open=False): gr.Markdown( "Load a finetuned multiclass model from a local directory " "to use instead of the default pretrained classifier." ) multi_model_path = gr.Textbox( label="Model directory path", placeholder="e.g., ./finetuned_model", ) with gr.Row(): multi_load_btn = gr.Button("Load Model", variant="secondary") multi_reset_btn = gr.Button( "Reset to Pretrained", variant="secondary", ) multi_status = gr.Markdown("") # ================================================================ # QUESTION ANSWERING TAB # ================================================================ with gr.Tab("Question Answering"): gr.Markdown(info_callout( "Extract answers from a context passage. Provide a paragraph of " "text and ask a question about it. The model will highlight the " "most relevant span." )) with gr.Row(equal_height=True): with gr.Column(): qa_context = gr.Textbox( lines=6, placeholder="Paste the context passage here...", label="Context", ) qa_question = gr.Textbox( lines=2, placeholder="What would you like to know?", label="Question", ) qa_btn = gr.Button("Get Answer", variant="primary") with gr.Column(): qa_output = gr.HTML(label="Answer") with gr.Accordion("Batch Processing (CSV)", open=False): gr.Markdown( "Upload a CSV file with `context` and `question` columns " "to process multiple questions at once." ) with gr.Row(): qa_csv_in = gr.File( label="Upload CSV", file_types=[".csv"], ) qa_csv_out = gr.File(label="Download Results") qa_csv_btn = gr.Button("Process CSV", variant="secondary") # ================================================================ # FINE-TUNE TAB # ================================================================ with gr.Tab("Fine-tune"): gr.Markdown(info_callout( "Fine-tune a binary or multiclass classifier on your own data. " "Upload labeled TSV files, pick a base model, and train. " "Or compare multiple models head-to-head on the same dataset." )) # -- Data -- gr.Markdown("### Data") gr.Markdown( "TSV files, no header, format: `text[TAB]label` " "(binary: 0/1, multiclass: 0, 1, 2, ...)" ) with gr.Row(): ft_ex_binary_btn = gr.Button( "Load Example: Binary", variant="secondary", size="sm", ) ft_ex_multi_btn = gr.Button( "Load Example: Multiclass (4 classes)", variant="secondary", size="sm", ) with gr.Row(): ft_train_file = gr.File( label="Train", file_types=[".tsv", ".csv", ".txt"], ) ft_dev_file = gr.File( label="Validation", file_types=[".tsv", ".csv", ".txt"], ) ft_test_file = gr.File( label="Test", file_types=[".tsv", ".csv", ".txt"], ) # -- Configuration -- gr.Markdown("### Configuration") with gr.Row(): ft_task = gr.Radio( ["Binary", "Multiclass"], label="Task Type", value="Binary", ) ft_model = gr.Dropdown( choices=list(FINETUNE_MODELS.keys()), label="Base Model", value=list(FINETUNE_MODELS.keys())[0], ) with gr.Row(): ft_epochs = gr.Number( label="Epochs", value=3, minimum=1, maximum=100, precision=0, ) ft_batch = gr.Number( label="Batch Size", value=8, minimum=1, maximum=128, precision=0, ) ft_lr = gr.Number( label="Learning Rate", value=2e-5, minimum=1e-7, maximum=1e-2, ) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): ft_weight_decay = gr.Number( label="Weight Decay", value=0.01, minimum=0, maximum=1, ) ft_warmup = gr.Number( label="Warmup Ratio", value=0.1, minimum=0, maximum=0.5, ) ft_max_len = gr.Number( label="Max Sequence Length", value=512, minimum=32, maximum=8192, precision=0, ) with gr.Row(): ft_grad_accum = gr.Number( label="Gradient Accumulation", value=1, minimum=1, maximum=64, precision=0, ) ft_fp16 = gr.Checkbox( label="Mixed Precision (FP16)", value=False, ) ft_patience = gr.Number( label="Early Stopping Patience", value=3, minimum=0, maximum=20, precision=0, ) ft_scheduler = gr.Dropdown( ["linear", "cosine", "constant", "constant_with_warmup"], label="LR Scheduler", value="linear", ) gr.Markdown("**Parameter-Efficient Fine-Tuning (PEFT)**") with gr.Row(): ft_use_lora = gr.Checkbox( label="Use LoRA", value=False, ) ft_lora_rank = gr.Number( label="LoRA Rank (r)", value=8, minimum=1, maximum=256, precision=0, ) ft_lora_alpha = gr.Number( label="LoRA Alpha", value=16, minimum=1, maximum=512, precision=0, ) ft_use_qlora = gr.Checkbox( label="QLoRA (4-bit, CUDA only)", value=False, ) # -- Train -- ft_train_btn = gr.Button( "Start Training", variant="primary", size="lg", ) # State for the trained model ft_model_state = gr.State(None) ft_tokenizer_state = gr.State(None) ft_num_labels_state = gr.State(None) with gr.Accordion("Training Log", open=False) as ft_log_accordion: ft_log = gr.Textbox( lines=12, interactive=False, elem_classes="log-output", show_label=False, ) # -- Results + Try Model (hidden until training completes) -- with gr.Column(visible=False) as ft_results_col: gr.Markdown("### Results") with gr.Row(equal_height=True): with gr.Column(scale=2): ft_metrics = gr.Dataframe( label="Test Set Metrics", headers=["Metric", "Score"], interactive=False, ) with gr.Column(scale=3): gr.Markdown("**Try your model**") ft_try_input = gr.Textbox( lines=2, label="Input Text", placeholder="Type text to classify...", ) with gr.Row(): ft_try_btn = gr.Button("Predict", variant="primary") ft_try_output = gr.HTML(label="Prediction") # -- Save + Batch (hidden until training completes) -- with gr.Column(visible=False) as ft_actions_col: with gr.Row(equal_height=True): with gr.Column(): gr.Markdown("**Download model**") ft_save_btn = gr.Button("Prepare Download", variant="secondary") ft_save_file = gr.File(label="Download Model (.zip)") ft_save_status = gr.Markdown("") with gr.Column(): gr.Markdown("**Batch predictions**") ft_batch_in = gr.File( label="Upload CSV (needs 'text' column)", file_types=[".csv"], ) ft_batch_btn = gr.Button( "Run Predictions", variant="secondary", ) ft_batch_out = gr.File(label="Download Results") # -- Compare Models -- gr.Markdown("---") with gr.Accordion("Compare Multiple Models", open=False): gr.Markdown( "Train the same dataset on different base models and compare " "performance side by side. Uses the data and task type above." ) cmp_models = gr.CheckboxGroup( choices=list(FINETUNE_MODELS.keys()), label="Select models to compare (pick 2 or more)", ) with gr.Row(): cmp_epochs = gr.Number(label="Epochs", value=3, minimum=1, precision=0) cmp_batch = gr.Number(label="Batch Size", value=8, minimum=1, precision=0) cmp_lr = gr.Number(label="Learning Rate", value=2e-5, minimum=1e-7) with gr.Row(): cmp_use_lora = gr.Checkbox(label="Use LoRA", value=False) cmp_lora_rank = gr.Number(label="LoRA Rank", value=8, minimum=1, maximum=256, precision=0) cmp_lora_alpha = gr.Number(label="LoRA Alpha", value=16, minimum=1, maximum=512, precision=0) cmp_btn = gr.Button("Compare Models", variant="primary") cmp_log = gr.Textbox( label="Comparison Log", lines=8, interactive=False, elem_classes="log-output", ) with gr.Column(visible=False) as cmp_results_col: cmp_table = gr.Dataframe( label="Comparison Results", interactive=False, ) cmp_plot = gr.Plot(label="Metrics Comparison") cmp_roc = gr.Plot(label="ROC Curves") # ================================================================ # ACTIVE LEARNING TAB # ================================================================ with gr.Tab("Active Learning"): gr.Markdown(info_callout( "**Active learning** iteratively selects the most uncertain " "samples from an unlabeled pool for you to label, then retrains. " "This lets you build a strong classifier with far fewer labels." )) # -- Data -- gr.Markdown("### Data") gr.Markdown( "**Seed file** โ€” small labeled set (TSV, `text[TAB]label`). \n" "**Pool file** โ€” unlabeled texts (one per line, or CSV with `text` column). \n" "**Dev file** *(optional)* โ€” held-out labeled set to track metrics." ) al_ex_btn = gr.Button( "Load Example: Binary Active Learning", variant="secondary", size="sm", ) with gr.Row(): al_seed_file = gr.File( label="Labeled Seed (TSV)", file_types=[".tsv", ".csv", ".txt"], ) al_pool_file = gr.File( label="Unlabeled Pool", file_types=[".tsv", ".csv", ".txt"], ) al_dev_file = gr.File( label="Dev / Validation (optional)", file_types=[".tsv", ".csv", ".txt"], ) # -- Configuration -- gr.Markdown("### Configuration") with gr.Row(): al_task = gr.Radio( ["Binary", "Multiclass"], label="Task Type", value="Binary", ) al_model_dd = gr.Dropdown( choices=list(FINETUNE_MODELS.keys()), label="Base Model", value=list(FINETUNE_MODELS.keys())[0], ) with gr.Row(): al_strategy = gr.Dropdown( ["entropy", "margin", "least_confidence"], label="Query Strategy", value="entropy", ) al_query_size = gr.Number( label="Samples per Round", value=20, minimum=1, maximum=500, precision=0, ) with gr.Row(): al_epochs = gr.Number( label="Epochs per Round", value=3, minimum=1, maximum=50, precision=0, ) al_batch_size = gr.Number( label="Batch Size", value=8, minimum=1, maximum=128, precision=0, ) al_lr = gr.Number( label="Learning Rate", value=2e-5, minimum=1e-7, maximum=1e-2, ) with gr.Accordion("Advanced", open=False): with gr.Row(): al_max_len = gr.Number( label="Max Sequence Length", value=512, minimum=32, maximum=8192, precision=0, ) al_use_lora = gr.Checkbox(label="Use LoRA", value=False) al_lora_rank = gr.Number( label="LoRA Rank", value=8, minimum=1, maximum=256, precision=0, ) al_lora_alpha = gr.Number( label="LoRA Alpha", value=16, minimum=1, maximum=512, precision=0, ) al_init_btn = gr.Button( "Initialize Active Learning", variant="primary", size="lg", ) # -- State -- al_state = gr.State({}) al_model_state = gr.State(None) al_tokenizer_state = gr.State(None) with gr.Accordion("Log", open=False): al_log = gr.Textbox( lines=12, interactive=False, elem_classes="log-output", show_label=False, ) # -- Annotation panel (hidden until init) -- with gr.Column(visible=False) as al_annotation_col: gr.Markdown("### Label These Samples") gr.Markdown( "Fill in the **Label** column with integer class labels " "(e.g. 0 or 1 for binary). Then click **Submit**." ) al_annotation_df = gr.Dataframe( headers=["Text", "Label"], interactive=True, ) with gr.Row(): al_submit_btn = gr.Button( "Submit Labels & Next Round", variant="primary", ) al_chart = gr.Plot(label="Metrics Across Rounds") gr.Markdown("### Download Model") with gr.Row(): al_save_btn = gr.Button("Prepare Download", variant="secondary") al_save_file = gr.File(label="Download Model (.zip)") al_save_status = gr.Markdown("") # ---- FOOTER ---- gr.Markdown( "
    " "

    " "Developed by " "Shreyas Meher" "

    " "

    " "If you use ConfliBERT in your research, please cite:
    " "Brandt, P.T., Alsarra, S., D'Orazio, V., Heintze, D., Khan, L., " "Meher, S., Osorio, J. and Sianan, M., 2025. Extractive versus Generative " "Language Models for Political Conflict Text Classification. " "Political Analysis, pp.1–29." "

    " ) # ==================================================================== # EVENT HANDLERS # ==================================================================== # NER ner_btn.click( fn=named_entity_recognition, inputs=[ner_input], outputs=[ner_output], ) ner_csv_btn.click( fn=process_csv_ner, inputs=[ner_csv_in], outputs=[ner_csv_out], ) # Binary Classification clf_btn.click( fn=text_classification, inputs=[clf_input, custom_clf_model, custom_clf_tokenizer], outputs=[clf_output], ) clf_csv_btn.click( fn=process_csv_binary, inputs=[clf_csv_in, custom_clf_model, custom_clf_tokenizer], outputs=[clf_csv_out], ) clf_load_btn.click( fn=load_custom_model, inputs=[clf_model_path], outputs=[custom_clf_model, custom_clf_tokenizer, clf_status], ) clf_reset_btn.click( fn=reset_custom_model, outputs=[custom_clf_model, custom_clf_tokenizer, clf_status], ) # Multilabel Classification multi_btn.click( fn=multilabel_classification, inputs=[multi_input, custom_multi_model, custom_multi_tokenizer], outputs=[multi_output], ) multi_csv_btn.click( fn=process_csv_multilabel, inputs=[multi_csv_in], outputs=[multi_csv_out], ) multi_load_btn.click( fn=load_custom_model, inputs=[multi_model_path], outputs=[custom_multi_model, custom_multi_tokenizer, multi_status], ) multi_reset_btn.click( fn=reset_custom_model, outputs=[custom_multi_model, custom_multi_tokenizer, multi_status], ) # Question Answering qa_btn.click( fn=question_answering, inputs=[qa_context, qa_question], outputs=[qa_output], ) qa_csv_btn.click( fn=process_csv_qa, inputs=[qa_csv_in], outputs=[qa_csv_out], ) # Fine-tuning: example dataset loaders ft_ex_binary_btn.click( fn=load_example_binary, outputs=[ft_train_file, ft_dev_file, ft_test_file, ft_task], ) ft_ex_multi_btn.click( fn=load_example_multiclass, outputs=[ft_train_file, ft_dev_file, ft_test_file, ft_task], ) # Fine-tuning: training ft_train_btn.click( fn=run_finetuning, inputs=[ ft_train_file, ft_dev_file, ft_test_file, ft_task, ft_model, ft_epochs, ft_batch, ft_lr, ft_weight_decay, ft_warmup, ft_max_len, ft_grad_accum, ft_fp16, ft_patience, ft_scheduler, ft_use_lora, ft_lora_rank, ft_lora_alpha, ft_use_qlora, ], outputs=[ ft_log, ft_metrics, ft_model_state, ft_tokenizer_state, ft_num_labels_state, ft_results_col, ft_actions_col, ], concurrency_limit=1, ) # Try finetuned model ft_try_btn.click( fn=predict_finetuned, inputs=[ft_try_input, ft_model_state, ft_tokenizer_state, ft_num_labels_state], outputs=[ft_try_output], ) # Save finetuned model ft_save_btn.click( fn=save_finetuned_model, inputs=[ft_model_state, ft_tokenizer_state], outputs=[ft_save_file, ft_save_status], ) # Batch predictions with finetuned model ft_batch_btn.click( fn=batch_predict_finetuned, inputs=[ft_batch_in, ft_model_state, ft_tokenizer_state, ft_num_labels_state], outputs=[ft_batch_out], ) # Active Learning: example loader al_ex_btn.click( fn=load_example_active_learning, outputs=[al_seed_file, al_pool_file, al_dev_file, al_task], ) # Active Learning al_init_btn.click( fn=al_initialize, inputs=[ al_seed_file, al_pool_file, al_dev_file, al_task, al_model_dd, al_strategy, al_query_size, al_epochs, al_batch_size, al_lr, al_max_len, al_use_lora, al_lora_rank, al_lora_alpha, ], outputs=[ al_state, al_model_state, al_tokenizer_state, al_annotation_df, al_log, al_chart, al_annotation_col, ], concurrency_limit=1, ) al_submit_btn.click( fn=al_submit_and_continue, inputs=[ al_annotation_df, al_state, al_model_state, al_tokenizer_state, al_log, ], outputs=[ al_state, al_model_state, al_tokenizer_state, al_annotation_df, al_log, al_chart, ], concurrency_limit=1, ) al_save_btn.click( fn=al_save_model, inputs=[al_model_state, al_tokenizer_state], outputs=[al_save_file, al_save_status], ) # Model comparison cmp_btn.click( fn=run_comparison, inputs=[ ft_train_file, ft_dev_file, ft_test_file, ft_task, cmp_models, cmp_epochs, cmp_batch, cmp_lr, cmp_use_lora, cmp_lora_rank, cmp_lora_alpha, ], outputs=[cmp_log, cmp_table, cmp_plot, cmp_roc, cmp_results_col], concurrency_limit=1, ) # ============================================================================ # LAUNCH # ============================================================================ demo.launch(share=True, ssr_mode=False)