| | import sys |
| | import time |
| | import os |
| |
|
| | import pandas as pd |
| | import requests |
| | from datasets import load_dataset, concatenate_datasets |
| |
|
| | import argilla as rg |
| | from argilla.listeners import listener |
| |
|
| | |
| |
|
| | |
| | HF_TOKEN = os.environ.get("HF_TOKEN") |
| |
|
| | |
| | SOURCE_DATASET = "LEL-A/translated_german_alpaca" |
| |
|
| | |
| | RG_DATASET_NAME = "translated-german-alpaca" |
| |
|
| | |
| | HUB_DATASET_NAME = os.environ.get('HUB_DATASET_NAME', f"{SOURCE_DATASET}_validation") |
| |
|
| | |
| | LABELS = ["BAD INSTRUCTION", "INAPPROPRIATE", "ALL GOOD", "NOT SURE", "WRONG LANGUAGE"] |
| |
|
| | @listener( |
| | dataset=RG_DATASET_NAME, |
| | query="status:Validated", |
| | execution_interval_in_seconds=1200, |
| | ) |
| | def save_validated_to_hub(records, ctx): |
| | if len(records) > 0: |
| | ds = rg.DatasetForTextClassification(records=records).to_datasets() |
| | if HF_TOKEN: |
| | print("Pushing the dataset") |
| | print(ds) |
| | ds.push_to_hub(HUB_DATASET_NAME, token=HF_TOKEN) |
| | else: |
| | print("SET HF_TOKEN and HUB_DATASET_NAME TO SYNC YOUR DATASET!!!") |
| | else: |
| | print("NO RECORDS found") |
| |
|
| | class LoadDatasets: |
| | def __init__(self, api_key, workspace="team"): |
| | rg.init(api_key=api_key, workspace=workspace) |
| |
|
| | @staticmethod |
| | def load_somos(): |
| | |
| | try: |
| | print(f"Trying to sync with {HUB_DATASET_NAME}") |
| | old_ds = load_dataset(HUB_DATASET_NAME, split="train") |
| | except Exception as e: |
| | print(f"Not possible to sync with {HUB_DATASET_NAME}") |
| | print(e) |
| | old_ds = None |
| |
|
| | print(f"Loading dataset: {SOURCE_DATASET}") |
| | dataset = load_dataset(SOURCE_DATASET, split="train") |
| | |
| | |
| | if old_ds: |
| | print("Concatenating datasets") |
| | dataset = concatenate_datasets([dataset, old_ds]) |
| | print("Concatenated dataset is:") |
| | print(dataset) |
| | |
| | dataset = dataset.remove_columns("metrics") |
| | records = rg.DatasetForTextClassification.from_datasets(dataset) |
| |
|
| | settings = rg.TextClassificationSettings( |
| | label_schema=LABELS |
| | ) |
| | |
| | print(f"Configuring dataset: {RG_DATASET_NAME}") |
| | rg.configure_dataset(name=RG_DATASET_NAME, settings=settings, workspace="team") |
| | |
| | |
| | print(f"Logging dataset: {RG_DATASET_NAME}") |
| | rg.log( |
| | records, |
| | name=RG_DATASET_NAME, |
| | tags={"description": "Alpaca dataset to clean up"}, |
| | batch_size=200 |
| | ) |
| | |
| | |
| | save_validated_to_hub.start() |
| |
|
| | if __name__ == "__main__": |
| | API_KEY = sys.argv[1] |
| | LOAD_DATASETS = sys.argv[2] |
| |
|
| | if LOAD_DATASETS.lower() == "none": |
| | print("No datasets being loaded") |
| | else: |
| | while True: |
| | try: |
| | response = requests.get("http://0.0.0.0:6900/") |
| | if response.status_code == 200: |
| | ld = LoadDatasets(API_KEY) |
| | ld.load_somos() |
| | break |
| |
|
| | except requests.exceptions.ConnectionError: |
| | pass |
| | except Exception as e: |
| | print(e) |
| | time.sleep(10) |
| | pass |
| |
|
| | time.sleep(5) |
| | while True: |
| | time.sleep(60) |