label_tools / label_tool.py
DuongTrongChi's picture
Upload 3 files
bc13e62 verified
"""
Argilla Data Labeling Tool
This tool helps you create a labeling interface for data with the following structure:
{
"text": "sample text",
"region_label": "Central/North/South",
"task_label": clean/offensive/hate
}
Usage:
1. First, set up and run Argilla (see README.md)
2. Run this script to create a dataset and upload your data
"""
import argparse
import json
from pathlib import Path
import argilla as rg
from typing import List, Dict, Optional
# Label mappings for converting between string labels and numeric IDs
LABEL_MAP = {
"clean": 0,
"offensive": 1,
"hate": 2
}
REGION_MAP = {
"North": 0,
"Central": 1,
"South": 2,
"Unknown": 3
}
# Reverse mappings
LABEL_ID_TO_NAME = {v: k for k, v in LABEL_MAP.items()}
REGION_ID_TO_NAME = {v: k for k, v in REGION_MAP.items()}
def load_jsonl(file_path: str) -> List[Dict]:
"""Load data from a JSONL file."""
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
data.append(json.loads(line))
return data
def create_dataset(
client: rg.Argilla,
dataset_name: str,
workspace: str = "argilla"
) -> rg.Dataset:
"""
Create an Argilla dataset configured for text classification with
region and task labels.
Args:
client: Authenticated Argilla client
dataset_name: Name for the dataset
workspace: Workspace name (default: "argilla")
Returns:
Created Argilla dataset
"""
# Define the settings for the dataset
settings = rg.Settings(
# Fields: the data to be displayed and annotated
fields=[
rg.TextField(
name="text",
title="Text",
required=True,
use_markdown=False
),
],
# Questions: what we want to annotate
questions=[
# Region labeling (multi-class)
rg.LabelQuestion(
name="region",
title="Region Label",
labels=["North", "Central", "South", "Unknown"],
required=True
),
# Task labeling (multi-class: toxicity level)
rg.LabelQuestion(
name="task",
title="Task Label",
labels=["clean", "offensive", "hate"],
required=True
),
],
# Metadata: additional info to display
metadata_properties=[
rg.TermsMetadataProperty(
name="original_region",
title="Original Region Label"
),
rg.IntegerMetadataProperty(
name="original_task",
title="Original Task Label"
),
]
)
# Check if dataset exists
try:
dataset = client.datasets(dataset_name)
print(f"Dataset '{dataset_name}' already exists. Using existing dataset.")
return dataset
except Exception:
# Create new dataset
print(f"Creating new dataset '{dataset_name}'...")
dataset = client.datasets(
name=dataset_name,
workspace=workspace,
settings=settings
)
print(f"Dataset '{dataset_name}' created successfully!")
return dataset
def normalize_label(label, id_to_name_map, label_map):
"""Convert label to string name, handling both numeric IDs and string names."""
if isinstance(label, int):
return id_to_name_map.get(label, "Unknown")
elif isinstance(label, str):
# If it's already a valid string label, return it
if label in label_map:
return label
# If it's a numeric string, try to convert
try:
label_id = int(label)
return id_to_name_map.get(label_id, "Unknown")
except ValueError:
return label
return "Unknown"
def upload_data(
dataset: rg.Dataset,
data: List[Dict],
batch_size: int = 100
):
"""
Upload data to the Argilla dataset.
Args:
dataset: Argilla dataset
data: List of records with text, region_label, task_label
batch_size: Number of records to upload at once
Note:
Accepts both string labels (e.g., "clean", "North") and
numeric IDs (e.g., 0, 1, 2) for region_label and task_label.
"""
print(f"Preparing to upload {len(data)} records...")
# Create Argilla records
records = []
for item in data:
# Normalize labels (handle both string and numeric input)
region_label = normalize_label(
item.get("region_label", "Unknown"),
REGION_ID_TO_NAME,
REGION_MAP
)
task_label = normalize_label(
item.get("task_label", "clean"),
LABEL_ID_TO_NAME,
LABEL_MAP
)
record = rg.Record(
fields={"text": item["text"]},
metadata={
"original_region": str(item.get("region_label", "")),
"original_task": str(item.get("task_label", ""))
}
)
# Add existing labels as suggestions
suggestions = []
if region_label != "Unknown":
suggestions.append(
rg.Suggestion(
question_name="region",
value=region_label,
agent="pre-label"
)
)
if task_label != "Unknown":
suggestions.append(
rg.Suggestion(
question_name="task",
value=task_label,
agent="pre-label"
)
)
if suggestions:
record.suggestions = suggestions
records.append(record)
# Upload records in batches
total_batches = (len(records) + batch_size - 1) // batch_size
for i in range(0, len(records), batch_size):
batch = records[i:i + batch_size]
dataset.records.log(batch)
batch_num = (i // batch_size) + 1
print(f"Uploaded batch {batch_num}/{total_batches} ({len(batch)} records)")
print(f"Successfully uploaded {len(records)} records to dataset!")
def export_labeled_data(
dataset: rg.Dataset,
output_path: str,
format: str = "jsonl"
):
"""
Export labeled data from Argilla dataset to JSONL file with numeric IDs.
Args:
dataset: Argilla dataset
output_path: Path to output file
format: Output format ("jsonl" or "json")
"""
print(f"Exporting labeled data from '{dataset.name}'...")
records = dataset.records.to_list()
print(f"Found {len(records)} records")
output_data = []
for record in records:
# Get responses (annotations)
responses = record.responses
if not responses:
# No annotations yet, skip or use suggestions
continue
# Use the latest response
latest_response = responses[-1]
# Get labels from response
region_name = latest_response.values.get("region")
task_name = latest_response.values.get("task")
# Convert to numeric IDs
region_id = REGION_MAP.get(region_name, 3) # Default to Unknown
task_id = LABEL_MAP.get(task_name, 0) # Default to clean
output_data.append({
"text": record.fields["text"],
"region_label": region_name,
"region_id": region_id,
"task_label": task_name,
"task_id": task_id
})
# Write to file
if format == "jsonl":
with open(output_path, 'w', encoding='utf-8') as f:
for item in output_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
else: # json
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(output_data, f, ensure_ascii=False, indent=2)
print(f"Exported {len(output_data)} labeled records to '{output_path}'")
def main():
parser = argparse.ArgumentParser(
description="Upload data to Argilla for labeling or export labeled data"
)
subparsers = parser.add_subparsers(dest="mode", help="Mode: upload or export")
# Upload mode
upload_parser = subparsers.add_parser("upload", help="Upload data to Argilla")
upload_parser.add_argument(
"jsonl_file",
type=str,
help="Path to the JSONL file containing data to label"
)
upload_parser.add_argument(
"--api-url",
type=str,
default=None,
help="Argilla API URL (e.g., http://localhost:6900 or https://your-space.hf.space)"
)
upload_parser.add_argument(
"--api-key",
type=str,
default=None,
help="Argilla API key (found in My Settings page)"
)
upload_parser.add_argument(
"--dataset-name",
type=str,
default="text_labeling_dataset",
help="Name for the Argilla dataset (default: text_labeling_dataset)"
)
upload_parser.add_argument(
"--workspace",
type=str,
default="argilla",
help="Argilla workspace name (default: argilla)"
)
# Export mode
export_parser = subparsers.add_parser("export", help="Export labeled data from Argilla")
export_parser.add_argument(
"output_file",
type=str,
help="Path to output file (e.g., labeled_data.jsonl)"
)
export_parser.add_argument(
"--api-url",
type=str,
default=None,
help="Argilla API URL (e.g., http://localhost:6900 or https://your-space.hf.space)"
)
export_parser.add_argument(
"--api-key",
type=str,
default=None,
help="Argilla API key (found in My Settings page)"
)
export_parser.add_argument(
"--dataset-name",
type=str,
default="text_labeling_dataset",
help="Name for the Argilla dataset (default: text_labeling_dataset)"
)
export_parser.add_argument(
"--workspace",
type=str,
default="argilla",
help="Argilla workspace name (default: argilla)"
)
export_parser.add_argument(
"--format",
type=str,
default="jsonl",
choices=["jsonl", "json"],
help="Output format (default: jsonl)"
)
args = parser.parse_args()
# Default to upload mode if not specified (backward compatibility)
if args.mode is None:
# Treat as upload mode with jsonl_file as first positional arg
import sys
# Re-parse with upload mode
sys.argv.insert(1, "upload")
args = parser.parse_args()
# Connect to Argilla
if args.api_url and args.api_key:
print(f"Connecting to Argilla at {args.api_url}...")
client = rg.Argilla(
api_url=args.api_url,
api_key=args.api_key
)
else:
print("Connecting to Argilla (using default or env variables)...")
print("Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables if not using defaults")
client = rg.Argilla()
# Verify connection
try:
me = client.me()
print(f"Connected as: {me.username}")
except Exception as e:
print(f"Error connecting to Argilla: {e}")
return 1
if args.mode == "upload":
# Validate input file
jsonl_path = Path(args.jsonl_file)
if not jsonl_path.exists():
print(f"Error: File '{args.jsonl_file}' not found!")
return 1
# Load data
print(f"Loading data from '{args.jsonl_file}'...")
data = load_jsonl(args.jsonl_file)
print(f"Loaded {len(data)} records")
# Show sample
if data:
print(f"\nSample record:")
print(json.dumps(data[0], ensure_ascii=False, indent=2))
# Create dataset
dataset = create_dataset(
client=client,
dataset_name=args.dataset_name,
workspace=args.workspace
)
# Upload data
upload_data(dataset, data)
print(f"\n{'='*60}")
print(f"Done! You can now label your data at:")
print(f"{client.api_url}/datasets/{args.workspace}/{args.dataset_name}")
print(f"{'='*60}")
elif args.mode == "export":
# Get dataset
try:
dataset = client.datasets(args.dataset_name, workspace=args.workspace)
except Exception as e:
print(f"Error: Dataset '{args.dataset_name}' not found: {e}")
return 1
# Export data
export_labeled_data(dataset, args.output_file, args.format)
print(f"\nLabel mappings:")
print(f" Task: {LABEL_MAP}")
print(f" Region: {REGION_MAP}")
return 0
if __name__ == "__main__":
exit(main())