| import io |
| import json |
| from dataclasses import dataclass |
| from typing import Any, List |
|
|
| from huggingface_hub import HfApi, create_repo |
| from loguru import logger |
|
|
|
|
| @dataclass |
| class DreamboothPreprocessor: |
| concept_images: List[Any] |
| concept_name: str |
| username: str |
| project_name: str |
| token: str |
|
|
| def __post_init__(self): |
| self.repo_name = f"{self.username}/autotrain-data-{self.project_name}" |
| try: |
| create_repo( |
| repo_id=self.repo_name, |
| repo_type="dataset", |
| token=self.token, |
| private=True, |
| exist_ok=False, |
| ) |
| except Exception: |
| logger.error("Error creating repo") |
| raise ValueError("Error creating repo") |
|
|
| def _upload_concept_images(self, file, api): |
| logger.info(f"Uploading {file} to concept1") |
| api.upload_file( |
| path_or_fileobj=file.name, |
| path_in_repo=f"concept1/{file.name.split('/')[-1]}", |
| repo_id=self.repo_name, |
| repo_type="dataset", |
| token=self.token, |
| ) |
|
|
| def _upload_concept_prompts(self, api): |
| _prompts = {} |
| _prompts["concept1"] = self.concept_name |
|
|
| prompts = json.dumps(_prompts) |
| prompts = prompts.encode("utf-8") |
| prompts = io.BytesIO(prompts) |
| api.upload_file( |
| path_or_fileobj=prompts, |
| path_in_repo="prompts.json", |
| repo_id=self.repo_name, |
| repo_type="dataset", |
| token=self.token, |
| ) |
|
|
| def prepare(self): |
| api = HfApi() |
| for _file in self.concept_images: |
| self._upload_concept_images(_file, api) |
|
|
| self._upload_concept_prompts(api) |
|
|