Shreyas Meher commited on
Commit
522f5f8
·
unverified ·
1 Parent(s): 0adf14d

New version with FT

Browse files
README.md CHANGED
@@ -1,216 +1,233 @@
1
- ![ConfliBERT GUI](./gui.png)
2
-
3
- # ConfliBERT GUI Application
4
-
5
- A web-based interface for [ConfliBERT](https://github.com/eventdata/ConfliBERT), a BERT-based model specialized in conflict and political event analysis. This application provides multiple Natural Language Processing capabilities including Named Entity Recognition (NER), Text Classification, Multi-label Classification, and Question Answering.
6
-
7
- ## Features
8
-
9
- - **Named Entity Recognition (NER)**
10
- - Identifies and classifies named entities in text
11
- - Entities include: Organizations, Persons, Locations, Quantities, Weapons, Nationalities, Temporal references, and more
12
- - Color-coded visualization of entities in the web interface
13
-
14
- - **Text Classification**
15
- - Binary classification for conflict-related content
16
- - Determines if text is related to conflict, violence, or politics
17
- - Provides confidence scores for classifications
18
-
19
- - **Multi-label Classification**
20
- - Categorizes text into multiple event types
21
- - Categories include: Armed Assault, Bombing or Explosion, Kidnapping, and Other
22
- - Provides confidence scores for each category
23
-
24
- - **Question Answering**
25
- - Extracts answers from provided context based on questions
26
- - Specialized for conflict-related queries
27
-
28
- ## Installation
29
-
30
- ### Requirements
31
-
32
- **Required:**
33
- - Python 3.8+
34
- - Git
35
- - Code editor (VS Code recommended)
36
-
37
- **Optional but recommended:**
38
- - PowerShell 5.0+ (Windows)
39
- - Terminal (Mac)
40
-
41
- ### Installation Steps
42
-
43
- 1. Install Python:
44
- - Download from [python.org](https://www.python.org/downloads/)
45
- - Check installation: `python --version`
46
-
47
- 2. Install Git:
48
- - Windows: Download from [git-scm.com](https://git-scm.com/downloads)
49
- - Mac: `brew install git` or download from [git-scm.com](https://git-scm.com/downloads)
50
- - Check installation: `git --version`
51
-
52
- 3. Clone repository:
53
- ```bash
54
- git clone https://github.com/shreyasmeher/conflibert-gui.git
55
- cd conflibert-gui
56
- ```
57
-
58
- 4. Create and activate virtual environment:
59
- ```bash
60
- # Create environment
61
- python -m venv env
62
-
63
- # Activate environment
64
- # On Windows:
65
- env\Scripts\activate
66
- # On Mac/Linux:
67
- source env/bin/activate
68
- ```
69
-
70
- 5. For Windows users with permission errors, [run PowerShell as Administrator](https://www.javatpoint.com/powershell-run-as-administrator):
71
- ```powershell
72
- Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope LocalMachine
73
- ```
74
-
75
- 6. Install requirements:
76
- ```bash
77
- pip install -r requirements.txt
78
- ```
79
-
80
- ### Package Requirements
81
-
82
- - Python 3.8+
83
- - PyTorch
84
- - TensorFlow
85
- - Transformers
86
- - Gradio
87
- - Pandas
88
-
89
- ## Usage
90
-
91
- ### Running the Application
92
-
93
- 1. Start the application:
94
- ```bash
95
- python app.py
96
- ```
97
-
98
- 2. Open your web browser and navigate to:
99
- ```
100
- http://localhost:7860
101
- ```
102
-
103
- ### Using Different Features
104
-
105
- #### Individual Text Analysis
106
-
107
- 1. Select the desired task from the dropdown menu:
108
- - Named Entity Recognition
109
- - Text Classification
110
- - Multilabel Classification
111
- - Question Answering
112
-
113
- 2. For standard tasks:
114
- - Enter your text in the input box
115
- - Click Submit
116
-
117
- 3. For Question Answering:
118
- - Enter the context in the context box
119
- - Enter your question in the question box
120
- - Click Submit
121
-
122
- #### Batch Processing with CSV
123
-
124
- 1. Prepare a CSV file with a 'text' column containing your texts
125
-
126
- 2. Select the desired task:
127
- - NER
128
- - Text Classification
129
- - Multilabel Classification
130
-
131
- 3. Upload your CSV file using the file upload component
132
-
133
- 4. Click Submit to process the entire file
134
-
135
- 5. Download the results CSV containing the original text and analysis results
136
-
137
- ## Model Information
138
-
139
- ConfliBERT uses several specialized models:
140
-
141
- - **NER Model**: `eventdata-utd/conflibert-named-entity-recognition`
142
- - **Binary Classification**: `eventdata-utd/conflibert-binary-classification`
143
- - **Multi-label Classification**: `eventdata-utd/conflibert-satp-relevant-multilabel`
144
- - **Question Answering**: `salsarra/ConfliBERT-QA`
145
-
146
- ## Output Formats
147
-
148
- ### NER Output
149
- ```
150
- EntityType: Entity1, Entity2 || EntityType2: Entity3 | Entity4
151
- ```
152
-
153
- ### Binary Classification Output
154
- ```
155
- Class (Confidence%)
156
- ```
157
-
158
- ### Multi-label Classification Output
159
- ```
160
- Class1 (Confidence%) | Class2 (Confidence%)
161
- ```
162
-
163
- ## Technical Details
164
-
165
- ### File Structure
166
- ```
167
- conflibert-gui/
168
- ├── app.py # Main application file
169
- ├── requirements.txt # Package dependencies
170
- └── README.md # Documentation
171
- ```
172
-
173
- ### Key Components
174
-
175
- - **UI Components**: Built using Gradio
176
- - **Backend Processing**: PyTorch and TensorFlow
177
- - **Data Processing**: Pandas for CSV handling
178
- - **Model Integration**: Hugging Face Transformers
179
-
180
- ## Contributing
181
-
182
- 1. Fork the repository
183
- 2. Create your feature branch (`git checkout -b feature/AmazingFeature`)
184
- 3. Commit your changes (`git commit -m 'Add some AmazingFeature'`)
185
- 4. Push to the branch (`git push origin feature/AmazingFeature`)
186
- 5. Open a Pull Request
187
-
188
- ## Credits
189
-
190
- Developed by:
191
- - [Sultan Alsarra](https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/)
192
- - [Shreyas Meher](http://shreyasmeher.com)
193
-
194
- ## License
195
-
196
- This project is licensed under the MIT License - see the LICENSE file for details.
197
-
198
- ## Institutional Support
199
-
200
- - [UTD Event Data](https://eventdata.utdallas.edu/)
201
- - [University of Texas at Dallas](https://www.utdallas.edu/)
202
-
203
- ## Citation
204
-
205
- If you use this tool in your research, please cite:
206
-
207
- ```bibtex
208
- @inproceedings{hu2022conflibert,
209
- title={ConfliBERT: A Pre-trained Language Model for Political Conflict and Violence},
210
- author={Hu, Yibo and Hosseini, MohammadSaleh and Parolin, Erick Skorupa and Osorio, Javier and Khan, Latifur and Brandt, Patrick and D’Orazio, Vito},
211
- booktitle={Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
212
- pages={5469--5482},
213
- year={2022}
214
- }
215
- ```
216
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ConfliBERT
2
+
3
+ [ConfliBERT](https://github.com/eventdata/ConfliBERT) is a pretrained language model built specifically for analyzing conflict and political violence text. This application provides a browser-based interface for running inference with ConfliBERT's pretrained models, fine-tuning custom classifiers on your own data, and comparing model performance across architectures.
4
+
5
+ Developed by [Shreyas Meher](http://shreyasmeher.com).
6
+
7
+ ## Screenshots
8
+
9
+ ### Home
10
+
11
+ The landing page shows your system configuration (GPU/CPU, RAM, platform) and an overview of everything the app can do.
12
+
13
+ <!-- Take a screenshot of the Home tab and save as screenshots/home.png -->
14
+ ![Home](./screenshots/home.png)
15
+
16
+ ### Named Entity Recognition
17
+
18
+ Identifies persons, organizations, locations, weapons, and other entity types. Results are color-coded. Supports single text and CSV batch processing.
19
+
20
+ <!-- Take a screenshot of the NER tab with sample output and save as screenshots/ner.png -->
21
+ ![NER](./screenshots/ner.png)
22
+
23
+ ### Binary Classification
24
+
25
+ Classifies text as conflict-related or not. Uses the pretrained ConfliBERT classifier by default, or load your own fine-tuned model.
26
+
27
+ <!-- Take a screenshot of the Classification tab and save as screenshots/classification.png -->
28
+ ![Classification](./screenshots/classification.png)
29
+
30
+ ### Multilabel Classification
31
+
32
+ Scores text against four event categories (Armed Assault, Bombing/Explosion, Kidnapping, Other). Each category is scored independently.
33
+
34
+ <!-- Take a screenshot of the Multilabel tab and save as screenshots/multilabel.png -->
35
+ ![Multilabel](./screenshots/multilabel.png)
36
+
37
+ ### Question Answering
38
+
39
+ Provide a context passage and a question. The model extracts the most relevant answer span.
40
+
41
+ <!-- Take a screenshot of the QA tab and save as screenshots/qa.png -->
42
+ ![QA](./screenshots/qa.png)
43
+
44
+ ### Fine-tuning
45
+
46
+ Train your own binary or multiclass classifier directly in the browser. Upload data (or load a built-in example), pick a base model, configure training, and go. After training, results and a "Try Your Model" panel appear side by side. You can also save the model and run batch predictions.
47
+
48
+ <!-- Take a screenshot of the Fine-tune tab and save as screenshots/finetune.png -->
49
+ ![Fine-tune](./screenshots/finetune.png)
50
+
51
+ ### Model Comparison
52
+
53
+ Compare multiple base model architectures on the same dataset. The comparison produces a metrics table, a grouped bar chart, and ROC-AUC curves.
54
+
55
+ ## Supported Models
56
+
57
+ ### Pretrained (Inference)
58
+
59
+ | Task | HuggingFace Model |
60
+ |------|-------------------|
61
+ | NER | `eventdata-utd/conflibert-named-entity-recognition` |
62
+ | Binary Classification | `eventdata-utd/conflibert-binary-classification` |
63
+ | Multilabel Classification | `eventdata-utd/conflibert-satp-relevant-multilabel` |
64
+ | Question Answering | `salsarra/ConfliBERT-QA` |
65
+
66
+ ### Fine-tuning (Base Models)
67
+
68
+ | Model | HuggingFace ID | Notes |
69
+ |-------|----------------|-------|
70
+ | ConfliBERT | `snowood1/ConfliBERT-scr-uncased` | Best for conflict/political text |
71
+ | BERT Base Uncased | `bert-base-uncased` | General-purpose baseline |
72
+ | BERT Base Cased | `bert-base-cased` | Case-sensitive variant |
73
+ | RoBERTa Base | `roberta-base` | Improved BERT training |
74
+ | ModernBERT Base | `answerdotai/ModernBERT-base` | Up to 8K token context |
75
+ | DeBERTa v3 Base | `microsoft/deberta-v3-base` | Strong on benchmarks |
76
+ | DistilBERT Base | `distilbert-base-uncased` | Faster, smaller |
77
+
78
+ ## Installation
79
+
80
+ ### Requirements
81
+
82
+ - Python 3.8+
83
+ - Git
84
+
85
+ ### Steps
86
+
87
+ 1. Clone the repository:
88
+
89
+ ```bash
90
+ git clone https://github.com/shreyasmeher/conflibert-gui.git
91
+ cd conflibert-gui
92
+ ```
93
+
94
+ 2. Create and activate a virtual environment:
95
+
96
+ ```bash
97
+ python -m venv env
98
+
99
+ # Mac/Linux:
100
+ source env/bin/activate
101
+
102
+ # Windows:
103
+ env\Scripts\activate
104
+ ```
105
+
106
+ On Windows, if you get a permission error, run PowerShell as Administrator and execute:
107
+
108
+ ```powershell
109
+ Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope LocalMachine
110
+ ```
111
+
112
+ 3. Install dependencies:
113
+
114
+ ```bash
115
+ pip install -r requirements.txt
116
+ ```
117
+
118
+ ## Usage
119
+
120
+ Start the application:
121
+
122
+ ```bash
123
+ python app.py
124
+ ```
125
+
126
+ Opens at `http://localhost:7860` and generates a public shareable link. The first launch takes a minute or two while it downloads the pretrained models.
127
+
128
+ ### Tabs
129
+
130
+ | Tab | What it does |
131
+ |-----|-------------|
132
+ | Home | System info, feature overview, citation |
133
+ | Named Entity Recognition | Identify entities in text or CSV |
134
+ | Binary Classification | Conflict vs. non-conflict, supports custom models |
135
+ | Multilabel Classification | Multi-event-type scoring |
136
+ | Question Answering | Extract answers from a context passage |
137
+ | Fine-tune | Train classifiers, compare models, ROC curves |
138
+
139
+ ### Fine-tuning Quick Start
140
+
141
+ 1. Go to the **Fine-tune** tab
142
+ 2. Click **"Load Example: Binary"** to load sample data
143
+ 3. Leave defaults and click **"Start Training"**
144
+ 4. Review metrics and try your model on new text
145
+ 5. Save the model and load it in the **Binary Classification** tab
146
+
147
+ ### Model Comparison Quick Start
148
+
149
+ 1. Upload data (or load an example) in the **Fine-tune** tab
150
+ 2. Scroll down and open **"Compare Multiple Models"**
151
+ 3. Check 2 or more models to compare
152
+ 4. Click **"Compare Models"**
153
+ 5. View the metrics table, bar chart, and ROC-AUC curves
154
+
155
+ ### Data Format
156
+
157
+ Tab-separated values (TSV), no header row. Each line: `text<TAB>label`
158
+
159
+ Binary example:
160
+ ```
161
+ The bomb exploded near the market 1
162
+ It was a sunny day at the park 0
163
+ ```
164
+
165
+ Multiclass example (integer labels starting from 0):
166
+ ```
167
+ The president signed the peace treaty 0
168
+ Militants attacked the military base 1
169
+ Thousands marched in the capital 2
170
+ Aid workers delivered food supplies 3
171
+ ```
172
+
173
+ ### CSV Batch Processing
174
+
175
+ Prepare a CSV with a `text` column:
176
+
177
+ ```csv
178
+ text
179
+ "The soldiers advanced toward the border."
180
+ "The festival attracted thousands of visitors."
181
+ ```
182
+
183
+ Upload it in the Batch Processing section of any inference tab.
184
+
185
+ ## Project Structure
186
+
187
+ ```
188
+ conflibert-gui/
189
+ app.py # Main application
190
+ requirements.txt # Dependencies
191
+ README.md
192
+ screenshots/ # UI screenshots for documentation
193
+ examples/
194
+ binary/ # Example binary dataset (conflict vs non-conflict)
195
+ train.tsv
196
+ dev.tsv
197
+ test.tsv
198
+ multiclass/ # Example multiclass dataset (4 event types)
199
+ train.tsv # 0=Diplomacy, 1=Armed Conflict,
200
+ dev.tsv # 2=Protest, 3=Humanitarian
201
+ test.tsv
202
+ ```
203
+
204
+ ## Training Features
205
+
206
+ - Early stopping with configurable patience
207
+ - Learning rate schedulers: linear, cosine, constant, constant with warmup
208
+ - Mixed precision training (FP16) on CUDA GPUs
209
+ - Gradient accumulation for larger effective batch sizes
210
+ - Weight decay regularization
211
+ - Automatic system detection (NVIDIA GPU, Apple Silicon MPS, CPU)
212
+ - Model comparison with grouped bar charts and ROC-AUC curves
213
+
214
+ ## Citation
215
+
216
+ If you use ConfliBERT in your research, please cite:
217
+
218
+ Brandt, P.T., Alsarra, S., D'Orazio, V., Heintze, D., Khan, L., Meher, S., Osorio, J. and Sianan, M., 2025. Extractive versus Generative Language Models for Political Conflict Text Classification. *Political Analysis*, pp.1-29.
219
+
220
+ ```bibtex
221
+ @article{brandt2025extractive,
222
+ title={Extractive versus Generative Language Models for Political Conflict Text Classification},
223
+ author={Brandt, Patrick T and Alsarra, Sultan and D'Orazio, Vito and Heintze, Dagmar and Khan, Latifur and Meher, Shreyas and Osorio, Javier and Sianan, Marcus},
224
+ journal={Political Analysis},
225
+ pages={1--29},
226
+ year={2025},
227
+ publisher={Cambridge University Press}
228
+ }
229
+ ```
230
+
231
+ ## License
232
+
233
+ MIT License. See LICENSE for details.
app.py CHANGED
@@ -1,878 +1,1629 @@
1
- import torch
2
- import tensorflow as tf
3
- from tf_keras import models, layers
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TFAutoModelForQuestionAnswering
5
- import gradio as gr
6
- import re
7
- import pandas as pd
8
- import io
9
- import os
10
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
11
- import keras
12
-
13
- # Check if GPU is available and use it if possible
14
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
-
16
- MAX_TOKEN_LENGTH = 512 # Adjust based on your model's limits
17
-
18
- def truncate_text(text, tokenizer, max_length=MAX_TOKEN_LENGTH):
19
- """Truncate text to max token length"""
20
- tokens = tokenizer.encode(text, truncation=False)
21
- if len(tokens) > max_length:
22
- tokens = tokens[:max_length-1] + [tokenizer.sep_token_id]
23
- return tokenizer.decode(tokens, skip_special_tokens=True)
24
- return text
25
-
26
- def safe_process(func, text, tokenizer):
27
- """Safely process text with proper error handling"""
28
- try:
29
- truncated_text = truncate_text(text, tokenizer)
30
- return func(truncated_text)
31
- except Exception as e:
32
- error_msg = str(e)
33
- if 'out of memory' in error_msg.lower():
34
- return "Error: Text too long for processing"
35
- elif 'cuda' in error_msg.lower():
36
- return "Error: GPU processing error"
37
- else:
38
- return f"Error: {error_msg}"
39
-
40
- # Load the models and tokenizers
41
- qa_model_name = 'salsarra/ConfliBERT-QA'
42
- qa_model = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name)
43
- qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
44
-
45
- ner_model_name = 'eventdata-utd/conflibert-named-entity-recognition'
46
- ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
47
- ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
48
-
49
- clf_model_name = 'eventdata-utd/conflibert-binary-classification'
50
- clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name).to(device)
51
- clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)
52
-
53
- multi_clf_model_name = 'eventdata-utd/conflibert-satp-relevant-multilabel'
54
- multi_clf_model = AutoModelForSequenceClassification.from_pretrained(multi_clf_model_name).to(device)
55
- multi_clf_tokenizer = AutoTokenizer.from_pretrained(multi_clf_model_name)
56
-
57
- # Define the class names for text classification
58
- class_names = ['Negative', 'Positive']
59
- multi_class_names = ["Armed Assault", "Bombing or Explosion", "Kidnapping", "Other"] # Updated labels
60
-
61
- # Define the NER labels and colors
62
- ner_labels = {
63
- 'Organisation': 'blue',
64
- 'Person': 'red',
65
- 'Location': 'green',
66
- 'Quantity': 'orange',
67
- 'Weapon': 'purple',
68
- 'Nationality': 'cyan',
69
- 'Temporal': 'magenta',
70
- 'DocumentReference': 'brown',
71
- 'MilitaryPlatform': 'yellow',
72
- 'Money': 'pink'
73
- }
74
-
75
- def handle_error_message(e, default_limit=512):
76
- error_message = str(e)
77
- pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)")
78
- match = pattern.search(error_message)
79
- if match:
80
- number_1, number_2 = match.groups()
81
- return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
82
- pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)")
83
- match_qa = pattern_qa.search(error_message)
84
- if match_qa:
85
- number_1, number_2 = match_qa.groups()
86
- return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
87
- return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>"
88
-
89
- # Define the functions for each task
90
- def question_answering(context, question):
91
- try:
92
- inputs = qa_tokenizer(question, context, return_tensors='tf', truncation=True)
93
- outputs = qa_model(inputs)
94
- answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
95
- answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
96
- answer = qa_tokenizer.convert_tokens_to_string(qa_tokenizer.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
97
- return f"<span style='color: green; font-weight: bold;'>{answer}</span>"
98
- except Exception as e:
99
- return handle_error_message(e)
100
-
101
- def replace_unk(tokens):
102
- return [token.replace('[UNK]', "'") for token in tokens]
103
-
104
- def named_entity_recognition(text, output_format='html'):
105
- """
106
- Process text for named entity recognition.
107
- output_format: 'html' for GUI display, 'csv' for CSV processing
108
- """
109
- try:
110
- inputs = ner_tokenizer(text, return_tensors='pt', truncation=True)
111
- with torch.no_grad():
112
- outputs = ner_model(**inputs)
113
- ner_results = outputs.logits.argmax(dim=2).squeeze().tolist()
114
- tokens = ner_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist())
115
- tokens = replace_unk(tokens)
116
-
117
- entities = []
118
- seen_labels = set()
119
- current_entity = []
120
- current_label = None
121
-
122
- # Process tokens and group consecutive entities
123
- for i in range(len(tokens)):
124
- token = tokens[i]
125
- label = ner_model.config.id2label[ner_results[i]].split('-')[-1]
126
-
127
- # Handle subwords
128
- if token.startswith('##'):
129
- if entities:
130
- if output_format == 'html':
131
- entities[-1][0] += token[2:]
132
- elif current_entity:
133
- current_entity[-1] = current_entity[-1] + token[2:]
134
- else:
135
- # For CSV format, group consecutive tokens of same entity type
136
- if output_format == 'csv':
137
- if label != 'O':
138
- if label == current_label:
139
- current_entity.append(token)
140
- else:
141
- if current_entity:
142
- entities.append([' '.join(current_entity), current_label])
143
- current_entity = [token]
144
- current_label = label
145
- else:
146
- if current_entity:
147
- entities.append([' '.join(current_entity), current_label])
148
- current_entity = []
149
- current_label = None
150
- else:
151
- entities.append([token, label])
152
-
153
- if label != 'O':
154
- seen_labels.add(label)
155
-
156
- # Don't forget the last entity for CSV format
157
- if output_format == 'csv' and current_entity:
158
- entities.append([' '.join(current_entity), current_label])
159
-
160
- if output_format == 'csv':
161
- # Group by entity type
162
- grouped_entities = {}
163
- for token, label in entities:
164
- if label != 'O':
165
- if label not in grouped_entities:
166
- grouped_entities[label] = []
167
- grouped_entities[label].append(token)
168
-
169
- # Format the output
170
- result_parts = []
171
- for label, tokens in grouped_entities.items():
172
- unique_tokens = list(dict.fromkeys(tokens)) # Remove duplicates
173
- result_parts.append(f"{label}: {' | '.join(unique_tokens)}")
174
-
175
- return ' || '.join(result_parts)
176
- else:
177
- # Original HTML output
178
- highlighted_text = ""
179
- for token, label in entities:
180
- color = ner_labels.get(label, 'black')
181
- if label != 'O':
182
- highlighted_text += f"<span style='color: {color}; font-weight: bold;'>{token}</span> "
183
- else:
184
- highlighted_text += f"{token} "
185
-
186
- legend = "<div><strong>NER Tags Found:</strong><ul style='list-style-type: disc; padding-left: 20px;'>"
187
- for label in seen_labels:
188
- color = ner_labels.get(label, 'black')
189
- legend += f"<li style='color: {color}; font-weight: bold;'>{label}</li>"
190
- legend += "</ul></div>"
191
-
192
- return f"<div>{highlighted_text}</div>{legend}"
193
-
194
- except Exception as e:
195
- return handle_error_message(e)
196
-
197
- def text_classification(text):
198
- try:
199
- inputs = clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
200
- with torch.no_grad():
201
- outputs = clf_model(**inputs)
202
- logits = outputs.logits.squeeze().tolist()
203
- predicted_class = torch.argmax(outputs.logits, dim=1).item()
204
- confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100
205
-
206
- if predicted_class == 1: # Positive class
207
- result = f"<span style='color: green; font-weight: bold;'>Positive: The text is related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>"
208
- else: # Negative class
209
- result = f"<span style='color: red; font-weight: bold;'>Negative: The text is not related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>"
210
- return result
211
- except Exception as e:
212
- return handle_error_message(e)
213
-
214
- def multilabel_classification(text):
215
- try:
216
- inputs = multi_clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
217
- with torch.no_grad():
218
- outputs = multi_clf_model(**inputs)
219
- predicted_classes = torch.sigmoid(outputs.logits).squeeze().tolist()
220
- if len(predicted_classes) != len(multi_class_names):
221
- return f"Error: Number of predicted classes ({len(predicted_classes)}) does not match number of class names ({len(multi_class_names)})."
222
-
223
- results = []
224
- for i in range(len(predicted_classes)):
225
- confidence = predicted_classes[i] * 100
226
- if predicted_classes[i] >= 0.5:
227
- results.append(f"<span style='color: green; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>")
228
- else:
229
- results.append(f"<span style='color: red; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>")
230
-
231
- return " / ".join(results)
232
- except Exception as e:
233
- return handle_error_message(e)
234
-
235
- def clean_html_tags(text):
236
- """Remove HTML tags and formatting from the output."""
237
- # Remove HTML tags but keep the text content
238
- clean_text = re.sub(r'<[^>]+>', '', text)
239
- # Remove multiple spaces
240
- clean_text = re.sub(r'\s+', ' ', clean_text)
241
- # Remove [CLS] and [SEP] tokens
242
- clean_text = re.sub(r'\[CLS\]|\[SEP\]', '', clean_text)
243
- return clean_text.strip()
244
-
245
- def extract_ner_entities(html_output):
246
- """Extract entities and their types from NER output using a simpler approach."""
247
- # Map colors to entity types
248
- color_to_type = {
249
- 'blue': 'Organisation',
250
- 'red': 'Person',
251
- 'green': 'Location',
252
- 'orange': 'Quantity',
253
- 'purple': 'Weapon',
254
- 'cyan': 'Nationality',
255
- 'magenta': 'Temporal',
256
- 'brown': 'DocumentReference',
257
- 'yellow': 'MilitaryPlatform',
258
- 'pink': 'Money'
259
- }
260
-
261
- # Find all colored spans
262
- pattern = r"<span style='color: ([^']+)[^>]+>([^<]+)</span>"
263
- matches = re.findall(pattern, html_output)
264
-
265
- # Group by entity type
266
- entities = {}
267
-
268
- # Process each match
269
- for color, text in matches:
270
- if color in color_to_type:
271
- entity_type = color_to_type[color]
272
- if entity_type not in entities:
273
- entities[entity_type] = []
274
-
275
- # Clean and store the text
276
- text = text.strip()
277
- if text and not text.isspace():
278
- entities[entity_type].append(text)
279
-
280
- # Join consecutive words for each entity type
281
- result_parts = []
282
- for entity_type, words in entities.items():
283
- # Join consecutive words
284
- phrases = []
285
- current_phrase = []
286
-
287
- for word in words:
288
- if word in [',', '/', ':', '-']: # Skip punctuation
289
- continue
290
- if not current_phrase:
291
- current_phrase.append(word)
292
- else:
293
- # If it's a continuation (e.g., part of a date or name)
294
- if word.startswith(':') or word == 'of' or current_phrase[-1].endswith('/'):
295
- current_phrase.append(word)
296
- else:
297
- # If it's a new entity
298
- phrases.append(' '.join(current_phrase))
299
- current_phrase = [word]
300
-
301
- if current_phrase:
302
- phrases.append(' '.join(current_phrase))
303
-
304
- # Remove duplicates while preserving order
305
- unique_phrases = []
306
- seen = set()
307
- for phrase in phrases:
308
- clean_phrase = phrase.strip()
309
- if clean_phrase and clean_phrase not in seen:
310
- unique_phrases.append(clean_phrase)
311
- seen.add(clean_phrase)
312
-
313
- if unique_phrases:
314
- result_parts.append(f"{entity_type}: {' | '.join(unique_phrases)}")
315
-
316
- return ' || '.join(result_parts)
317
-
318
-
319
- def clean_classification_output(html_output):
320
- """Extract classification results without HTML formatting."""
321
- if "Positive" in html_output:
322
- # Binary classification
323
- match = re.search(r">(Positive|Negative).*?Confidence: ([\d.]+)%", html_output)
324
- if match:
325
- class_name, confidence = match.groups()
326
- return f"{class_name} ({confidence}%)"
327
- else:
328
- # Multilabel classification
329
- results = []
330
- matches = re.finditer(r">([^<]+)\s*\(Confidence:\s*([\d.]+)%\)", html_output)
331
- for match in matches:
332
- class_name, confidence = match.groups()
333
- if float(confidence) >= 50: # Only include classes with confidence >= 50%
334
- results.append(f"{class_name.strip()} ({confidence}%)")
335
- return " | ".join(results) if results else "No classes above 50% confidence"
336
-
337
- return "Unknown"
338
-
339
-
340
- def process_csv_ner(file):
341
- try:
342
- df = pd.read_csv(file.name)
343
-
344
- if 'text' not in df.columns:
345
- return "Error: CSV must contain a 'text' column"
346
-
347
- entities = []
348
- for text in df['text']:
349
- if pd.isna(text):
350
- entities.append("")
351
- continue
352
-
353
- # Use CSV output format
354
- result = named_entity_recognition(str(text), output_format='csv')
355
- entities.append(result)
356
-
357
- df['entities'] = entities
358
-
359
- output_path = "processed_results.csv"
360
- df.to_csv(output_path, index=False)
361
- return output_path
362
- except Exception as e:
363
- return f"Error processing CSV: {str(e)}"
364
-
365
- def process_csv_classification(file, is_multi=False):
366
- try:
367
- df = pd.read_csv(file.name)
368
-
369
- if 'text' not in df.columns:
370
- return "Error: CSV must contain a 'text' column"
371
-
372
- results = []
373
- for text in df['text']:
374
- if pd.isna(text):
375
- results.append("")
376
- continue
377
-
378
- if is_multi:
379
- html_result = multilabel_classification(str(text))
380
- else:
381
- html_result = text_classification(str(text))
382
- results.append(clean_classification_output(html_result))
383
-
384
- result_column = 'multilabel_results' if is_multi else 'classification_results'
385
- df[result_column] = results
386
-
387
- output_path = "processed_results.csv"
388
- df.to_csv(output_path, index=False)
389
- return output_path
390
- except Exception as e:
391
- return f"Error processing CSV: {str(e)}"
392
-
393
-
394
- # Define the Gradio interface
395
- def chatbot(task, text=None, context=None, question=None, file=None):
396
- if file is not None: # Handle CSV file input
397
- if task == "Named Entity Recognition":
398
- return process_csv_ner(file)
399
- elif task == "Text Classification":
400
- return process_csv_classification(file, is_multi=False)
401
- elif task == "Multilabel Classification":
402
- return process_csv_classification(file, is_multi=True)
403
- else:
404
- return "CSV processing is not supported for Question Answering task"
405
-
406
- # Handle regular text input (previous implementation)
407
- if task == "Question Answering":
408
- if context and question:
409
- return question_answering(context, question)
410
- else:
411
- return "Please provide both context and question for the Question Answering task."
412
- elif task == "Named Entity Recognition":
413
- if text:
414
- return named_entity_recognition(text)
415
- else:
416
- return "Please provide text for the Named Entity Recognition task."
417
- elif task == "Text Classification":
418
- if text:
419
- return text_classification(text)
420
- else:
421
- return "Please provide text for the Text Classification task."
422
- elif task == "Multilabel Classification":
423
- if text:
424
- return multilabel_classification(text)
425
- else:
426
- return "Please provide text for the Multilabel Classification task."
427
- else:
428
- return "Please select a valid task."
429
-
430
-
431
- # Custom CSS for modern orange theme
432
- custom_css = """
433
- /* CSS Variables for Light and Dark Theme */
434
- :root {
435
- --primary-orange: #ff6b35;
436
- --primary-orange-light: #ff8c5a;
437
- --primary-orange-dark: #e55a2b;
438
- --secondary-orange: #ffa366;
439
- --accent-orange: #ff9f40;
440
- --background-light: #fefefe;
441
- --background-dark: #1a1a1a;
442
- --surface-light: #ffffff;
443
- --surface-dark: #2d2d2d;
444
- --text-primary-light: #2c2c2c;
445
- --text-primary-dark: #ffffff;
446
- --text-secondary-light: #666666;
447
- --text-secondary-dark: #cccccc;
448
- --border-light: #e0e0e0;
449
- --border-dark: #404040;
450
- --shadow-light: rgba(0, 0, 0, 0.1);
451
- --shadow-dark: rgba(0, 0, 0, 0.3);
452
- --gradient-orange: linear-gradient(135deg, #ff6b35 0%, #ff9f40 100%);
453
- --gradient-orange-subtle: linear-gradient(135deg, rgba(255, 107, 53, 0.1) 0%, rgba(255, 159, 64, 0.1) 100%);
454
- }
455
-
456
- /* Dark theme overrides */
457
- .dark {
458
- --background: var(--background-dark);
459
- --surface: var(--surface-dark);
460
- --text-primary: var(--text-primary-dark);
461
- --text-secondary: var(--text-secondary-dark);
462
- --border: var(--border-dark);
463
- --shadow: var(--shadow-dark);
464
- }
465
-
466
- /* Light theme (default) */
467
- .light, :root {
468
- --background: var(--background-light);
469
- --surface: var(--surface-light);
470
- --text-primary: var(--text-primary-light);
471
- --text-secondary: var(--text-secondary-light);
472
- --border: var(--border-light);
473
- --shadow: var(--shadow-light);
474
- }
475
-
476
- /* Global Styles */
477
- * {
478
- font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
479
- }
480
-
481
- /* Main Container */
482
- .gradio-container {
483
- background: var(--background) !important;
484
- color: var(--text-primary) !important;
485
- min-height: 100vh;
486
- }
487
-
488
- /* Header Styling */
489
- .header-container {
490
- background: var(--gradient-orange) !important;
491
- padding: 2rem 1rem !important;
492
- margin: -1rem -1rem 2rem -1rem !important;
493
- border-radius: 0 0 24px 24px !important;
494
- box-shadow: 0 8px 32px var(--shadow) !important;
495
- position: relative;
496
- overflow: hidden;
497
- }
498
-
499
- .header-container::before {
500
- content: '';
501
- position: absolute;
502
- top: 0;
503
- left: 0;
504
- right: 0;
505
- bottom: 0;
506
- background: url("data:image/svg+xml,%3Csvg width='60' height='60' viewBox='0 0 60 60' xmlns='http://www.w3.org/2000/svg'%3E%3Cg fill='none' fill-rule='evenodd'%3E%3Cg fill='%23ffffff' fill-opacity='0.05'%3E%3Ccircle cx='30' cy='30' r='2'/%3E%3C/g%3E%3C/g%3E%3C/svg%3E") !important;
507
- pointer-events: none;
508
- }
509
-
510
- .header-title-center {
511
- text-align: center !important;
512
- position: relative;
513
- z-index: 1;
514
- }
515
-
516
- .header-title-center a {
517
- color: white !important;
518
- text-decoration: none !important;
519
- font-weight: 900 !important;
520
- font-size: 4rem !important;
521
- text-shadow: 0 4px 8px rgba(0, 0, 0, 0.2) !important;
522
- letter-spacing: -0.02em !important;
523
- transition: all 0.3s ease !important;
524
- }
525
-
526
- .header-title-center a:hover {
527
- transform: translateY(-2px) !important;
528
- text-shadow: 0 6px 16px rgba(0, 0, 0, 0.3) !important;
529
- }
530
-
531
- /* Task Container */
532
- .task-container {
533
- background: var(--surface) !important;
534
- border-radius: 16px !important;
535
- padding: 2rem !important;
536
- box-shadow: 0 4px 24px var(--shadow) !important;
537
- border: 1px solid var(--border) !important;
538
- margin-bottom: 2rem !important;
539
- }
540
-
541
- /* Input Components */
542
- .input-text textarea, .input-text input {
543
- background: var(--surface) !important;
544
- border: 2px solid var(--border) !important;
545
- border-radius: 12px !important;
546
- padding: 1rem !important;
547
- color: var(--text-primary) !important;
548
- font-size: 0.95rem !important;
549
- line-height: 1.5 !important;
550
- transition: all 0.3s ease !important;
551
- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05) !important;
552
- }
553
-
554
- .input-text textarea:focus, .input-text input:focus {
555
- border-color: var(--primary-orange) !important;
556
- box-shadow: 0 0 0 3px rgba(255, 107, 53, 0.1) !important;
557
- outline: none !important;
558
- transform: translateY(-1px) !important;
559
- }
560
-
561
- /* Placeholder text styling */
562
- .input-text textarea::placeholder, .input-text input::placeholder {
563
- color: var(--text-secondary) !important;
564
- opacity: 0.7 !important;
565
- }
566
-
567
- .input-text textarea::-webkit-input-placeholder, .input-text input::-webkit-input-placeholder {
568
- color: var(--text-secondary) !important;
569
- opacity: 0.7 !important;
570
- }
571
-
572
- .input-text textarea::-moz-placeholder, .input-text input::-moz-placeholder {
573
- color: var(--text-secondary) !important;
574
- opacity: 0.7 !important;
575
- }
576
-
577
- .input-text textarea:-ms-input-placeholder, .input-text input:-ms-input-placeholder {
578
- color: var(--text-secondary) !important;
579
- opacity: 0.7 !important;
580
- }
581
-
582
- /* Dropdown Styling */
583
- .gr-dropdown {
584
- background: var(--surface) !important;
585
- border: 2px solid var(--border) !important;
586
- border-radius: 12px !important;
587
- color: var(--text-primary) !important;
588
- transition: all 0.3s ease !important;
589
- }
590
-
591
- .gr-dropdown:focus-within {
592
- border-color: var(--primary-orange) !important;
593
- box-shadow: 0 0 0 3px rgba(255, 107, 53, 0.1) !important;
594
- }
595
-
596
- /* Button Styling */
597
- .submit-btn {
598
- background: var(--gradient-orange) !important;
599
- border: none !important;
600
- border-radius: 12px !important;
601
- padding: 1rem 2rem !important;
602
- color: white !important;
603
- font-weight: 600 !important;
604
- font-size: 1rem !important;
605
- cursor: pointer !important;
606
- transition: all 0.3s ease !important;
607
- box-shadow: 0 4px 16px rgba(255, 107, 53, 0.3) !important;
608
- text-transform: uppercase !important;
609
- letter-spacing: 0.5px !important;
610
- }
611
-
612
- .submit-btn:hover {
613
- transform: translateY(-2px) !important;
614
- box-shadow: 0 6px 24px rgba(255, 107, 53, 0.4) !important;
615
- background: linear-gradient(135deg, #ff8c5a 0%, #ffb366 100%) !important;
616
- }
617
-
618
- .submit-btn:active {
619
- transform: translateY(0) !important;
620
- box-shadow: 0 2px 8px rgba(255, 107, 53, 0.3) !important;
621
- }
622
-
623
- /* File Upload Styling */
624
- .file-upload {
625
- background: var(--gradient-orange-subtle) !important;
626
- border: 2px dashed var(--primary-orange) !important;
627
- border-radius: 12px !important;
628
- padding: 1.5rem !important;
629
- text-align: center !important;
630
- transition: all 0.3s ease !important;
631
- }
632
-
633
- .file-upload:hover {
634
- background: rgba(255, 107, 53, 0.15) !important;
635
- border-color: var(--primary-orange-dark) !important;
636
- }
637
-
638
- /* Output Styling */
639
- .output-html {
640
- background: var(--surface) !important;
641
- border: 1px solid var(--border) !important;
642
- border-radius: 12px !important;
643
- padding: 1.5rem !important;
644
- margin-top: 1rem !important;
645
- box-shadow: 0 2px 12px var(--shadow) !important;
646
- min-height: 100px !important;
647
- }
648
-
649
- .output-html div {
650
- color: var(--text-primary) !important;
651
- line-height: 1.6 !important;
652
- }
653
-
654
- /* Labels */
655
- label {
656
- color: var(--text-primary) !important;
657
- font-weight: 600 !important;
658
- font-size: 0.9rem !important;
659
- margin-bottom: 0.5rem !important;
660
- text-transform: uppercase !important;
661
- letter-spacing: 0.5px !important;
662
- }
663
-
664
- /* Footer */
665
- .footer {
666
- background: var(--surface) !important;
667
- border-top: 1px solid var(--border) !important;
668
- padding: 1.5rem !important;
669
- margin-top: 2rem !important;
670
- text-align: center !important;
671
- border-radius: 16px 16px 0 0 !important;
672
- }
673
-
674
- .footer a {
675
- color: var(--primary-orange) !important;
676
- text-decoration: none !important;
677
- font-weight: 500 !important;
678
- transition: color 0.3s ease !important;
679
- }
680
-
681
- .footer a:hover {
682
- color: var(--primary-orange-dark) !important;
683
- text-decoration: underline !important;
684
- }
685
-
686
- /* Responsive Design */
687
- @media (max-width: 768px) {
688
- .header-title-center a {
689
- font-size: 2.5rem !important;
690
- }
691
-
692
- .task-container {
693
- padding: 1.5rem !important;
694
- margin: 1rem !important;
695
- }
696
-
697
- .header-container {
698
- padding: 1.5rem 1rem !important;
699
- margin: -1rem -1rem 1rem -1rem !important;
700
- }
701
- }
702
-
703
- /* Enhanced NER Output Styling */
704
- .output-html span[style*="color: blue"] { color: #3b82f6 !important; }
705
- .output-html span[style*="color: red"] { color: #ef4444 !important; }
706
- .output-html span[style*="color: green"] { color: #10b981 !important; }
707
- .output-html span[style*="color: orange"] { color: var(--primary-orange) !important; }
708
- .output-html span[style*="color: purple"] { color: #8b5cf6 !important; }
709
- .output-html span[style*="color: cyan"] { color: #06b6d4 !important; }
710
- .output-html span[style*="color: magenta"] { color: #ec4899 !important; }
711
- .output-html span[style*="color: brown"] { color: #92400e !important; }
712
- .output-html span[style*="color: yellow"] { color: #f59e0b !important; }
713
- .output-html span[style*="color: pink"] { color: #f472b6 !important; }
714
-
715
- /* Dark mode specific adjustments */
716
- @media (prefers-color-scheme: dark) {
717
- .gradio-container {
718
- background: var(--background-dark) !important;
719
- color: var(--text-primary-dark) !important;
720
- }
721
-
722
- .task-container, .output-html {
723
- background: var(--surface-dark) !important;
724
- border-color: var(--border-dark) !important;
725
- }
726
-
727
- .input-text textarea, .input-text input, .gr-dropdown {
728
- background: var(--surface-dark) !important;
729
- border-color: var(--border-dark) !important;
730
- color: var(--text-primary-dark) !important;
731
- }
732
-
733
- label {
734
- color: var(--text-primary-dark) !important;
735
- }
736
- }
737
-
738
- /* Smooth transitions for theme switching */
739
- * {
740
- transition: background-color 0.3s ease, border-color 0.3s ease, color 0.3s ease !important;
741
- }
742
- """
743
-
744
- with gr.Blocks(theme="allenai/gradio-theme", css=custom_css) as demo:
745
- with gr.Column():
746
- with gr.Row(elem_id="header", elem_classes="header-container"):
747
- gr.Markdown("<div class='header-title-center'><a href='https://eventdata.utdallas.edu/conflibert/' style='font-size: 4rem; font-weight: 900;'>ConfliBERT</a></div>")
748
-
749
- with gr.Column(elem_classes="task-container"):
750
- gr.Markdown("<h2 style='font-size: 1.25rem; font-weight: 600; margin-bottom: 1.5rem;'>Select a task and provide the necessary inputs:</h2>")
751
-
752
- task = gr.Dropdown(
753
- choices=["Question Answering", "Named Entity Recognition", "Text Classification", "Multilabel Classification"],
754
- label="Select Task",
755
- value="Named Entity Recognition"
756
- )
757
-
758
- with gr.Row():
759
- text_input = gr.Textbox(
760
- lines=5,
761
- placeholder="Enter the text here...",
762
- label="Text",
763
- elem_classes="input-text"
764
- )
765
- context_input = gr.Textbox(
766
- lines=5,
767
- placeholder="Enter the context here...",
768
- label="Context",
769
- visible=False,
770
- elem_classes="input-text"
771
- )
772
- question_input = gr.Textbox(
773
- lines=2,
774
- placeholder="Enter your question here...",
775
- label="Question",
776
- visible=False,
777
- elem_classes="input-text"
778
- )
779
-
780
- with gr.Row():
781
- file_input = gr.File(
782
- label="Or upload a CSV file (must contain a 'text' column)",
783
- file_types=[".csv"],
784
- elem_classes="file-upload"
785
- )
786
- file_output = gr.File(
787
- label="Download processed results",
788
- visible=False,
789
- elem_classes="file-download"
790
- )
791
-
792
- with gr.Row():
793
- submit_button = gr.Button(
794
- "Submit",
795
- elem_id="submit-button",
796
- elem_classes="submit-btn"
797
- )
798
-
799
- output = gr.HTML(label="Output", elem_classes="output-html")
800
-
801
- with gr.Row(elem_classes="footer"):
802
- gr.Markdown("<a href='https://eventdata.utdallas.edu/'>UTD Event Data</a> | <a href='https://www.utdallas.edu/'>University of Texas at Dallas</a>")
803
- gr.Markdown("Developed By: <a href='https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/' target='_blank'>Sultan Alsarra</a> and <a href='http://shreyasmeher.com' target='_blank'>Shreyas Meher</a>")
804
-
805
- def update_inputs(task_name):
806
- """Updates the visibility of input components based on the selected task."""
807
- if task_name == "Question Answering":
808
- return [
809
- gr.update(visible=False),
810
- gr.update(visible=True),
811
- gr.update(visible=True),
812
- gr.update(visible=False),
813
- gr.update(visible=False)
814
- ]
815
- else:
816
- return [
817
- gr.update(visible=True),
818
- gr.update(visible=False),
819
- gr.update(visible=False),
820
- gr.update(visible=True),
821
- gr.update(visible=True)
822
- ]
823
-
824
- def chatbot_interface(task, text, context, question, file):
825
- """Handles both file and text inputs for different tasks."""
826
- if file:
827
- result = chatbot(task, file=file)
828
- if isinstance(result, str) and result.endswith('.csv'):
829
- return gr.update(visible=False), gr.update(value=result, visible=True)
830
- return gr.update(value=result, visible=True), gr.update(visible=False)
831
- else:
832
- result = chatbot(task, text, context, question)
833
- return gr.update(value=result, visible=True), gr.update(visible=False)
834
-
835
- def chatbot(task, text=None, context=None, question=None, file=None):
836
- """Main function to process different types of inputs and tasks."""
837
- if file is not None: # Handle CSV file input
838
- if task == "Named Entity Recognition":
839
- return process_csv_ner(file)
840
- elif task == "Text Classification":
841
- return process_csv_classification(file, is_multi=False)
842
- elif task == "Multilabel Classification":
843
- return process_csv_classification(file, is_multi=True)
844
- else:
845
- return "CSV processing is not supported for Question Answering task"
846
-
847
- # Handle regular text input
848
- if task == "Question Answering":
849
- if context and question:
850
- return question_answering(context, question)
851
- else:
852
- return "Please provide both context and question for the Question Answering task."
853
- elif task == "Named Entity Recognition":
854
- if text:
855
- return named_entity_recognition(text)
856
- else:
857
- return "Please provide text for the Named Entity Recognition task."
858
- elif task == "Text Classification":
859
- if text:
860
- return text_classification(text)
861
- else:
862
- return "Please provide text for the Text Classification task."
863
- elif task == "Multilabel Classification":
864
- if text:
865
- return multilabel_classification(text)
866
- else:
867
- return "Please provide text for the Multilabel Classification task."
868
- else:
869
- return "Please select a valid task."
870
-
871
- task.change(fn=update_inputs, inputs=task, outputs=[text_input, context_input, question_input, file_input, file_output])
872
- submit_button.click(
873
- fn=chatbot_interface,
874
- inputs=[task, text_input, context_input, question_input, file_input],
875
- outputs=[output, file_output]
876
- )
877
-
878
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # ConfliBERT - Conflict & Political Violence NLP Toolkit
3
+ # University of Texas at Dallas | Event Data Lab
4
+ # ============================================================================
5
+
6
+ import os
7
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
8
+
9
+ import torch
10
+ import tensorflow as tf
11
+ import tf_keras # noqa: F401 - needed for TF model loading
12
+ import keras # noqa: F401 - needed for TF model loading
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModelForSequenceClassification,
16
+ AutoModelForTokenClassification,
17
+ TFAutoModelForQuestionAnswering,
18
+ TrainingArguments,
19
+ Trainer,
20
+ EarlyStoppingCallback,
21
+ TrainerCallback,
22
+ )
23
+ import gradio as gr
24
+ import numpy as np
25
+ import pandas as pd
26
+ import re
27
+ import csv
28
+ import tempfile
29
+ from sklearn.metrics import (
30
+ accuracy_score as sk_accuracy,
31
+ precision_score as sk_precision,
32
+ recall_score as sk_recall,
33
+ f1_score as sk_f1,
34
+ roc_curve,
35
+ auc as sk_auc,
36
+ )
37
+ from sklearn.preprocessing import label_binarize
38
+ from torch.utils.data import Dataset as TorchDataset
39
+ import gc
40
+
41
+
42
+ # ============================================================================
43
+ # CONFIGURATION
44
+ # ============================================================================
45
+
46
+ if torch.cuda.is_available():
47
+ device = torch.device('cuda')
48
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
49
+ device = torch.device('mps')
50
+ else:
51
+ device = torch.device('cpu')
52
+
53
+ MAX_TOKEN_LENGTH = 512
54
+
55
+
56
+ def get_system_info():
57
+ """Build an HTML string describing the user's compute environment."""
58
+ import platform
59
+ lines = []
60
+
61
+ # Device
62
+ if device.type == 'cuda':
63
+ gpu_name = torch.cuda.get_device_name(0)
64
+ vram = torch.cuda.get_device_properties(0).total_mem / (1024 ** 3)
65
+ lines.append(f"GPU: {gpu_name} ({vram:.1f} GB VRAM)")
66
+ lines.append("FP16 training: supported")
67
+ elif device.type == 'mps':
68
+ lines.append("GPU: Apple Silicon (MPS)")
69
+ lines.append("FP16 training: not supported on MPS")
70
+ else:
71
+ lines.append("GPU: None detected (using CPU)")
72
+ lines.append("FP16 training: not supported on CPU")
73
+
74
+ # CPU / RAM
75
+ import os
76
+ cpu_count = os.cpu_count() or 1
77
+ lines.append(f"CPU cores: {cpu_count}")
78
+ try:
79
+ import psutil
80
+ ram_gb = psutil.virtual_memory().total / (1024 ** 3)
81
+ lines.append(f"RAM: {ram_gb:.1f} GB")
82
+ except ImportError:
83
+ pass
84
+
85
+ lines.append(f"Platform: {platform.system()} {platform.machine()}")
86
+ lines.append(f"PyTorch: {torch.__version__}")
87
+
88
+ return " · ".join(lines)
89
+
90
+ FINETUNE_MODELS = {
91
+ "ConfliBERT (recommended for conflict/political text)": "snowood1/ConfliBERT-scr-uncased",
92
+ "BERT Base Uncased": "bert-base-uncased",
93
+ "BERT Base Cased": "bert-base-cased",
94
+ "RoBERTa Base": "roberta-base",
95
+ "ModernBERT Base": "answerdotai/ModernBERT-base",
96
+ "DeBERTa v3 Base": "microsoft/deberta-v3-base",
97
+ "DistilBERT Base Uncased": "distilbert-base-uncased",
98
+ }
99
+
100
+ NER_LABELS = {
101
+ 'Organisation': '#3b82f6',
102
+ 'Person': '#ef4444',
103
+ 'Location': '#10b981',
104
+ 'Quantity': '#ff6b35',
105
+ 'Weapon': '#8b5cf6',
106
+ 'Nationality': '#06b6d4',
107
+ 'Temporal': '#ec4899',
108
+ 'DocumentReference': '#92400e',
109
+ 'MilitaryPlatform': '#f59e0b',
110
+ 'Money': '#f472b6',
111
+ }
112
+
113
+ CLASS_NAMES = ['Negative', 'Positive']
114
+ MULTI_CLASS_NAMES = ["Armed Assault", "Bombing or Explosion", "Kidnapping", "Other"]
115
+
116
+
117
+ # ============================================================================
118
+ # PRETRAINED MODEL LOADING
119
+ # ============================================================================
120
+
121
+ qa_model_name = 'salsarra/ConfliBERT-QA'
122
+ qa_model = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name)
123
+ qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
124
+
125
+ ner_model_name = 'eventdata-utd/conflibert-named-entity-recognition'
126
+ ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
127
+ ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
128
+
129
+ clf_model_name = 'eventdata-utd/conflibert-binary-classification'
130
+ clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name).to(device)
131
+ clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)
132
+
133
+ multi_clf_model_name = 'eventdata-utd/conflibert-satp-relevant-multilabel'
134
+ multi_clf_model = AutoModelForSequenceClassification.from_pretrained(multi_clf_model_name).to(device)
135
+ multi_clf_tokenizer = AutoTokenizer.from_pretrained(multi_clf_model_name)
136
+
137
+
138
+ # ============================================================================
139
+ # UTILITY FUNCTIONS
140
+ # ============================================================================
141
+
142
+ def get_path(f):
143
+ """Get file path from Gradio file component output."""
144
+ if f is None:
145
+ return None
146
+ return f if isinstance(f, str) else getattr(f, 'name', str(f))
147
+
148
+
149
+ def truncate_text(text, tokenizer, max_length=MAX_TOKEN_LENGTH):
150
+ tokens = tokenizer.encode(text, truncation=False)
151
+ if len(tokens) > max_length:
152
+ tokens = tokens[:max_length - 1] + [tokenizer.sep_token_id]
153
+ return tokenizer.decode(tokens, skip_special_tokens=True)
154
+ return text
155
+
156
+
157
+ def info_callout(text):
158
+ """Wrap markdown text in a styled callout div to avoid Gradio double-border."""
159
+ return (
160
+ "<div class='info-callout-inner' style='"
161
+ "background: #fff7f3; border-left: 3px solid #ff6b35; "
162
+ "padding: 0.75rem 1rem; border-radius: 0 8px 8px 0; "
163
+ "font-size: 0.9rem;'>\n\n"
164
+ f"{text}\n\n</div>"
165
+ )
166
+
167
+
168
+ def handle_error(e, default_limit=512):
169
+ msg = str(e)
170
+ match = re.search(
171
+ r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)", msg
172
+ )
173
+ if match:
174
+ return (
175
+ f"<span style='color: #ef4444; font-weight: 600;'>"
176
+ f"Error: Input ({match.group(1)} tokens) exceeds model limit ({match.group(2)})</span>"
177
+ )
178
+ match_qa = re.search(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)", msg)
179
+ if match_qa:
180
+ return (
181
+ f"<span style='color: #ef4444; font-weight: 600;'>"
182
+ f"Error: Input too long for model (limit: {match_qa.group(2)} tokens)</span>"
183
+ )
184
+ return f"<span style='color: #ef4444; font-weight: 600;'>Error: {msg}</span>"
185
+
186
+
187
+ # ============================================================================
188
+ # INFERENCE FUNCTIONS
189
+ # ============================================================================
190
+
191
+ def question_answering(context, question):
192
+ if not context or not question:
193
+ return "Please provide both context and question."
194
+ try:
195
+ inputs = qa_tokenizer(question, context, return_tensors='tf', truncation=True)
196
+ outputs = qa_model(inputs)
197
+ start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
198
+ end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
199
+ tokens = qa_tokenizer.convert_ids_to_tokens(
200
+ inputs['input_ids'].numpy()[0][start:end]
201
+ )
202
+ answer = qa_tokenizer.convert_tokens_to_string(tokens)
203
+ return f"<span style='color: #10b981; font-weight: 600;'>{answer}</span>"
204
+ except Exception as e:
205
+ return handle_error(e)
206
+
207
+
208
+ def named_entity_recognition(text, output_format='html'):
209
+ if not text:
210
+ return "Please provide text for analysis."
211
+ try:
212
+ inputs = ner_tokenizer(text, return_tensors='pt', truncation=True)
213
+ with torch.no_grad():
214
+ outputs = ner_model(**inputs)
215
+ results = outputs.logits.argmax(dim=2).squeeze().tolist()
216
+ tokens = ner_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist())
217
+ tokens = [t.replace('[UNK]', "'") for t in tokens]
218
+
219
+ entities = []
220
+ seen_labels = set()
221
+ current_entity = []
222
+ current_label = None
223
+
224
+ for i in range(len(tokens)):
225
+ token = tokens[i]
226
+ label = ner_model.config.id2label[results[i]].split('-')[-1]
227
+
228
+ if token.startswith('##'):
229
+ if entities:
230
+ if output_format == 'html':
231
+ entities[-1][0] += token[2:]
232
+ elif current_entity:
233
+ current_entity[-1] = current_entity[-1] + token[2:]
234
+ else:
235
+ if output_format == 'csv':
236
+ if label != 'O':
237
+ if label == current_label:
238
+ current_entity.append(token)
239
+ else:
240
+ if current_entity:
241
+ entities.append([' '.join(current_entity), current_label])
242
+ current_entity = [token]
243
+ current_label = label
244
+ else:
245
+ if current_entity:
246
+ entities.append([' '.join(current_entity), current_label])
247
+ current_entity = []
248
+ current_label = None
249
+ else:
250
+ entities.append([token, label])
251
+
252
+ if label != 'O':
253
+ seen_labels.add(label)
254
+
255
+ if output_format == 'csv' and current_entity:
256
+ entities.append([' '.join(current_entity), current_label])
257
+
258
+ if output_format == 'csv':
259
+ grouped = {}
260
+ for token, label in entities:
261
+ if label != 'O':
262
+ grouped.setdefault(label, []).append(token)
263
+ parts = []
264
+ for label, toks in grouped.items():
265
+ unique = list(dict.fromkeys(toks))
266
+ parts.append(f"{label}: {' | '.join(unique)}")
267
+ return ' || '.join(parts)
268
+
269
+ # HTML output
270
+ highlighted = ""
271
+ for token, label in entities:
272
+ color = NER_LABELS.get(label, 'inherit')
273
+ if label != 'O':
274
+ highlighted += (
275
+ f"<span style='color: {color}; font-weight: 600;'>{token}</span> "
276
+ )
277
+ else:
278
+ highlighted += f"{token} "
279
+
280
+ if seen_labels:
281
+ legend_items = ""
282
+ for label in sorted(seen_labels):
283
+ color = NER_LABELS.get(label, '#666')
284
+ legend_items += (
285
+ f"<li style='color: {color}; font-weight: 600; "
286
+ f"background: {color}15; padding: 2px 8px; border-radius: 4px; "
287
+ f"font-size: 0.85rem;'>{label}</li>"
288
+ )
289
+ legend = (
290
+ f"<div style='margin-top: 1rem; padding-top: 0.75rem; "
291
+ f"border-top: 1px solid #e5e7eb;'>"
292
+ f"<strong>Entities found:</strong>"
293
+ f"<ul style='list-style: none; padding: 0; display: flex; "
294
+ f"flex-wrap: wrap; gap: 0.5rem; margin-top: 0.5rem;'>"
295
+ f"{legend_items}</ul></div>"
296
+ )
297
+ return f"<div style='line-height: 1.8;'>{highlighted}</div>{legend}"
298
+ else:
299
+ return (
300
+ f"<div style='line-height: 1.8;'>{highlighted}</div>"
301
+ f"<div style='color: #888; margin-top: 0.5rem;'>No entities detected.</div>"
302
+ )
303
+
304
+ except Exception as e:
305
+ return handle_error(e)
306
+
307
+
308
+ def predict_with_model(text, model, tokenizer):
309
+ """Run inference with an arbitrary classification model."""
310
+ model.eval()
311
+ dev = next(model.parameters()).device
312
+ inputs = tokenizer(
313
+ text, return_tensors='pt', truncation=True, padding=True, max_length=512
314
+ )
315
+ inputs = {k: v.to(dev) for k, v in inputs.items()}
316
+
317
+ with torch.no_grad():
318
+ outputs = model(**inputs)
319
+
320
+ probs = torch.softmax(outputs.logits, dim=1).squeeze()
321
+ predicted = torch.argmax(probs).item()
322
+ num_classes = probs.shape[0] if probs.dim() > 0 else 1
323
+
324
+ lines = []
325
+ for i in range(num_classes):
326
+ p = probs[i].item() * 100 if probs.dim() > 0 else probs.item() * 100
327
+ if i == predicted:
328
+ lines.append(
329
+ f"<span style='color: #10b981; font-weight: 600;'>"
330
+ f"Class {i}: {p:.2f}% (predicted)</span>"
331
+ )
332
+ else:
333
+ lines.append(f"<span style='color: #9ca3af;'>Class {i}: {p:.2f}%</span>")
334
+ return "<br>".join(lines)
335
+
336
+
337
+ def text_classification(text, custom_model=None, custom_tokenizer=None):
338
+ if not text:
339
+ return "Please provide text for classification."
340
+ try:
341
+ # Use custom model if loaded
342
+ if custom_model is not None and custom_tokenizer is not None:
343
+ return predict_with_model(text, custom_model, custom_tokenizer)
344
+
345
+ # Pretrained binary classifier
346
+ inputs = clf_tokenizer(
347
+ text, return_tensors='pt', truncation=True, padding=True
348
+ ).to(device)
349
+ with torch.no_grad():
350
+ outputs = clf_model(**inputs)
351
+ predicted = torch.argmax(outputs.logits, dim=1).item()
352
+ confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100
353
+
354
+ if predicted == 1:
355
+ return (
356
+ f"<span style='color: #10b981; font-weight: 600;'>"
357
+ f"Positive -- Related to conflict, violence, or politics. "
358
+ f"(Confidence: {confidence:.1f}%)</span>"
359
+ )
360
+ else:
361
+ return (
362
+ f"<span style='color: #ef4444; font-weight: 600;'>"
363
+ f"Negative -- Not related to conflict, violence, or politics. "
364
+ f"(Confidence: {confidence:.1f}%)</span>"
365
+ )
366
+ except Exception as e:
367
+ return handle_error(e)
368
+
369
+
370
+ def multilabel_classification(text):
371
+ if not text:
372
+ return "Please provide text for classification."
373
+ try:
374
+ inputs = multi_clf_tokenizer(
375
+ text, return_tensors='pt', truncation=True, padding=True
376
+ ).to(device)
377
+ with torch.no_grad():
378
+ outputs = multi_clf_model(**inputs)
379
+ probs = torch.sigmoid(outputs.logits).squeeze().tolist()
380
+
381
+ results = []
382
+ for i in range(len(probs)):
383
+ conf = probs[i] * 100
384
+ if probs[i] >= 0.5:
385
+ results.append(
386
+ f"<span style='color: #10b981; font-weight: 600;'>"
387
+ f"{MULTI_CLASS_NAMES[i]}: {conf:.1f}%</span>"
388
+ )
389
+ else:
390
+ results.append(
391
+ f"<span style='color: #9ca3af;'>"
392
+ f"{MULTI_CLASS_NAMES[i]}: {conf:.1f}%</span>"
393
+ )
394
+ return "<br>".join(results)
395
+ except Exception as e:
396
+ return handle_error(e)
397
+
398
+
399
+ # ============================================================================
400
+ # CSV BATCH PROCESSING
401
+ # ============================================================================
402
+
403
+ def process_csv_ner(file):
404
+ path = get_path(file)
405
+ if path is None:
406
+ return None
407
+ df = pd.read_csv(path)
408
+ if 'text' not in df.columns:
409
+ raise ValueError("CSV must contain a 'text' column")
410
+
411
+ entities = []
412
+ for text in df['text']:
413
+ if pd.isna(text):
414
+ entities.append("")
415
+ else:
416
+ entities.append(named_entity_recognition(str(text), output_format='csv'))
417
+ df['entities'] = entities
418
+
419
+ out = tempfile.NamedTemporaryFile(suffix='_ner_results.csv', delete=False)
420
+ df.to_csv(out.name, index=False)
421
+ return out.name
422
+
423
+
424
+ def process_csv_binary(file, custom_model=None, custom_tokenizer=None):
425
+ path = get_path(file)
426
+ if path is None:
427
+ return None
428
+ df = pd.read_csv(path)
429
+ if 'text' not in df.columns:
430
+ raise ValueError("CSV must contain a 'text' column")
431
+
432
+ results = []
433
+ for text in df['text']:
434
+ if pd.isna(text):
435
+ results.append("")
436
+ else:
437
+ html = text_classification(str(text), custom_model, custom_tokenizer)
438
+ results.append(re.sub(r'<[^>]+>', '', html).strip())
439
+ df['classification_results'] = results
440
+
441
+ out = tempfile.NamedTemporaryFile(suffix='_classification_results.csv', delete=False)
442
+ df.to_csv(out.name, index=False)
443
+ return out.name
444
+
445
+
446
+ def process_csv_multilabel(file):
447
+ path = get_path(file)
448
+ if path is None:
449
+ return None
450
+ df = pd.read_csv(path)
451
+ if 'text' not in df.columns:
452
+ raise ValueError("CSV must contain a 'text' column")
453
+
454
+ results = []
455
+ for text in df['text']:
456
+ if pd.isna(text):
457
+ results.append("")
458
+ else:
459
+ html = multilabel_classification(str(text))
460
+ results.append(re.sub(r'<[^>]+>', '', html).strip())
461
+ df['multilabel_results'] = results
462
+
463
+ out = tempfile.NamedTemporaryFile(suffix='_multilabel_results.csv', delete=False)
464
+ df.to_csv(out.name, index=False)
465
+ return out.name
466
+
467
+
468
+ # ============================================================================
469
+ # FINETUNING
470
+ # ============================================================================
471
+
472
+ class TextClassificationDataset(TorchDataset):
473
+ """PyTorch Dataset for text classification with HuggingFace tokenizers."""
474
+
475
+ def __init__(self, texts, labels, tokenizer, max_length=512):
476
+ self.encodings = tokenizer(
477
+ texts, truncation=True, padding=True,
478
+ max_length=max_length, return_tensors=None,
479
+ )
480
+ self.labels = labels
481
+
482
+ def __getitem__(self, idx):
483
+ item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
484
+ item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
485
+ return item
486
+
487
+ def __len__(self):
488
+ return len(self.labels)
489
+
490
+
491
+ def parse_data_file(file_path):
492
+ """Parse a TSV/CSV data file. Expected format: text<separator>label (no header).
493
+ Labels must be integers. Returns (texts, labels, num_labels)."""
494
+ path = get_path(file_path)
495
+ texts, labels = [], []
496
+
497
+ # Detect delimiter from first line
498
+ with open(path, 'r', encoding='utf-8') as f:
499
+ first_line = f.readline()
500
+ delimiter = '\t' if '\t' in first_line else ','
501
+
502
+ with open(path, 'r', encoding='utf-8') as f:
503
+ reader = csv.reader(f, delimiter=delimiter, quotechar='"')
504
+ for row in reader:
505
+ if len(row) < 2:
506
+ continue
507
+ try:
508
+ label = int(row[-1].strip())
509
+ text = row[0].strip() if len(row) == 2 else delimiter.join(row[:-1]).strip()
510
+ if text:
511
+ texts.append(text)
512
+ labels.append(label)
513
+ except (ValueError, IndexError):
514
+ continue # skip header or malformed rows
515
+
516
+ if not texts:
517
+ raise ValueError(
518
+ "No valid data rows found. Expected format: text<tab>label (no header row)"
519
+ )
520
+
521
+ num_labels = max(labels) + 1
522
+ return texts, labels, num_labels
523
+
524
+
525
+ class LogCallback(TrainerCallback):
526
+ """Captures training logs for display in the UI."""
527
+
528
+ def __init__(self):
529
+ self.entries = []
530
+
531
+ def on_log(self, args, state, control, logs=None, **kwargs):
532
+ if logs:
533
+ self.entries.append({**logs})
534
+
535
+ def format(self):
536
+ lines = []
537
+ skip_keys = {
538
+ 'total_flos', 'train_runtime', 'train_samples_per_second',
539
+ 'train_steps_per_second', 'train_loss',
540
+ }
541
+ for entry in self.entries:
542
+ parts = []
543
+ for k, v in sorted(entry.items()):
544
+ if k in skip_keys:
545
+ continue
546
+ if isinstance(v, float):
547
+ parts.append(f"{k}: {v:.4f}")
548
+ elif isinstance(v, (int, np.integer)):
549
+ parts.append(f"{k}: {v}")
550
+ if parts:
551
+ lines.append(" ".join(parts))
552
+ return "\n".join(lines)
553
+
554
+
555
+ def make_compute_metrics(task_type):
556
+ """Factory for compute_metrics function based on task type."""
557
+
558
+ def compute_metrics(eval_pred):
559
+ logits, labels = eval_pred
560
+ preds = np.argmax(logits, axis=-1)
561
+ acc = sk_accuracy(labels, preds)
562
+
563
+ if task_type == "Binary":
564
+ return {
565
+ 'accuracy': acc,
566
+ 'precision': sk_precision(labels, preds, zero_division=0),
567
+ 'recall': sk_recall(labels, preds, zero_division=0),
568
+ 'f1': sk_f1(labels, preds, zero_division=0),
569
+ }
570
+ else:
571
+ return {
572
+ 'accuracy': acc,
573
+ 'f1_macro': sk_f1(labels, preds, average='macro', zero_division=0),
574
+ 'f1_micro': sk_f1(labels, preds, average='micro', zero_division=0),
575
+ 'precision_macro': sk_precision(
576
+ labels, preds, average='macro', zero_division=0
577
+ ),
578
+ 'precision_micro': sk_precision(
579
+ labels, preds, average='micro', zero_division=0
580
+ ),
581
+ 'recall_macro': sk_recall(
582
+ labels, preds, average='macro', zero_division=0
583
+ ),
584
+ 'recall_micro': sk_recall(
585
+ labels, preds, average='micro', zero_division=0
586
+ ),
587
+ }
588
+
589
+ return compute_metrics
590
+
591
+
592
+ def run_finetuning(
593
+ train_file, dev_file, test_file, task_type, model_display_name,
594
+ epochs, batch_size, lr, weight_decay, warmup_ratio, max_seq_len,
595
+ grad_accum, fp16, patience, scheduler,
596
+ progress=gr.Progress(track_tqdm=True),
597
+ ):
598
+ """Main finetuning function. Returns logs, metrics, model state, and visibility updates."""
599
+ try:
600
+ # Validate inputs
601
+ if train_file is None or dev_file is None or test_file is None:
602
+ raise ValueError("Please upload all three data files (train, dev, test).")
603
+
604
+ epochs = int(epochs)
605
+ batch_size = int(batch_size)
606
+ max_seq_len = int(max_seq_len)
607
+ grad_accum = int(grad_accum)
608
+ patience = int(patience)
609
+
610
+ # Parse data files
611
+ train_texts, train_labels, n_train = parse_data_file(train_file)
612
+ dev_texts, dev_labels, n_dev = parse_data_file(dev_file)
613
+ test_texts, test_labels, n_test = parse_data_file(test_file)
614
+
615
+ num_labels = max(n_train, n_dev, n_test)
616
+ if task_type == "Binary" and num_labels > 2:
617
+ raise ValueError(
618
+ f"Binary task selected but found {num_labels} label classes in data. "
619
+ f"Use Multiclass instead."
620
+ )
621
+ if task_type == "Binary":
622
+ num_labels = 2
623
+
624
+ # Load model and tokenizer
625
+ model_id = FINETUNE_MODELS[model_display_name]
626
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
627
+ model = AutoModelForSequenceClassification.from_pretrained(
628
+ model_id, num_labels=num_labels
629
+ )
630
+
631
+ # Create datasets
632
+ train_ds = TextClassificationDataset(
633
+ train_texts, train_labels, tokenizer, max_seq_len
634
+ )
635
+ dev_ds = TextClassificationDataset(
636
+ dev_texts, dev_labels, tokenizer, max_seq_len
637
+ )
638
+ test_ds = TextClassificationDataset(
639
+ test_texts, test_labels, tokenizer, max_seq_len
640
+ )
641
+
642
+ # Output directory
643
+ output_dir = tempfile.mkdtemp(prefix='conflibert_ft_')
644
+
645
+ # Training arguments
646
+ best_metric = 'f1' if task_type == 'Binary' else 'f1_macro'
647
+ training_args = TrainingArguments(
648
+ output_dir=output_dir,
649
+ num_train_epochs=epochs,
650
+ per_device_train_batch_size=batch_size,
651
+ per_device_eval_batch_size=batch_size * 2,
652
+ learning_rate=lr,
653
+ weight_decay=weight_decay,
654
+ warmup_ratio=warmup_ratio,
655
+ gradient_accumulation_steps=grad_accum,
656
+ fp16=fp16 and torch.cuda.is_available(),
657
+ eval_strategy='epoch',
658
+ save_strategy='epoch',
659
+ load_best_model_at_end=True,
660
+ metric_for_best_model=best_metric,
661
+ greater_is_better=True,
662
+ logging_steps=10,
663
+ save_total_limit=2,
664
+ lr_scheduler_type=scheduler,
665
+ report_to='none',
666
+ seed=42,
667
+ )
668
+
669
+ # Callbacks
670
+ log_callback = LogCallback()
671
+ callbacks = [log_callback]
672
+ if patience > 0:
673
+ callbacks.append(EarlyStoppingCallback(early_stopping_patience=patience))
674
+
675
+ # Create Trainer
676
+ trainer = Trainer(
677
+ model=model,
678
+ args=training_args,
679
+ train_dataset=train_ds,
680
+ eval_dataset=dev_ds,
681
+ compute_metrics=make_compute_metrics(task_type),
682
+ callbacks=callbacks,
683
+ )
684
+
685
+ # Train
686
+ train_result = trainer.train()
687
+
688
+ # Evaluate on test set
689
+ test_results = trainer.evaluate(test_ds, metric_key_prefix='test')
690
+
691
+ # Build log text
692
+ header = (
693
+ f"=== Configuration ===\n"
694
+ f"Model: {model_display_name}\n"
695
+ f" {model_id}\n"
696
+ f"Task: {task_type} Classification ({num_labels} classes)\n"
697
+ f"Data: {len(train_texts)} train / {len(dev_texts)} dev / {len(test_texts)} test\n"
698
+ f"Epochs: {epochs} Batch: {batch_size} LR: {lr} Scheduler: {scheduler}\n"
699
+ f"\n=== Training Log ===\n"
700
+ )
701
+ runtime = train_result.metrics.get('train_runtime', 0)
702
+ footer = (
703
+ f"\n=== Training Complete ===\n"
704
+ f"Time: {runtime:.1f}s ({runtime / 60:.1f} min)\n"
705
+ )
706
+ log_text = header + log_callback.format() + footer
707
+
708
+ # Build metrics DataFrame
709
+ metrics_data = []
710
+ for k, v in sorted(test_results.items()):
711
+ if isinstance(v, (int, float, np.floating, np.integer)) and k != 'test_epoch':
712
+ name = k.replace('test_', '').replace('_', ' ').title()
713
+ metrics_data.append([name, f"{float(v):.4f}"])
714
+ metrics_df = pd.DataFrame(metrics_data, columns=['Metric', 'Score'])
715
+
716
+ # Move trained model to CPU for inference
717
+ trained_model = trainer.model.cpu()
718
+ trained_model.eval()
719
+
720
+ return (
721
+ log_text, metrics_df, trained_model, tokenizer, num_labels,
722
+ gr.Column(visible=True), gr.Column(visible=True),
723
+ )
724
+
725
+ except Exception as e:
726
+ error_log = f"Training failed:\n{str(e)}"
727
+ empty_df = pd.DataFrame(columns=['Metric', 'Score'])
728
+ return (
729
+ error_log, empty_df, None, None, None,
730
+ gr.Column(visible=False), gr.Column(visible=False),
731
+ )
732
+
733
+
734
+ # ============================================================================
735
+ # MODEL MANAGEMENT (predict, save, load)
736
+ # ============================================================================
737
+
738
+ def predict_finetuned(text, model_state, tokenizer_state, num_labels_state):
739
+ """Run prediction with the finetuned model stored in gr.State."""
740
+ if not text:
741
+ return "Please enter some text."
742
+ if model_state is None:
743
+ return "No model available. Please train a model first."
744
+ return predict_with_model(text, model_state, tokenizer_state)
745
+
746
+
747
+ def save_finetuned_model(save_path, model_state, tokenizer_state):
748
+ """Save the finetuned model and tokenizer to disk."""
749
+ if model_state is None:
750
+ return "No model to save. Please train a model first."
751
+ if not save_path:
752
+ return "Please specify a save directory."
753
+ try:
754
+ os.makedirs(save_path, exist_ok=True)
755
+ model_state.save_pretrained(save_path)
756
+ tokenizer_state.save_pretrained(save_path)
757
+ return f"Model saved successfully to: {save_path}"
758
+ except Exception as e:
759
+ return f"Error saving model: {str(e)}"
760
+
761
+
762
+ def load_custom_model(path):
763
+ """Load a finetuned classification model from disk."""
764
+ if not path or not os.path.isdir(path):
765
+ return None, None, "Invalid path. Please enter a valid model directory."
766
+ try:
767
+ tokenizer = AutoTokenizer.from_pretrained(path)
768
+ model = AutoModelForSequenceClassification.from_pretrained(path)
769
+ model.eval()
770
+ n = model.config.num_labels
771
+ return model, tokenizer, f"Loaded model with {n} classes from: {path}"
772
+ except Exception as e:
773
+ return None, None, f"Error loading model: {str(e)}"
774
+
775
+
776
+ def reset_custom_model():
777
+ """Reset to the pretrained ConfliBERT binary classifier."""
778
+ return None, None, "Reset to pretrained ConfliBERT binary classifier."
779
+
780
+
781
+ def batch_predict_finetuned(file, model_state, tokenizer_state, num_labels_state):
782
+ """Run batch predictions on a CSV using the finetuned model."""
783
+ if model_state is None:
784
+ return None
785
+ path = get_path(file)
786
+ if path is None:
787
+ return None
788
+
789
+ df = pd.read_csv(path)
790
+ if 'text' not in df.columns:
791
+ raise ValueError("CSV must contain a 'text' column")
792
+
793
+ model_state.eval()
794
+ dev = next(model_state.parameters()).device
795
+
796
+ predictions, confidences = [], []
797
+ for text in df['text']:
798
+ if pd.isna(text):
799
+ predictions.append("")
800
+ confidences.append("")
801
+ continue
802
+
803
+ inputs = tokenizer_state(
804
+ str(text), return_tensors='pt', truncation=True,
805
+ padding=True, max_length=512,
806
+ )
807
+ inputs = {k: v.to(dev) for k, v in inputs.items()}
808
+ with torch.no_grad():
809
+ outputs = model_state(**inputs)
810
+ probs = torch.softmax(outputs.logits, dim=1).squeeze()
811
+ pred = torch.argmax(probs).item()
812
+ conf = probs[pred].item() * 100
813
+ predictions.append(str(pred))
814
+ confidences.append(f"{conf:.1f}%")
815
+
816
+ df['predicted_class'] = predictions
817
+ df['confidence'] = confidences
818
+
819
+ out = tempfile.NamedTemporaryFile(suffix='_predictions.csv', delete=False)
820
+ df.to_csv(out.name, index=False)
821
+ return out.name
822
+
823
+
824
+ EXAMPLES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "examples")
825
+
826
+
827
+ def load_example_binary():
828
+ """Load the binary classification example dataset."""
829
+ return (
830
+ os.path.join(EXAMPLES_DIR, "binary", "train.tsv"),
831
+ os.path.join(EXAMPLES_DIR, "binary", "dev.tsv"),
832
+ os.path.join(EXAMPLES_DIR, "binary", "test.tsv"),
833
+ "Binary",
834
+ )
835
+
836
+
837
+ def load_example_multiclass():
838
+ """Load the multiclass classification example dataset."""
839
+ return (
840
+ os.path.join(EXAMPLES_DIR, "multiclass", "train.tsv"),
841
+ os.path.join(EXAMPLES_DIR, "multiclass", "dev.tsv"),
842
+ os.path.join(EXAMPLES_DIR, "multiclass", "test.tsv"),
843
+ "Multiclass",
844
+ )
845
+
846
+
847
+ def run_comparison(
848
+ train_file, dev_file, test_file, task_type, selected_models,
849
+ epochs, batch_size, lr,
850
+ progress=gr.Progress(track_tqdm=True),
851
+ ):
852
+ """Train multiple models on the same data and compare performance + ROC curves."""
853
+ import plotly.graph_objects as go
854
+ from plotly.subplots import make_subplots
855
+
856
+ empty = ("", None, None, None, gr.Column(visible=False))
857
+ try:
858
+ if not selected_models or len(selected_models) < 2:
859
+ return ("Select at least 2 models to compare.",) + empty[1:]
860
+ if train_file is None or dev_file is None or test_file is None:
861
+ return ("Upload all 3 data files first.",) + empty[1:]
862
+
863
+ epochs = int(epochs)
864
+ batch_size = int(batch_size)
865
+
866
+ train_texts, train_labels, n_train = parse_data_file(train_file)
867
+ dev_texts, dev_labels, n_dev = parse_data_file(dev_file)
868
+ test_texts, test_labels, n_test = parse_data_file(test_file)
869
+ num_labels = max(n_train, n_dev, n_test)
870
+ if task_type == "Binary":
871
+ num_labels = 2
872
+
873
+ # Only keep these metrics for the table and bar chart
874
+ if task_type == "Binary":
875
+ keep_metrics = {'Accuracy', 'Precision', 'Recall', 'F1'}
876
+ else:
877
+ keep_metrics = {
878
+ 'Accuracy', 'F1 Macro', 'F1 Micro',
879
+ 'Precision Macro', 'Recall Macro',
880
+ }
881
+
882
+ results = []
883
+ roc_data = {} # model_name -> (true_labels, probabilities)
884
+ log_lines = []
885
+
886
+ for i, model_display_name in enumerate(selected_models):
887
+ model_id = FINETUNE_MODELS[model_display_name]
888
+ short_name = model_display_name.split(" (")[0]
889
+ log_lines.append(f"[{i + 1}/{len(selected_models)}] Training {short_name}...")
890
+
891
+ try:
892
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
893
+ model = AutoModelForSequenceClassification.from_pretrained(
894
+ model_id, num_labels=num_labels,
895
+ )
896
+ train_ds = TextClassificationDataset(train_texts, train_labels, tokenizer, 512)
897
+ dev_ds = TextClassificationDataset(dev_texts, dev_labels, tokenizer, 512)
898
+ test_ds = TextClassificationDataset(test_texts, test_labels, tokenizer, 512)
899
+
900
+ output_dir = tempfile.mkdtemp(prefix='conflibert_cmp_')
901
+ best_metric = 'f1' if task_type == 'Binary' else 'f1_macro'
902
+
903
+ training_args = TrainingArguments(
904
+ output_dir=output_dir,
905
+ num_train_epochs=epochs,
906
+ per_device_train_batch_size=batch_size,
907
+ per_device_eval_batch_size=batch_size * 2,
908
+ learning_rate=lr,
909
+ weight_decay=0.01,
910
+ warmup_ratio=0.1,
911
+ eval_strategy='epoch',
912
+ save_strategy='epoch',
913
+ load_best_model_at_end=True,
914
+ metric_for_best_model=best_metric,
915
+ greater_is_better=True,
916
+ logging_steps=50,
917
+ save_total_limit=1,
918
+ report_to='none',
919
+ seed=42,
920
+ )
921
+
922
+ trainer = Trainer(
923
+ model=model,
924
+ args=training_args,
925
+ train_dataset=train_ds,
926
+ eval_dataset=dev_ds,
927
+ compute_metrics=make_compute_metrics(task_type),
928
+ )
929
+
930
+ train_result = trainer.train()
931
+
932
+ # Get predictions for ROC curves
933
+ pred_output = trainer.predict(test_ds)
934
+ logits = pred_output.predictions
935
+ true_labels = pred_output.label_ids
936
+ probs = torch.softmax(torch.tensor(logits), dim=1).numpy()
937
+ roc_data[short_name] = (true_labels, probs)
938
+
939
+ # Collect classification metrics only
940
+ test_results = trainer.evaluate(test_ds, metric_key_prefix='test')
941
+ row = {'Model': short_name}
942
+ for k, v in sorted(test_results.items()):
943
+ if not isinstance(v, (int, float, np.floating, np.integer)):
944
+ continue
945
+ name = k.replace('test_', '').replace('_', ' ').title()
946
+ if name in keep_metrics:
947
+ row[name] = round(float(v), 4)
948
+ results.append(row)
949
+
950
+ runtime = train_result.metrics.get('train_runtime', 0)
951
+ log_lines.append(f" Done in {runtime:.1f}s")
952
+
953
+ del model, trainer, tokenizer, train_ds, dev_ds, test_ds
954
+ gc.collect()
955
+ if torch.cuda.is_available():
956
+ torch.cuda.empty_cache()
957
+
958
+ except Exception as e:
959
+ log_lines.append(f" Failed: {str(e)}")
960
+
961
+ log_lines.append(f"\nComparison complete. {len(results)} models evaluated.")
962
+ log_text = "\n".join(log_lines)
963
+
964
+ if not results:
965
+ return log_text, None, None, None, gr.Column(visible=False)
966
+
967
+ comparison_df = pd.DataFrame(results)
968
+
969
+ # --- Bar chart: classification metrics only ---
970
+ metric_cols = [c for c in comparison_df.columns if c in keep_metrics]
971
+ colors = ['#ff6b35', '#3b82f6', '#10b981', '#8b5cf6', '#f59e0b']
972
+ fig_bar = go.Figure()
973
+ for j, metric in enumerate(metric_cols):
974
+ fig_bar.add_trace(go.Bar(
975
+ name=metric,
976
+ x=comparison_df['Model'],
977
+ y=comparison_df[metric],
978
+ text=comparison_df[metric].apply(
979
+ lambda x: f'{x:.3f}' if isinstance(x, float) else ''
980
+ ),
981
+ textposition='auto',
982
+ marker_color=colors[j % len(colors)],
983
+ ))
984
+ fig_bar.update_layout(
985
+ barmode='group',
986
+ yaxis_title='Score', yaxis_range=[0, 1.05],
987
+ template='plotly_white',
988
+ legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
989
+ height=400, margin=dict(t=40, b=40),
990
+ )
991
+
992
+ # --- ROC curves ---
993
+ model_colors = ['#ff6b35', '#3b82f6', '#10b981', '#8b5cf6',
994
+ '#f59e0b', '#ec4899', '#06b6d4']
995
+ fig_roc = go.Figure()
996
+ for j, (model_name, (labels, probs)) in enumerate(roc_data.items()):
997
+ color = model_colors[j % len(model_colors)]
998
+ if num_labels == 2:
999
+ fpr, tpr, _ = roc_curve(labels, probs[:, 1])
1000
+ roc_auc_val = sk_auc(fpr, tpr)
1001
+ fig_roc.add_trace(go.Scatter(
1002
+ x=fpr, y=tpr, mode='lines',
1003
+ name=f'{model_name} (AUC = {roc_auc_val:.3f})',
1004
+ line=dict(color=color, width=2),
1005
+ ))
1006
+ else:
1007
+ # Macro-average ROC for multiclass
1008
+ labels_bin = label_binarize(labels, classes=list(range(num_labels)))
1009
+ all_fpr = np.linspace(0, 1, 200)
1010
+ mean_tpr = np.zeros_like(all_fpr)
1011
+ for c in range(num_labels):
1012
+ fpr_c, tpr_c, _ = roc_curve(labels_bin[:, c], probs[:, c])
1013
+ mean_tpr += np.interp(all_fpr, fpr_c, tpr_c)
1014
+ mean_tpr /= num_labels
1015
+ roc_auc_val = sk_auc(all_fpr, mean_tpr)
1016
+ fig_roc.add_trace(go.Scatter(
1017
+ x=all_fpr, y=mean_tpr, mode='lines',
1018
+ name=f'{model_name} (macro AUC = {roc_auc_val:.3f})',
1019
+ line=dict(color=color, width=2),
1020
+ ))
1021
+
1022
+ # Diagonal reference line
1023
+ fig_roc.add_trace(go.Scatter(
1024
+ x=[0, 1], y=[0, 1], mode='lines',
1025
+ line=dict(dash='dash', color='#ccc', width=1),
1026
+ showlegend=False,
1027
+ ))
1028
+ fig_roc.update_layout(
1029
+ xaxis_title='False Positive Rate',
1030
+ yaxis_title='True Positive Rate',
1031
+ template='plotly_white',
1032
+ legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
1033
+ height=400, margin=dict(t=40, b=40),
1034
+ )
1035
+
1036
+ return log_text, comparison_df, fig_bar, fig_roc, gr.Column(visible=True)
1037
+
1038
+ except Exception as e:
1039
+ return f"Comparison failed: {str(e)}", None, None, None, gr.Column(visible=False)
1040
+
1041
+
1042
+ # ============================================================================
1043
+ # THEME & CSS
1044
+ # ============================================================================
1045
+
1046
+ utd_orange = gr.themes.Color(
1047
+ c50="#fff7f3", c100="#ffead9", c200="#ffd4b3", c300="#ffb380",
1048
+ c400="#ff8c52", c500="#ff6b35", c600="#e8551f", c700="#c2410c",
1049
+ c800="#9a3412", c900="#7c2d12", c950="#431407",
1050
+ )
1051
+
1052
+ theme = gr.themes.Soft(
1053
+ primary_hue=utd_orange,
1054
+ secondary_hue="neutral",
1055
+ font=gr.themes.GoogleFont("Inter"),
1056
+ )
1057
+
1058
+ custom_css = """
1059
+ /* Top accent bar */
1060
+ .gradio-container::before {
1061
+ content: '';
1062
+ display: block;
1063
+ height: 4px;
1064
+ background: linear-gradient(90deg, #ff6b35, #ff9f40, #ff6b35);
1065
+ position: fixed;
1066
+ top: 0;
1067
+ left: 0;
1068
+ right: 0;
1069
+ z-index: 1000;
1070
+ }
1071
+
1072
+ /* Active tab styling */
1073
+ .tab-nav button.selected {
1074
+ border-bottom-color: #ff6b35 !important;
1075
+ color: #ff6b35 !important;
1076
+ font-weight: 600 !important;
1077
+ }
1078
+
1079
+ /* Log output - monospace */
1080
+ .log-output textarea {
1081
+ font-family: 'JetBrains Mono', 'Fira Code', 'Consolas', monospace !important;
1082
+ font-size: 0.8rem !important;
1083
+ line-height: 1.5 !important;
1084
+ }
1085
+
1086
+ /* Dark mode: info callout adjustment */
1087
+ .dark .info-callout-inner {
1088
+ background: rgba(255, 107, 53, 0.1) !important;
1089
+ color: #ffead9 !important;
1090
+ }
1091
+
1092
+ /* Clean container width */
1093
+ .gradio-container {
1094
+ max-width: 1200px !important;
1095
+ }
1096
+
1097
+ /* Smooth transitions */
1098
+ .gradio-container * {
1099
+ transition: background-color 0.2s ease, border-color 0.2s ease !important;
1100
+ }
1101
+ """
1102
+
1103
+
1104
+ # ============================================================================
1105
+ # GRADIO UI
1106
+ # ============================================================================
1107
+
1108
+ with gr.Blocks(theme=theme, css=custom_css, title="ConfliBERT") as demo:
1109
+
1110
+ # ---- HEADER ----
1111
+ gr.Markdown(
1112
+ "<div style='text-align: center; padding: 1.5rem 0 0.5rem;'>"
1113
+ "<h1 style='font-size: 2.5rem; font-weight: 800; margin: 0;'>"
1114
+ "<a href='https://eventdata.utdallas.edu/conflibert/' target='_blank' "
1115
+ "style='color: #ff6b35; text-decoration: none;'>ConfliBERT</a></h1>"
1116
+ "<p style='color: #888; font-size: 0.95rem; margin: 0.25rem 0 0;'>"
1117
+ "A Pretrained Language Model for Conflict and Political Violence</p></div>"
1118
+ )
1119
+
1120
+ with gr.Tabs():
1121
+
1122
+ # ================================================================
1123
+ # HOME TAB
1124
+ # ================================================================
1125
+ with gr.Tab("Home"):
1126
+ gr.Markdown(
1127
+ "## Welcome to ConfliBERT\n\n"
1128
+ "ConfliBERT is a pretrained language model built specifically for "
1129
+ "conflict and political violence text. This application lets you "
1130
+ "run inference with ConfliBERT's pretrained models and fine-tune "
1131
+ "your own classifiers on custom data. Use the tabs above to get started."
1132
+ )
1133
+
1134
+ with gr.Row(equal_height=True):
1135
+ with gr.Column():
1136
+ gr.Markdown(
1137
+ "### Inference\n\n"
1138
+ "Run pretrained ConfliBERT models on your text. "
1139
+ "Each task has its own tab with single-text analysis "
1140
+ "and CSV batch processing.\n\n"
1141
+ "**Named Entity Recognition**\n"
1142
+ "Identify persons, organizations, locations, weapons, "
1143
+ "and other entities in text. Results are color-coded "
1144
+ "by entity type.\n\n"
1145
+ "**Binary Classification**\n"
1146
+ "Determine whether text is related to conflict, violence, "
1147
+ "or politics (positive) or not (negative). You can also "
1148
+ "load a custom fine-tuned model here.\n\n"
1149
+ "**Multilabel Classification**\n"
1150
+ "Score text against four event categories: Armed Assault, "
1151
+ "Bombing/Explosion, Kidnapping, and Other. Each category "
1152
+ "is scored independently.\n\n"
1153
+ "**Question Answering**\n"
1154
+ "Provide a context passage and ask a question. The model "
1155
+ "extracts the most relevant answer span from the text."
1156
+ )
1157
+ with gr.Column():
1158
+ gr.Markdown(
1159
+ "### Fine-tuning\n\n"
1160
+ "Train your own binary or multiclass text classifier "
1161
+ "on custom labeled data, all within the browser.\n\n"
1162
+ "**Workflow:**\n"
1163
+ "1. Upload your training, validation, and test data as "
1164
+ "TSV files (or load a built-in example dataset)\n"
1165
+ "2. Pick a base model: ConfliBERT, BERT, RoBERTa, "
1166
+ "ModernBERT, DeBERTa, or DistilBERT\n"
1167
+ "3. Configure training parameters (sensible defaults "
1168
+ "are provided)\n"
1169
+ "4. Train and watch progress in real time\n"
1170
+ "5. Review test-set metrics (accuracy, precision, "
1171
+ "recall, F1)\n"
1172
+ "6. Try your model on new text immediately\n"
1173
+ "7. Run batch predictions on a CSV\n"
1174
+ "8. Save the model and load it later in the "
1175
+ "Classification tab\n\n"
1176
+ "**Advanced features:**\n"
1177
+ "- Early stopping with configurable patience\n"
1178
+ "- Learning rate schedulers (linear, cosine, constant)\n"
1179
+ "- Mixed precision training (FP16 on CUDA GPUs)\n"
1180
+ "- Gradient accumulation for larger effective batch sizes\n"
1181
+ "- Weight decay regularization"
1182
+ )
1183
+
1184
+ gr.Markdown(
1185
+ f"---\n\n"
1186
+ f"**Your system:** {get_system_info()}"
1187
+ )
1188
+
1189
+ gr.Markdown(
1190
+ "**Citation:** Brandt, P.T., Alsarra, S., D'Orazio, V., "
1191
+ "Heintze, D., Khan, L., Meher, S., Osorio, J. and Sianan, M., "
1192
+ "2025. Extractive versus Generative Language Models for Political "
1193
+ "Conflict Text Classification. *Political Analysis*, pp.1-29."
1194
+ )
1195
+
1196
+ # ================================================================
1197
+ # NER TAB
1198
+ # ================================================================
1199
+ with gr.Tab("Named Entity Recognition"):
1200
+ gr.Markdown(info_callout(
1201
+ "Identify entities in text such as **persons**, **organizations**, "
1202
+ "**locations**, **weapons**, and more. Results are color-coded by type."
1203
+ ))
1204
+ with gr.Row(equal_height=True):
1205
+ with gr.Column():
1206
+ ner_input = gr.Textbox(
1207
+ lines=6,
1208
+ placeholder="Paste or type text to analyze for entities...",
1209
+ label="Input Text",
1210
+ )
1211
+ ner_btn = gr.Button("Analyze Entities", variant="primary")
1212
+ with gr.Column():
1213
+ ner_output = gr.HTML(label="Results")
1214
+
1215
+ with gr.Accordion("Batch Processing (CSV)", open=False):
1216
+ gr.Markdown(
1217
+ "Upload a CSV file with a `text` column to process "
1218
+ "multiple texts at once."
1219
+ )
1220
+ with gr.Row():
1221
+ ner_csv_in = gr.File(
1222
+ label="Upload CSV", file_types=[".csv"],
1223
+ )
1224
+ ner_csv_out = gr.File(label="Download Results")
1225
+ ner_csv_btn = gr.Button("Process CSV", variant="secondary")
1226
+
1227
+ # ================================================================
1228
+ # BINARY CLASSIFICATION TAB
1229
+ # ================================================================
1230
+ with gr.Tab("Binary Classification"):
1231
+ gr.Markdown(info_callout(
1232
+ "Classify text as **conflict-related** (positive) or "
1233
+ "**not conflict-related** (negative). Uses the pretrained ConfliBERT "
1234
+ "binary classifier by default, or load your own finetuned model below."
1235
+ ))
1236
+
1237
+ custom_clf_model = gr.State(None)
1238
+ custom_clf_tokenizer = gr.State(None)
1239
+
1240
+ with gr.Row(equal_height=True):
1241
+ with gr.Column():
1242
+ clf_input = gr.Textbox(
1243
+ lines=6,
1244
+ placeholder="Paste or type text to classify...",
1245
+ label="Input Text",
1246
+ )
1247
+ clf_btn = gr.Button("Classify", variant="primary")
1248
+ with gr.Column():
1249
+ clf_output = gr.HTML(label="Results")
1250
+
1251
+ with gr.Accordion("Batch Processing (CSV)", open=False):
1252
+ gr.Markdown("Upload a CSV file with a `text` column.")
1253
+ with gr.Row():
1254
+ clf_csv_in = gr.File(label="Upload CSV", file_types=[".csv"])
1255
+ clf_csv_out = gr.File(label="Download Results")
1256
+ clf_csv_btn = gr.Button("Process CSV", variant="secondary")
1257
+
1258
+ with gr.Accordion("Load Custom Model", open=False):
1259
+ gr.Markdown(
1260
+ "Load a finetuned classification model from a local directory "
1261
+ "to use instead of the default pretrained classifier."
1262
+ )
1263
+ clf_model_path = gr.Textbox(
1264
+ label="Model directory path",
1265
+ placeholder="e.g., ./finetuned_model",
1266
+ )
1267
+ with gr.Row():
1268
+ clf_load_btn = gr.Button("Load Model", variant="secondary")
1269
+ clf_reset_btn = gr.Button(
1270
+ "Reset to Pretrained", variant="secondary",
1271
+ )
1272
+ clf_status = gr.Markdown("")
1273
+
1274
+ # ================================================================
1275
+ # MULTILABEL CLASSIFICATION TAB
1276
+ # ================================================================
1277
+ with gr.Tab("Multilabel Classification"):
1278
+ gr.Markdown(info_callout(
1279
+ "Identify multiple event types in text. Each category is scored "
1280
+ "independently: **Armed Assault**, **Bombing/Explosion**, "
1281
+ "**Kidnapping**, **Other**. Categories above 50% confidence "
1282
+ "are highlighted."
1283
+ ))
1284
+ with gr.Row(equal_height=True):
1285
+ with gr.Column():
1286
+ multi_input = gr.Textbox(
1287
+ lines=6,
1288
+ placeholder="Paste or type text to classify...",
1289
+ label="Input Text",
1290
+ )
1291
+ multi_btn = gr.Button("Classify", variant="primary")
1292
+ with gr.Column():
1293
+ multi_output = gr.HTML(label="Results")
1294
+
1295
+ with gr.Accordion("Batch Processing (CSV)", open=False):
1296
+ gr.Markdown("Upload a CSV file with a `text` column.")
1297
+ with gr.Row():
1298
+ multi_csv_in = gr.File(label="Upload CSV", file_types=[".csv"])
1299
+ multi_csv_out = gr.File(label="Download Results")
1300
+ multi_csv_btn = gr.Button("Process CSV", variant="secondary")
1301
+
1302
+ # ================================================================
1303
+ # QUESTION ANSWERING TAB
1304
+ # ================================================================
1305
+ with gr.Tab("Question Answering"):
1306
+ gr.Markdown(info_callout(
1307
+ "Extract answers from a context passage. Provide a paragraph of "
1308
+ "text and ask a question about it. The model will highlight the "
1309
+ "most relevant span."
1310
+ ))
1311
+ with gr.Row(equal_height=True):
1312
+ with gr.Column():
1313
+ qa_context = gr.Textbox(
1314
+ lines=6,
1315
+ placeholder="Paste the context passage here...",
1316
+ label="Context",
1317
+ )
1318
+ qa_question = gr.Textbox(
1319
+ lines=2,
1320
+ placeholder="What would you like to know?",
1321
+ label="Question",
1322
+ )
1323
+ qa_btn = gr.Button("Get Answer", variant="primary")
1324
+ with gr.Column():
1325
+ qa_output = gr.HTML(label="Answer")
1326
+
1327
+ # ================================================================
1328
+ # FINE-TUNE TAB
1329
+ # ================================================================
1330
+ with gr.Tab("Fine-tune"):
1331
+ gr.Markdown(info_callout(
1332
+ "Fine-tune a binary or multiclass classifier on your own data. "
1333
+ "Upload labeled TSV files, pick a base model, and train. "
1334
+ "Or compare multiple models head-to-head on the same dataset."
1335
+ ))
1336
+
1337
+ # -- Data --
1338
+ gr.Markdown("### Data")
1339
+ gr.Markdown(
1340
+ "TSV files, no header, format: `text[TAB]label` "
1341
+ "(binary: 0/1, multiclass: 0, 1, 2, ...)"
1342
+ )
1343
+ with gr.Row():
1344
+ ft_ex_binary_btn = gr.Button(
1345
+ "Load Example: Binary", variant="secondary", size="sm",
1346
+ )
1347
+ ft_ex_multi_btn = gr.Button(
1348
+ "Load Example: Multiclass (4 classes)", variant="secondary", size="sm",
1349
+ )
1350
+ with gr.Row():
1351
+ ft_train_file = gr.File(
1352
+ label="Train", file_types=[".tsv", ".csv", ".txt"],
1353
+ )
1354
+ ft_dev_file = gr.File(
1355
+ label="Validation", file_types=[".tsv", ".csv", ".txt"],
1356
+ )
1357
+ ft_test_file = gr.File(
1358
+ label="Test", file_types=[".tsv", ".csv", ".txt"],
1359
+ )
1360
+
1361
+ # -- Configuration --
1362
+ gr.Markdown("### Configuration")
1363
+ with gr.Row():
1364
+ ft_task = gr.Radio(
1365
+ ["Binary", "Multiclass"],
1366
+ label="Task Type", value="Binary",
1367
+ )
1368
+ ft_model = gr.Dropdown(
1369
+ choices=list(FINETUNE_MODELS.keys()),
1370
+ label="Base Model",
1371
+ value=list(FINETUNE_MODELS.keys())[0],
1372
+ )
1373
+ with gr.Row():
1374
+ ft_epochs = gr.Number(
1375
+ label="Epochs", value=3, minimum=1, maximum=100, precision=0,
1376
+ )
1377
+ ft_batch = gr.Number(
1378
+ label="Batch Size", value=8, minimum=1, maximum=128, precision=0,
1379
+ )
1380
+ ft_lr = gr.Number(
1381
+ label="Learning Rate", value=2e-5, minimum=1e-7, maximum=1e-2,
1382
+ )
1383
+
1384
+ with gr.Accordion("Advanced Settings", open=False):
1385
+ with gr.Row():
1386
+ ft_weight_decay = gr.Number(
1387
+ label="Weight Decay", value=0.01, minimum=0, maximum=1,
1388
+ )
1389
+ ft_warmup = gr.Number(
1390
+ label="Warmup Ratio", value=0.1, minimum=0, maximum=0.5,
1391
+ )
1392
+ ft_max_len = gr.Number(
1393
+ label="Max Sequence Length", value=512,
1394
+ minimum=32, maximum=8192, precision=0,
1395
+ )
1396
+ with gr.Row():
1397
+ ft_grad_accum = gr.Number(
1398
+ label="Gradient Accumulation", value=1,
1399
+ minimum=1, maximum=64, precision=0,
1400
+ )
1401
+ ft_fp16 = gr.Checkbox(
1402
+ label="Mixed Precision (FP16)", value=False,
1403
+ )
1404
+ ft_patience = gr.Number(
1405
+ label="Early Stopping Patience", value=3,
1406
+ minimum=0, maximum=20, precision=0,
1407
+ )
1408
+ ft_scheduler = gr.Dropdown(
1409
+ ["linear", "cosine", "constant", "constant_with_warmup"],
1410
+ label="LR Scheduler", value="linear",
1411
+ )
1412
+
1413
+ # -- Train --
1414
+ ft_train_btn = gr.Button(
1415
+ "Start Training", variant="primary", size="lg",
1416
+ )
1417
+
1418
+ # State for the trained model
1419
+ ft_model_state = gr.State(None)
1420
+ ft_tokenizer_state = gr.State(None)
1421
+ ft_num_labels_state = gr.State(None)
1422
+
1423
+ with gr.Accordion("Training Log", open=False) as ft_log_accordion:
1424
+ ft_log = gr.Textbox(
1425
+ lines=12, interactive=False, elem_classes="log-output",
1426
+ show_label=False,
1427
+ )
1428
+
1429
+ # -- Results + Try Model (hidden until training completes) --
1430
+ with gr.Column(visible=False) as ft_results_col:
1431
+ gr.Markdown("### Results")
1432
+ with gr.Row(equal_height=True):
1433
+ with gr.Column(scale=2):
1434
+ ft_metrics = gr.Dataframe(
1435
+ label="Test Set Metrics",
1436
+ headers=["Metric", "Score"],
1437
+ interactive=False,
1438
+ )
1439
+ with gr.Column(scale=3):
1440
+ gr.Markdown("**Try your model**")
1441
+ ft_try_input = gr.Textbox(
1442
+ lines=2, label="Input Text",
1443
+ placeholder="Type text to classify...",
1444
+ )
1445
+ with gr.Row():
1446
+ ft_try_btn = gr.Button("Predict", variant="primary")
1447
+ ft_try_output = gr.HTML(label="Prediction")
1448
+
1449
+ # -- Save + Batch (hidden until training completes) --
1450
+ with gr.Column(visible=False) as ft_actions_col:
1451
+ with gr.Row(equal_height=True):
1452
+ with gr.Column():
1453
+ gr.Markdown("**Save model**")
1454
+ ft_save_path = gr.Textbox(
1455
+ label="Save Directory", value="./finetuned_model",
1456
+ )
1457
+ ft_save_btn = gr.Button("Save", variant="secondary")
1458
+ ft_save_status = gr.Markdown("")
1459
+ with gr.Column():
1460
+ gr.Markdown("**Batch predictions**")
1461
+ ft_batch_in = gr.File(
1462
+ label="Upload CSV (needs 'text' column)",
1463
+ file_types=[".csv"],
1464
+ )
1465
+ ft_batch_btn = gr.Button(
1466
+ "Run Predictions", variant="secondary",
1467
+ )
1468
+ ft_batch_out = gr.File(label="Download Results")
1469
+
1470
+ # -- Compare Models --
1471
+ gr.Markdown("---")
1472
+ with gr.Accordion("Compare Multiple Models", open=False):
1473
+ gr.Markdown(
1474
+ "Train the same dataset on different base models and compare "
1475
+ "performance side by side. Uses the data and task type above."
1476
+ )
1477
+ cmp_models = gr.CheckboxGroup(
1478
+ choices=list(FINETUNE_MODELS.keys()),
1479
+ label="Select models to compare (pick 2 or more)",
1480
+ )
1481
+ with gr.Row():
1482
+ cmp_epochs = gr.Number(label="Epochs", value=3, minimum=1, precision=0)
1483
+ cmp_batch = gr.Number(label="Batch Size", value=8, minimum=1, precision=0)
1484
+ cmp_lr = gr.Number(label="Learning Rate", value=2e-5, minimum=1e-7)
1485
+ cmp_btn = gr.Button("Compare Models", variant="primary")
1486
+ cmp_log = gr.Textbox(
1487
+ label="Comparison Log", lines=8,
1488
+ interactive=False, elem_classes="log-output",
1489
+ )
1490
+ with gr.Column(visible=False) as cmp_results_col:
1491
+ cmp_table = gr.Dataframe(
1492
+ label="Comparison Results", interactive=False,
1493
+ )
1494
+ cmp_plot = gr.Plot(label="Metrics Comparison")
1495
+ cmp_roc = gr.Plot(label="ROC Curves")
1496
+
1497
+ # ---- FOOTER ----
1498
+ gr.Markdown(
1499
+ "<div style='text-align: center; padding: 1rem 0; margin-top: 0.5rem; "
1500
+ "border-top: 1px solid #e5e7eb;'>"
1501
+ "<p style='color: #888; font-size: 0.85rem; margin: 0;'>"
1502
+ "Developed by "
1503
+ "<a href='http://shreyasmeher.com' target='_blank' "
1504
+ "style='color: #ff6b35; text-decoration: none;'>Shreyas Meher</a>"
1505
+ "</p>"
1506
+ "<p style='color: #999; font-size: 0.75rem; margin: 0.5rem 0 0; "
1507
+ "max-width: 700px; margin-left: auto; margin-right: auto; line-height: 1.4;'>"
1508
+ "If you use ConfliBERT in your research, please cite:<br>"
1509
+ "<em>Brandt, P.T., Alsarra, S., D'Orazio, V., Heintze, D., Khan, L., "
1510
+ "Meher, S., Osorio, J. and Sianan, M., 2025. Extractive versus Generative "
1511
+ "Language Models for Political Conflict Text Classification. "
1512
+ "Political Analysis, pp.1&ndash;29.</em>"
1513
+ "</p></div>"
1514
+ )
1515
+
1516
+ # ====================================================================
1517
+ # EVENT HANDLERS
1518
+ # ====================================================================
1519
+
1520
+ # NER
1521
+ ner_btn.click(
1522
+ fn=named_entity_recognition, inputs=[ner_input], outputs=[ner_output],
1523
+ )
1524
+ ner_csv_btn.click(
1525
+ fn=process_csv_ner, inputs=[ner_csv_in], outputs=[ner_csv_out],
1526
+ )
1527
+
1528
+ # Binary Classification
1529
+ clf_btn.click(
1530
+ fn=text_classification,
1531
+ inputs=[clf_input, custom_clf_model, custom_clf_tokenizer],
1532
+ outputs=[clf_output],
1533
+ )
1534
+ clf_csv_btn.click(
1535
+ fn=process_csv_binary,
1536
+ inputs=[clf_csv_in, custom_clf_model, custom_clf_tokenizer],
1537
+ outputs=[clf_csv_out],
1538
+ )
1539
+ clf_load_btn.click(
1540
+ fn=load_custom_model,
1541
+ inputs=[clf_model_path],
1542
+ outputs=[custom_clf_model, custom_clf_tokenizer, clf_status],
1543
+ )
1544
+ clf_reset_btn.click(
1545
+ fn=reset_custom_model,
1546
+ outputs=[custom_clf_model, custom_clf_tokenizer, clf_status],
1547
+ )
1548
+
1549
+ # Multilabel Classification
1550
+ multi_btn.click(
1551
+ fn=multilabel_classification, inputs=[multi_input], outputs=[multi_output],
1552
+ )
1553
+ multi_csv_btn.click(
1554
+ fn=process_csv_multilabel, inputs=[multi_csv_in], outputs=[multi_csv_out],
1555
+ )
1556
+
1557
+ # Question Answering
1558
+ qa_btn.click(
1559
+ fn=question_answering,
1560
+ inputs=[qa_context, qa_question],
1561
+ outputs=[qa_output],
1562
+ )
1563
+
1564
+ # Fine-tuning: example dataset loaders
1565
+ ft_ex_binary_btn.click(
1566
+ fn=load_example_binary,
1567
+ outputs=[ft_train_file, ft_dev_file, ft_test_file, ft_task],
1568
+ )
1569
+ ft_ex_multi_btn.click(
1570
+ fn=load_example_multiclass,
1571
+ outputs=[ft_train_file, ft_dev_file, ft_test_file, ft_task],
1572
+ )
1573
+
1574
+ # Fine-tuning: training
1575
+ ft_train_btn.click(
1576
+ fn=run_finetuning,
1577
+ inputs=[
1578
+ ft_train_file, ft_dev_file, ft_test_file,
1579
+ ft_task, ft_model,
1580
+ ft_epochs, ft_batch, ft_lr,
1581
+ ft_weight_decay, ft_warmup, ft_max_len,
1582
+ ft_grad_accum, ft_fp16, ft_patience, ft_scheduler,
1583
+ ],
1584
+ outputs=[
1585
+ ft_log, ft_metrics,
1586
+ ft_model_state, ft_tokenizer_state, ft_num_labels_state,
1587
+ ft_results_col, ft_actions_col,
1588
+ ],
1589
+ concurrency_limit=1,
1590
+ )
1591
+
1592
+ # Try finetuned model
1593
+ ft_try_btn.click(
1594
+ fn=predict_finetuned,
1595
+ inputs=[ft_try_input, ft_model_state, ft_tokenizer_state, ft_num_labels_state],
1596
+ outputs=[ft_try_output],
1597
+ )
1598
+
1599
+ # Save finetuned model
1600
+ ft_save_btn.click(
1601
+ fn=save_finetuned_model,
1602
+ inputs=[ft_save_path, ft_model_state, ft_tokenizer_state],
1603
+ outputs=[ft_save_status],
1604
+ )
1605
+
1606
+ # Batch predictions with finetuned model
1607
+ ft_batch_btn.click(
1608
+ fn=batch_predict_finetuned,
1609
+ inputs=[ft_batch_in, ft_model_state, ft_tokenizer_state, ft_num_labels_state],
1610
+ outputs=[ft_batch_out],
1611
+ )
1612
+
1613
+ # Model comparison
1614
+ cmp_btn.click(
1615
+ fn=run_comparison,
1616
+ inputs=[
1617
+ ft_train_file, ft_dev_file, ft_test_file,
1618
+ ft_task, cmp_models, cmp_epochs, cmp_batch, cmp_lr,
1619
+ ],
1620
+ outputs=[cmp_log, cmp_table, cmp_plot, cmp_roc, cmp_results_col],
1621
+ concurrency_limit=1,
1622
+ )
1623
+
1624
+
1625
+ # ============================================================================
1626
+ # LAUNCH
1627
+ # ============================================================================
1628
+
1629
+ demo.launch(share=True)
examples/binary/dev.tsv ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The peacekeeping mission reported increased hostilities along the ceasefire line 1
2
+ The warlord's militia seized control of the strategically important bridge 1
3
+ The naval forces intercepted a shipment of weapons destined for the rebels 1
4
+ The terrorist organization released a video threatening further attacks 1
5
+ Armed groups continue to recruit child soldiers in violation of international law 1
6
+ Ethnic cleansing campaigns have forced entire communities to flee their lands 1
7
+ The arms embargo was violated as new weapons flowed into the conflict zone 1
8
+ Soldiers exchanged fire with suspected militants near the border crossing 1
9
+ Multiple explosions were reported near the presidential palace overnight 1
10
+ The rebel commander announced a new offensive targeting government supply routes 1
11
+ The fashion designer presented a stunning collection at the annual style week 0
12
+ The telecommunications company expanded its fiber network to rural areas 0
13
+ The basketball league announced changes to the playoff format next season 0
14
+ A new species of butterfly was documented during a biodiversity survey 0
15
+ The pottery exhibition drew visitors from across the region 0
16
+ The national library digitized thousands of historical manuscripts 0
17
+ The airline reported record passenger numbers during the holiday season 0
18
+ The botanical garden opened a new section dedicated to tropical plants 0
19
+ Researchers developed a more efficient method for recycling plastic waste 0
20
+ The chess tournament attracted grandmasters from over thirty countries 0
examples/binary/test.tsv ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The siege of the coastal city entered its fourth month with supplies running low 1
2
+ Rebel fighters launched rockets at the airport disrupting all flights 1
3
+ The militia carried out reprisal attacks against civilians in the border towns 1
4
+ Security forces clashed with armed protesters near the government district 1
5
+ A landmine explosion killed three children walking to school in the rural area 1
6
+ The insurgent group released hostages after weeks of tense negotiations 1
7
+ Mortar rounds struck a hospital compound wounding patients and medical staff 1
8
+ The warring parties rejected the latest ceasefire proposal from international mediators 1
9
+ Government troops advanced into rebel-held territory following an intense bombardment 1
10
+ An assassination attempt on the defense minister was foiled by security services 1
11
+ The renewable energy conference attracted investors from around the world 0
12
+ A new underwater cable was laid connecting the island to the mainland grid 0
13
+ The jazz festival featured performances by legendary musicians and new artists 0
14
+ The national team qualified for the world championships for the first time 0
15
+ The organic food market experienced double-digit growth for the third straight year 0
16
+ The space agency successfully launched a satellite to monitor ocean temperatures 0
17
+ The vintage car rally drew enthusiasts from neighboring countries 0
18
+ The children's hospital received a generous donation for its expansion project 0
19
+ The urban farming initiative transformed abandoned lots into productive gardens 0
20
+ The documentary about traditional boat-building won the top prize at the festival 0
examples/binary/train.tsv ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Government forces launched an offensive against rebel positions in the northern province early this morning 1
2
+ A car bomb exploded near a military checkpoint killing at least twelve soldiers 1
3
+ Insurgents attacked a police station in the capital overnight leaving several officers wounded 1
4
+ The militant group claimed responsibility for the ambush on a military convoy 1
5
+ Heavy fighting broke out between rival armed factions in the disputed border region 1
6
+ Security forces conducted raids targeting suspected members of the armed opposition 1
7
+ A suicide bomber detonated explosives at a crowded marketplace injuring dozens of civilians 1
8
+ The rebel forces captured a strategic town after weeks of intense battles 1
9
+ Artillery shells struck residential areas as the conflict between the two sides intensified 1
10
+ An airstrike destroyed a weapons depot used by the insurgent group 1
11
+ The government declared a state of emergency following widespread political violence 1
12
+ Armed men attacked a village killing several residents and burning homes 1
13
+ Protesters clashed violently with police during demonstrations against the military regime 1
14
+ A roadside bomb targeted a military patrol wounding three soldiers 1
15
+ The armed group kidnapped aid workers operating in the conflict zone 1
16
+ Sniper fire killed two civilians in the besieged neighborhood 1
17
+ Military helicopters were deployed to support ground troops fighting in the eastern region 1
18
+ An explosion at a government building was attributed to opposition fighters 1
19
+ Cross-border shelling between the two nations continued for the third consecutive day 1
20
+ Armed bandits attacked a refugee camp displacing thousands of people 1
21
+ The guerrilla fighters ambushed a supply convoy on the main highway 1
22
+ The opposing forces exchanged heavy gunfire throughout the night 1
23
+ A mortar attack on the military base resulted in significant casualties 1
24
+ Paramilitary groups carried out targeted assassinations of political opponents 1
25
+ Government aircraft bombed suspected rebel strongholds in the mountainous region 1
26
+ Two soldiers were killed when their vehicle struck a landmine on a rural road 1
27
+ The separatist movement launched coordinated attacks on government installations 1
28
+ Ethnic tensions erupted into open violence as rival communities clashed in the market 1
29
+ Armed opposition forces shelled the outskirts of the capital city 1
30
+ A grenade attack on a busy intersection killed four people and wounded many more 1
31
+ The military junta deployed tanks to suppress the growing resistance movement 1
32
+ Fighting between government troops and rebels displaced thousands of families 1
33
+ Security operations intensified after a series of bombings in the commercial district 1
34
+ A militia group took control of a key oil facility in the contested region 1
35
+ An improvised explosive device was found near the parliament building 1
36
+ Coalition forces conducted a night raid capturing several high-value targets 1
37
+ The ongoing civil war has resulted in thousands of casualties and widespread destruction 1
38
+ A drone strike targeted a meeting of senior militant commanders 1
39
+ The opposition forces breached the defensive perimeter around the government compound 1
40
+ Gunmen opened fire on a convoy of government officials killing two bodyguards 1
41
+ The national football team secured a convincing victory in the qualifying match 0
42
+ Temperatures are expected to reach record highs this weekend according to forecasters 0
43
+ The technology company unveiled its latest smartphone with improved camera capabilities 0
44
+ Stock markets rallied on news of stronger than expected economic growth 0
45
+ The film festival announced its lineup featuring works from emerging directors 0
46
+ Scientists discovered a new species of deep-sea fish in the Pacific Ocean 0
47
+ The university announced a new scholarship program for students in engineering 0
48
+ Local farmers reported an excellent harvest this season due to favorable weather 0
49
+ The city council approved plans for a new public park in the downtown area 0
50
+ A major software update was released improving performance and adding new features 0
51
+ The marathon attracted over twenty thousand runners from across the country 0
52
+ Researchers published findings on a promising treatment for a rare disorder 0
53
+ The airline announced new direct flights connecting the capital with European cities 0
54
+ A popular author released the highly anticipated sequel to her bestselling novel 0
55
+ The automotive company revealed plans to launch three new electric vehicle models 0
56
+ Annual tourism numbers reached an all-time high at the coastal resorts 0
57
+ The construction of the new high-speed rail line is ahead of schedule 0
58
+ Astronomers observed a rare celestial event visible from the southern hemisphere 0
59
+ The bakery chain announced plans to expand into twelve new locations 0
60
+ The tech startup raised significant funding in its latest investment round 0
61
+ The swimming team broke the national record at the regional championships 0
62
+ A new study found that regular exercise significantly reduces heart disease risk 0
63
+ The city hosted a successful international food and wine festival 0
64
+ Archaeologists uncovered ancient pottery at a dig site near the monument 0
65
+ The pharmaceutical company received approval for a new vaccine formulation 0
66
+ A local nonprofit organized a community cleanup event at the riverside park 0
67
+ The solar energy project is expected to power thousands of homes by year end 0
68
+ The gaming company released a new title that quickly became a bestseller 0
69
+ Public transit ridership increased following improvements to the subway system 0
70
+ The annual science fair showcased innovative projects by high school students 0
71
+ The dairy industry adopted new standards for sustainable milk production 0
72
+ The orchestra performed a sold-out concert of works by contemporary composers 0
73
+ The hospital inaugurated a state-of-the-art wing dedicated to pediatric care 0
74
+ The cycling tour attracted international competitors to the coastal route 0
75
+ The agricultural ministry launched a program to support organic farming 0
76
+ A popular streaming service announced an original series based on the classic novel 0
77
+ The winter ski season opened early due to heavy snowfall in the mountains 0
78
+ The electric vehicle charging network expanded to cover all major highways 0
79
+ The oceanographic institute published research on coral reef restoration 0
80
+ The cookbook featuring traditional regional recipes became an unexpected bestseller 0
81
+ The museum opened a new exhibition showcasing contemporary sculpture and painting 0
examples/multiclass/dev.tsv ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The peace envoy presented a revised framework for territorial compromise 0
2
+ Regional leaders convened an emergency session to address the border crisis 0
3
+ Negotiations on the arms limitation treaty entered their final round 0
4
+ The joint commission agreed to establish demilitarized buffer zones 0
5
+ A new diplomatic initiative aimed at ending the decades-long standoff was announced 0
6
+ Militia forces overran a government checkpoint killing the defenders 1
7
+ An ambush on the supply column resulted in the loss of critical equipment 1
8
+ Fighter jets conducted precision strikes on command and control centers 1
9
+ The battle for the provincial capital intensified with street-to-street fighting 1
10
+ Landmines planted along the withdrawal route caused additional military casualties 1
11
+ Pro-democracy activists organized a candlelight vigil outside the detention center 2
12
+ Transport workers walked off the job paralyzing the rail network 2
13
+ Demonstrators erected barricades across main roads in defiance of the curfew 2
14
+ The environmental movement staged protests at industrial sites across the country 2
15
+ Police fired rubber bullets at stone-throwing youths during the confrontation 2
16
+ Emergency medical supplies were rushed to the hospital overwhelmed with casualties 3
17
+ The displaced population established makeshift camps along the roadside 3
18
+ Aid agencies warned of an impending water crisis in the drought-stricken region 3
19
+ Rescue teams searched through rubble for survivors after the earthquake 3
20
+ The nutrition program reached over ten thousand malnourished children this month 3
examples/multiclass/test.tsv ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The multilateral agreement established protocols for maritime dispute resolution 0
2
+ Both delegations expressed optimism following the fourth round of peace talks 0
3
+ The international community welcomed the signing of the normalization agreement 0
4
+ Economic sanctions were partially lifted following compliance with treaty obligations 0
5
+ The mediation team proposed a phased withdrawal plan accepted by both parties 0
6
+ The armored column advanced through the valley under heavy enemy fire 1
7
+ Rockets struck the airfield destroying several military aircraft on the ground 1
8
+ Special operations forces carried out a raid deep behind enemy lines 1
9
+ The naval blockade prevented resupply of the besieged coastal garrison 1
10
+ Anti-aircraft fire downed a reconnaissance drone over the contested territory 1
11
+ Massive crowds filled the boulevard demanding free and fair elections 2
12
+ The dock workers union expanded its strike to include all major ports 2
13
+ Indigenous communities organized roadblocks to protest land seizure by corporations 2
14
+ Thousands of women marched demanding an end to gender-based violence 2
15
+ Student groups staged walkouts across dozens of universities nationwide 2
16
+ The refugee crisis deepened as thousands more fled across the border overnight 3
17
+ Field hospitals operated at capacity treating both military and civilian wounded 3
18
+ The humanitarian airlift delivered critical medical supplies to the isolated town 3
19
+ Sanitation conditions in the overcrowded camp raised fears of disease outbreaks 3
20
+ International rescue teams deployed to assist with flood evacuation efforts 3
examples/multiclass/train.tsv ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The two nations signed a bilateral trade agreement during the summit meeting 0
2
+ Foreign ministers met to discuss the terms of the proposed peace deal 0
3
+ The United Nations Security Council passed a resolution imposing new sanctions 0
4
+ Diplomatic envoys were dispatched to mediate between the conflicting parties 0
5
+ The ambassador presented new proposals for resolving the territorial dispute 0
6
+ A ceasefire agreement was brokered by regional mediators after weeks of talks 0
7
+ The peace conference concluded with a joint declaration of mutual cooperation 0
8
+ International observers praised the diplomatic progress made at the negotiations 0
9
+ The two governments established formal diplomatic relations for the first time 0
10
+ Trade negotiations between the economic bloc and the developing nation resumed 0
11
+ The foreign affairs committee approved the new bilateral defense cooperation pact 0
12
+ A high-level delegation arrived in the capital for talks on nuclear disarmament 0
13
+ The treaty on maritime boundaries was ratified by both nations parliaments 0
14
+ International mediators proposed a roadmap for political transition and elections 0
15
+ Leaders of the rival factions agreed to power-sharing arrangements at the talks 0
16
+ The alliance pledged continued diplomatic support for the peace process 0
17
+ An arms control agreement was reached limiting missile deployments in the region 0
18
+ The special envoy shuttled between capitals seeking a breakthrough in negotiations 0
19
+ Both sides agreed to exchange prisoners as a confidence-building measure 0
20
+ The summit produced a framework for addressing cross-border resource disputes 0
21
+ Government forces launched a major offensive against rebel positions in the east 1
22
+ A car bomb exploded near the military headquarters killing at least eight people 1
23
+ Insurgents ambushed a military convoy destroying several armored vehicles 1
24
+ The air force conducted airstrikes on militant camps in the mountain region 1
25
+ Heavy fighting erupted between government troops and separatist fighters 1
26
+ A roadside bomb killed five soldiers on patrol near the contested border area 1
27
+ The armed group captured a military outpost after a fierce overnight battle 1
28
+ Artillery barrages devastated residential neighborhoods in the besieged city 1
29
+ Coalition warplanes struck weapons storage facilities operated by the militia 1
30
+ Two helicopter gunships were shot down during combat operations in the valley 1
31
+ The rebel offensive resulted in the capture of three strategic hilltop positions 1
32
+ A suicide attack on the army base left dozens of soldiers dead or wounded 1
33
+ Snipers targeted civilians attempting to flee the fighting in the urban center 1
34
+ Naval forces engaged enemy vessels in a brief but intense exchange of fire 1
35
+ The battalion suffered heavy losses during the assault on the fortified position 1
36
+ Ground forces advanced under cover of sustained aerial bombardment 1
37
+ Mortar fire struck the refugee camp adjacent to the front lines 1
38
+ Tank divisions moved into position along the disputed ceasefire line 1
39
+ The garrison surrendered after running out of ammunition during the prolonged siege 1
40
+ Drone surveillance identified enemy troop movements ahead of the counterattack 1
41
+ Thousands of demonstrators gathered in the central square demanding government reform 2
42
+ Riot police used tear gas to disperse crowds outside the parliament building 2
43
+ Workers across the industrial sector launched a nationwide general strike 2
44
+ Student protesters occupied the university administration building for three days 2
45
+ The opposition organized mass rallies in cities across the country 2
46
+ Police arrested dozens of activists during an unauthorized march through downtown 2
47
+ Labor unions called for indefinite strikes to protest proposed austerity measures 2
48
+ Demonstrators blocked major highways disrupting transportation and commerce 2
49
+ Anti-government protests entered their second week with no sign of subsiding 2
50
+ The youth movement organized sit-ins at government offices across the region 2
51
+ Protesters set fire to government vehicles during clashes in the capital 2
52
+ Civil society groups staged peaceful vigils demanding the release of political prisoners 2
53
+ Tens of thousands marched against corruption in the largest demonstration in years 2
54
+ The teachers union voted to strike over pay cuts and deteriorating school conditions 2
55
+ Farmers drove tractors into the city center to protest agricultural subsidy reductions 2
56
+ Activists chained themselves to the gates of the energy ministry 2
57
+ The pro-democracy movement announced plans for a week of sustained civil disobedience 2
58
+ Shopkeepers shuttered their businesses in solidarity with the striking workers 2
59
+ University students clashed with security forces during a campus demonstration 2
60
+ Residents organized neighborhood protests against the proposed construction project 2
61
+ International aid organizations delivered food supplies to the displaced population 3
62
+ The refugee camp expanded rapidly as thousands fled the advancing front lines 3
63
+ Medical teams set up field hospitals to treat civilians injured in the crossfire 3
64
+ The humanitarian corridor allowed evacuation of wounded from the conflict zone 3
65
+ Food shortages reached critical levels as supply routes remained blocked 3
66
+ Emergency shelters were established for families displaced by the flooding 3
67
+ Aid workers distributed clean water and sanitation supplies to the affected areas 3
68
+ The World Health Organization launched a vaccination campaign in the crisis region 3
69
+ Thousands of refugees crossed the border seeking safety in neighboring countries 3
70
+ The famine early warning system indicated severe food insecurity in the southern region 3
71
+ Humanitarian agencies appealed for additional funding to support relief operations 3
72
+ Displaced families struggled to find shelter as winter temperatures dropped sharply 3
73
+ The Red Cross established a blood donation drive to support overwhelmed hospitals 3
74
+ Emergency food rations were airlifted to communities cut off by the fighting 3
75
+ Child protection agencies reported a surge in unaccompanied minors at border crossings 3
76
+ The cholera outbreak in the camp prompted an emergency public health response 3
77
+ International donors pledged millions in reconstruction aid at the conference 3
78
+ Volunteer groups organized clothing and supply drives for the disaster survivors 3
79
+ The malnutrition rate among children under five reached alarming levels 3
80
+ Mobile clinics provided medical care to remote communities affected by the crisis 3
requirements.txt CHANGED
@@ -1,5 +1,9 @@
1
- torch
2
- tensorflow
3
- transformers
4
- gradio
5
- tf-keras
 
 
 
 
 
1
+ torch
2
+ tensorflow
3
+ transformers
4
+ gradio
5
+ tf-keras
6
+ accelerate
7
+ scikit-learn
8
+ pandas
9
+ plotly
screenshots/classification.png ADDED

Git LFS Details

  • SHA256: 34a798a8b7e06780d86b9c5db54b8e52d36ad4ac890f6463b57268e9cf18f17c
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB
screenshots/finetune.png ADDED

Git LFS Details

  • SHA256: 8c23081d43a14e90a77e2a201ac87b4b38216dda1dd54c6ba74b4540862e8ee7
  • Pointer size: 131 Bytes
  • Size of remote file: 488 kB
screenshots/home.png ADDED

Git LFS Details

  • SHA256: 2f2441498b30bc603f53b132539740780e6e830dd578a0b1d9e9e12693d07022
  • Pointer size: 131 Bytes
  • Size of remote file: 806 kB
screenshots/multilabel.png ADDED

Git LFS Details

  • SHA256: a665f7339d5f87a82d1211c44eabc8d291a0f6d3b4c6a8c105cc6d4af58e8ce0
  • Pointer size: 131 Bytes
  • Size of remote file: 321 kB
screenshots/ner.png ADDED

Git LFS Details

  • SHA256: c7d632e3c808d7615cf80205399d751d3664d836ab32ab4f5589d925f2aaa370
  • Pointer size: 131 Bytes
  • Size of remote file: 305 kB
screenshots/qa.png ADDED

Git LFS Details

  • SHA256: ba5742b7fccd00dae57e487181b005abf6a18c3d31a44c935e57b224471786cd
  • Pointer size: 131 Bytes
  • Size of remote file: 310 kB