| | import os |
| | import gc |
| | import time |
| | import shutil |
| | import logging |
| | from pathlib import Path |
| | from huggingface_hub import WebhooksServer, WebhookPayload |
| | from datasets import Dataset, load_dataset, disable_caching |
| | from fastapi import BackgroundTasks, Response, status |
| |
|
| |
|
| | def clear_huggingface_cache(): |
| | |
| | cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" |
| |
|
| | |
| | if cache_dir.exists() and cache_dir.is_dir(): |
| | shutil.rmtree(cache_dir) |
| | print(f"Removed cache directory: {cache_dir}") |
| | else: |
| | print("Cache directory does not exist.") |
| |
|
| |
|
| | |
| | disable_caching() |
| |
|
| | |
| | logger = logging.getLogger("basic_logger") |
| | logger.setLevel(logging.INFO) |
| |
|
| | |
| | console_handler = logging.StreamHandler() |
| | console_handler.setLevel(logging.INFO) |
| | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| | console_handler.setFormatter(formatter) |
| | logger.addHandler(console_handler) |
| |
|
| | |
| | DS_NAME = "amaye15/object-segmentation" |
| | DATA_DIR = Path("data") |
| | TARGET_REPO = "amaye15/object-segmentation-processed" |
| | WEBHOOK_SECRET = os.getenv("HF_WEBHOOK_SECRET") |
| |
|
| |
|
| | def get_data(): |
| | """ |
| | Generator function to stream data from the dataset. |
| | |
| | Uses streaming to avoid loading the entire dataset into memory at once, |
| | which is useful for handling large datasets. |
| | """ |
| | ds = load_dataset( |
| | DS_NAME, |
| | streaming=True, |
| | ) |
| | for row in ds["train"]: |
| | yield row |
| |
|
| |
|
| | def process_and_push_data(): |
| | """ |
| | Function to process and push new data to the target repository. |
| | |
| | Removes existing data directory if it exists, recreates it, processes |
| | the dataset, and pushes the processed dataset to the hub. |
| | """ |
| |
|
| | |
| | ds_processed = Dataset.from_generator(get_data) |
| | ds_processed.push_to_hub(TARGET_REPO, max_shard_size="1GB") |
| |
|
| | logger.info("Data processed and pushed to the hub.") |
| |
|
| |
|
| | |
| | app = WebhooksServer(webhook_secret=WEBHOOK_SECRET) |
| |
|
| |
|
| | @app.add_webhook("/dataset_repo") |
| | async def handle_repository_changes( |
| | payload: WebhookPayload, task_queue: BackgroundTasks |
| | ): |
| | """ |
| | Webhook endpoint that triggers data processing when the dataset is updated. |
| | |
| | Adds a task to the background task queue to process the dataset |
| | asynchronously. |
| | """ |
| | time.sleep(15) |
| | clear_huggingface_cache() |
| | logger.info( |
| | f"Webhook received from {payload.repo.name} indicating a repo {payload.event.action}" |
| | ) |
| | task_queue.add_task(_process_webhook) |
| | return Response("Task scheduled.", status_code=status.HTTP_202_ACCEPTED) |
| |
|
| |
|
| | def _process_webhook(): |
| | """ |
| | Private function to handle the processing of the dataset when a webhook |
| | is triggered. |
| | |
| | Loads the dataset, processes it, and pushes the processed data to the hub. |
| | """ |
| | logger.info("Loading new dataset...") |
| | |
| | logger.info("Loaded new dataset") |
| |
|
| | logger.info("Processing and updating dataset...") |
| | process_and_push_data() |
| | logger.info("Processing and updating dataset completed!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app.launch(server_name="0.0.0.0", show_error=True, server_port=7860) |
| |
|