Spaces:
Sleeping
Sleeping
| """ | |
| Argilla Data Labeling Tool | |
| This tool helps you create a labeling interface for data with the following structure: | |
| { | |
| "text": "sample text", | |
| "region_label": "Central/North/South", | |
| "task_label": clean/offensive/hate | |
| } | |
| Usage: | |
| 1. First, set up and run Argilla (see README.md) | |
| 2. Run this script to create a dataset and upload your data | |
| """ | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import argilla as rg | |
| from typing import List, Dict, Optional | |
| # Label mappings for converting between string labels and numeric IDs | |
| LABEL_MAP = { | |
| "clean": 0, | |
| "offensive": 1, | |
| "hate": 2 | |
| } | |
| REGION_MAP = { | |
| "North": 0, | |
| "Central": 1, | |
| "South": 2, | |
| "Unknown": 3 | |
| } | |
| # Reverse mappings | |
| LABEL_ID_TO_NAME = {v: k for k, v in LABEL_MAP.items()} | |
| REGION_ID_TO_NAME = {v: k for k, v in REGION_MAP.items()} | |
| def load_jsonl(file_path: str) -> List[Dict]: | |
| """Load data from a JSONL file.""" | |
| data = [] | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| data.append(json.loads(line)) | |
| return data | |
| def create_dataset( | |
| client: rg.Argilla, | |
| dataset_name: str, | |
| workspace: str = "argilla" | |
| ) -> rg.Dataset: | |
| """ | |
| Create an Argilla dataset configured for text classification with | |
| region and task labels. | |
| Args: | |
| client: Authenticated Argilla client | |
| dataset_name: Name for the dataset | |
| workspace: Workspace name (default: "argilla") | |
| Returns: | |
| Created Argilla dataset | |
| """ | |
| # Define the settings for the dataset | |
| settings = rg.Settings( | |
| # Fields: the data to be displayed and annotated | |
| fields=[ | |
| rg.TextField( | |
| name="text", | |
| title="Text", | |
| required=True, | |
| use_markdown=False | |
| ), | |
| ], | |
| # Questions: what we want to annotate | |
| questions=[ | |
| # Region labeling (multi-class) | |
| rg.LabelQuestion( | |
| name="region", | |
| title="Region Label", | |
| labels=["North", "Central", "South", "Unknown"], | |
| required=True | |
| ), | |
| # Task labeling (multi-class: toxicity level) | |
| rg.LabelQuestion( | |
| name="task", | |
| title="Task Label", | |
| labels=["clean", "offensive", "hate"], | |
| required=True | |
| ), | |
| ], | |
| # Metadata: additional info to display | |
| metadata_properties=[ | |
| rg.TermsMetadataProperty( | |
| name="original_region", | |
| title="Original Region Label" | |
| ), | |
| rg.IntegerMetadataProperty( | |
| name="original_task", | |
| title="Original Task Label" | |
| ), | |
| ] | |
| ) | |
| # Check if dataset exists | |
| try: | |
| dataset = client.datasets(dataset_name) | |
| print(f"Dataset '{dataset_name}' already exists. Using existing dataset.") | |
| return dataset | |
| except Exception: | |
| # Create new dataset | |
| print(f"Creating new dataset '{dataset_name}'...") | |
| dataset = client.datasets( | |
| name=dataset_name, | |
| workspace=workspace, | |
| settings=settings | |
| ) | |
| print(f"Dataset '{dataset_name}' created successfully!") | |
| return dataset | |
| def normalize_label(label, id_to_name_map, label_map): | |
| """Convert label to string name, handling both numeric IDs and string names.""" | |
| if isinstance(label, int): | |
| return id_to_name_map.get(label, "Unknown") | |
| elif isinstance(label, str): | |
| # If it's already a valid string label, return it | |
| if label in label_map: | |
| return label | |
| # If it's a numeric string, try to convert | |
| try: | |
| label_id = int(label) | |
| return id_to_name_map.get(label_id, "Unknown") | |
| except ValueError: | |
| return label | |
| return "Unknown" | |
| def upload_data( | |
| dataset: rg.Dataset, | |
| data: List[Dict], | |
| batch_size: int = 100 | |
| ): | |
| """ | |
| Upload data to the Argilla dataset. | |
| Args: | |
| dataset: Argilla dataset | |
| data: List of records with text, region_label, task_label | |
| batch_size: Number of records to upload at once | |
| Note: | |
| Accepts both string labels (e.g., "clean", "North") and | |
| numeric IDs (e.g., 0, 1, 2) for region_label and task_label. | |
| """ | |
| print(f"Preparing to upload {len(data)} records...") | |
| # Create Argilla records | |
| records = [] | |
| for item in data: | |
| # Normalize labels (handle both string and numeric input) | |
| region_label = normalize_label( | |
| item.get("region_label", "Unknown"), | |
| REGION_ID_TO_NAME, | |
| REGION_MAP | |
| ) | |
| task_label = normalize_label( | |
| item.get("task_label", "clean"), | |
| LABEL_ID_TO_NAME, | |
| LABEL_MAP | |
| ) | |
| record = rg.Record( | |
| fields={"text": item["text"]}, | |
| metadata={ | |
| "original_region": str(item.get("region_label", "")), | |
| "original_task": str(item.get("task_label", "")) | |
| } | |
| ) | |
| # Add existing labels as suggestions | |
| suggestions = [] | |
| if region_label != "Unknown": | |
| suggestions.append( | |
| rg.Suggestion( | |
| question_name="region", | |
| value=region_label, | |
| agent="pre-label" | |
| ) | |
| ) | |
| if task_label != "Unknown": | |
| suggestions.append( | |
| rg.Suggestion( | |
| question_name="task", | |
| value=task_label, | |
| agent="pre-label" | |
| ) | |
| ) | |
| if suggestions: | |
| record.suggestions = suggestions | |
| records.append(record) | |
| # Upload records in batches | |
| total_batches = (len(records) + batch_size - 1) // batch_size | |
| for i in range(0, len(records), batch_size): | |
| batch = records[i:i + batch_size] | |
| dataset.records.log(batch) | |
| batch_num = (i // batch_size) + 1 | |
| print(f"Uploaded batch {batch_num}/{total_batches} ({len(batch)} records)") | |
| print(f"Successfully uploaded {len(records)} records to dataset!") | |
| def export_labeled_data( | |
| dataset: rg.Dataset, | |
| output_path: str, | |
| format: str = "jsonl" | |
| ): | |
| """ | |
| Export labeled data from Argilla dataset to JSONL file with numeric IDs. | |
| Args: | |
| dataset: Argilla dataset | |
| output_path: Path to output file | |
| format: Output format ("jsonl" or "json") | |
| """ | |
| print(f"Exporting labeled data from '{dataset.name}'...") | |
| records = dataset.records.to_list() | |
| print(f"Found {len(records)} records") | |
| output_data = [] | |
| for record in records: | |
| # Get responses (annotations) | |
| responses = record.responses | |
| if not responses: | |
| # No annotations yet, skip or use suggestions | |
| continue | |
| # Use the latest response | |
| latest_response = responses[-1] | |
| # Get labels from response | |
| region_name = latest_response.values.get("region") | |
| task_name = latest_response.values.get("task") | |
| # Convert to numeric IDs | |
| region_id = REGION_MAP.get(region_name, 3) # Default to Unknown | |
| task_id = LABEL_MAP.get(task_name, 0) # Default to clean | |
| output_data.append({ | |
| "text": record.fields["text"], | |
| "region_label": region_name, | |
| "region_id": region_id, | |
| "task_label": task_name, | |
| "task_id": task_id | |
| }) | |
| # Write to file | |
| if format == "jsonl": | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| for item in output_data: | |
| f.write(json.dumps(item, ensure_ascii=False) + '\n') | |
| else: # json | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump(output_data, f, ensure_ascii=False, indent=2) | |
| print(f"Exported {len(output_data)} labeled records to '{output_path}'") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Upload data to Argilla for labeling or export labeled data" | |
| ) | |
| subparsers = parser.add_subparsers(dest="mode", help="Mode: upload or export") | |
| # Upload mode | |
| upload_parser = subparsers.add_parser("upload", help="Upload data to Argilla") | |
| upload_parser.add_argument( | |
| "jsonl_file", | |
| type=str, | |
| help="Path to the JSONL file containing data to label" | |
| ) | |
| upload_parser.add_argument( | |
| "--api-url", | |
| type=str, | |
| default=None, | |
| help="Argilla API URL (e.g., http://localhost:6900 or https://your-space.hf.space)" | |
| ) | |
| upload_parser.add_argument( | |
| "--api-key", | |
| type=str, | |
| default=None, | |
| help="Argilla API key (found in My Settings page)" | |
| ) | |
| upload_parser.add_argument( | |
| "--dataset-name", | |
| type=str, | |
| default="text_labeling_dataset", | |
| help="Name for the Argilla dataset (default: text_labeling_dataset)" | |
| ) | |
| upload_parser.add_argument( | |
| "--workspace", | |
| type=str, | |
| default="argilla", | |
| help="Argilla workspace name (default: argilla)" | |
| ) | |
| # Export mode | |
| export_parser = subparsers.add_parser("export", help="Export labeled data from Argilla") | |
| export_parser.add_argument( | |
| "output_file", | |
| type=str, | |
| help="Path to output file (e.g., labeled_data.jsonl)" | |
| ) | |
| export_parser.add_argument( | |
| "--api-url", | |
| type=str, | |
| default=None, | |
| help="Argilla API URL (e.g., http://localhost:6900 or https://your-space.hf.space)" | |
| ) | |
| export_parser.add_argument( | |
| "--api-key", | |
| type=str, | |
| default=None, | |
| help="Argilla API key (found in My Settings page)" | |
| ) | |
| export_parser.add_argument( | |
| "--dataset-name", | |
| type=str, | |
| default="text_labeling_dataset", | |
| help="Name for the Argilla dataset (default: text_labeling_dataset)" | |
| ) | |
| export_parser.add_argument( | |
| "--workspace", | |
| type=str, | |
| default="argilla", | |
| help="Argilla workspace name (default: argilla)" | |
| ) | |
| export_parser.add_argument( | |
| "--format", | |
| type=str, | |
| default="jsonl", | |
| choices=["jsonl", "json"], | |
| help="Output format (default: jsonl)" | |
| ) | |
| args = parser.parse_args() | |
| # Default to upload mode if not specified (backward compatibility) | |
| if args.mode is None: | |
| # Treat as upload mode with jsonl_file as first positional arg | |
| import sys | |
| # Re-parse with upload mode | |
| sys.argv.insert(1, "upload") | |
| args = parser.parse_args() | |
| # Connect to Argilla | |
| if args.api_url and args.api_key: | |
| print(f"Connecting to Argilla at {args.api_url}...") | |
| client = rg.Argilla( | |
| api_url=args.api_url, | |
| api_key=args.api_key | |
| ) | |
| else: | |
| print("Connecting to Argilla (using default or env variables)...") | |
| print("Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables if not using defaults") | |
| client = rg.Argilla() | |
| # Verify connection | |
| try: | |
| me = client.me() | |
| print(f"Connected as: {me.username}") | |
| except Exception as e: | |
| print(f"Error connecting to Argilla: {e}") | |
| return 1 | |
| if args.mode == "upload": | |
| # Validate input file | |
| jsonl_path = Path(args.jsonl_file) | |
| if not jsonl_path.exists(): | |
| print(f"Error: File '{args.jsonl_file}' not found!") | |
| return 1 | |
| # Load data | |
| print(f"Loading data from '{args.jsonl_file}'...") | |
| data = load_jsonl(args.jsonl_file) | |
| print(f"Loaded {len(data)} records") | |
| # Show sample | |
| if data: | |
| print(f"\nSample record:") | |
| print(json.dumps(data[0], ensure_ascii=False, indent=2)) | |
| # Create dataset | |
| dataset = create_dataset( | |
| client=client, | |
| dataset_name=args.dataset_name, | |
| workspace=args.workspace | |
| ) | |
| # Upload data | |
| upload_data(dataset, data) | |
| print(f"\n{'='*60}") | |
| print(f"Done! You can now label your data at:") | |
| print(f"{client.api_url}/datasets/{args.workspace}/{args.dataset_name}") | |
| print(f"{'='*60}") | |
| elif args.mode == "export": | |
| # Get dataset | |
| try: | |
| dataset = client.datasets(args.dataset_name, workspace=args.workspace) | |
| except Exception as e: | |
| print(f"Error: Dataset '{args.dataset_name}' not found: {e}") | |
| return 1 | |
| # Export data | |
| export_labeled_data(dataset, args.output_file, args.format) | |
| print(f"\nLabel mappings:") | |
| print(f" Task: {LABEL_MAP}") | |
| print(f" Region: {REGION_MAP}") | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |