Spaces:
Running
Running
| """ | |
| Dataset Manager for NanoBanana Image Generator | |
| Handles saving generated images to Hugging Face Dataset Repository | |
| """ | |
| import os | |
| import json | |
| import logging | |
| from typing import Optional, Dict, Any, List | |
| from datetime import datetime | |
| from pathlib import Path | |
| from PIL import Image | |
| from io import BytesIO | |
| import tempfile | |
| from huggingface_hub import HfApi, upload_file, upload_folder, create_repo, list_repo_files | |
| logger = logging.getLogger(__name__) | |
| class DatasetManager: | |
| """ | |
| Manages image uploads to Hugging Face Dataset Repository | |
| with custom folder structure support | |
| """ | |
| def __init__(self, repo_id: str, token: str): | |
| """ | |
| Initialize Dataset Manager | |
| Args: | |
| repo_id: Hugging Face dataset repository ID (e.g., "username/dataset-name") | |
| token: Hugging Face API token with write access | |
| """ | |
| self.repo_id = repo_id | |
| self.token = token | |
| self.api = HfApi(token=token) | |
| # Ensure repository exists | |
| self._ensure_repo_exists() | |
| logger.info(f"DatasetManager initialized for repository: {repo_id}") | |
| def _ensure_repo_exists(self): | |
| """Create dataset repository if it doesn't exist""" | |
| try: | |
| # Try to get repository info | |
| self.api.repo_info(repo_id=self.repo_id, repo_type="dataset") | |
| logger.info(f"Dataset repository {self.repo_id} already exists") | |
| except Exception as e: | |
| # Repository doesn't exist, create it | |
| try: | |
| self.api.create_repo( | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| private=False, | |
| exist_ok=True | |
| ) | |
| logger.info(f"Created new dataset repository: {self.repo_id}") | |
| # Create initial README | |
| readme_content = f"""# NanoBanana Generated Images Dataset | |
| This dataset contains images generated by the NanoBanana Gemini Image Generator. | |
| ## Structure | |
| - `images/` - Generated images organized by date or custom folders | |
| - `YYYY_MM_DD/` - Images generated on specific dates | |
| - Custom folders as specified during generation | |
| ## Metadata | |
| Each image is accompanied by a JSON metadata file containing: | |
| - Generation prompt | |
| - Model used | |
| - Generation parameters | |
| - Timestamp | |
| ## Usage | |
| You can load this dataset using the Hugging Face `datasets` library: | |
| ```python | |
| from datasets import load_dataset | |
| dataset = load_dataset("{self.repo_id}") | |
| ``` | |
| --- | |
| Generated with 🍌 NanoBanana Image Generator | |
| """ | |
| self._upload_readme(readme_content) | |
| except Exception as create_error: | |
| logger.warning(f"Could not create repository: {create_error}") | |
| def _upload_readme(self, content: str): | |
| """Upload README.md to repository""" | |
| try: | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f: | |
| f.write(content) | |
| temp_path = f.name | |
| self.api.upload_file( | |
| path_or_fileobj=temp_path, | |
| path_in_repo="README.md", | |
| repo_id=self.repo_id, | |
| repo_type="dataset" | |
| ) | |
| # Clean up temp file | |
| os.unlink(temp_path) | |
| except Exception as e: | |
| logger.error(f"Failed to upload README: {e}") | |
| def save_image( | |
| self, | |
| image: Image.Image, | |
| prompt: str, | |
| folder_name: Optional[str] = None, | |
| filename: Optional[str] = None, # 新しいパラメータ追加 | |
| metadata: Optional[Dict[str, Any]] = None | |
| ) -> Optional[str]: | |
| """ | |
| Save image to dataset repository | |
| Args: | |
| image: PIL Image object to save | |
| prompt: Generation prompt used | |
| folder_name: Optional folder name. If None, uses YYYY_MM_DD format | |
| filename: Optional custom filename (without extension). If None, auto-generates | |
| metadata: Optional additional metadata to save | |
| Returns: | |
| URL to the uploaded image in the dataset repository, or None if failed | |
| """ | |
| try: | |
| # Determine folder name | |
| if not folder_name: | |
| folder_name = datetime.now().strftime("%Y_%m_%d") | |
| # Generate or use custom filename | |
| if filename: | |
| # Sanitize filename - remove extension if provided | |
| clean_filename = os.path.splitext(filename)[0] | |
| # Replace invalid characters | |
| clean_filename = "".join(c if c.isalnum() or c in '-_' else '_' for c in clean_filename) | |
| # Ensure filename is not empty | |
| if not clean_filename: | |
| clean_filename = f"image_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')[:-3]}" | |
| image_filename = f"{clean_filename}.png" | |
| metadata_filename = f"{clean_filename}.json" | |
| else: | |
| # Auto-generate filename | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Milliseconds | |
| image_filename = f"image_{timestamp}.png" | |
| metadata_filename = f"image_{timestamp}.json" | |
| # Paths in repository | |
| image_path_in_repo = f"images/{folder_name}/{image_filename}" | |
| metadata_path_in_repo = f"images/{folder_name}/{metadata_filename}" | |
| # Create metadata | |
| image_metadata = { | |
| "prompt": prompt, | |
| "timestamp": datetime.now().isoformat(), | |
| "folder": folder_name, | |
| "filename": image_filename, | |
| "custom_filename": filename is not None, # Track if filename was custom | |
| "size": { | |
| "width": image.width, | |
| "height": image.height | |
| } | |
| } | |
| # Add additional metadata if provided | |
| if metadata: | |
| image_metadata.update(metadata) | |
| # Save image to temporary file | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as img_temp: | |
| image.save(img_temp, format='PNG') | |
| img_temp_path = img_temp.name | |
| # Save metadata to temporary file | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as meta_temp: | |
| json.dump(image_metadata, meta_temp, indent=2) | |
| meta_temp_path = meta_temp.name | |
| # Upload image | |
| self.api.upload_file( | |
| path_or_fileobj=img_temp_path, | |
| path_in_repo=image_path_in_repo, | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| commit_message=f"Add generated image: {image_filename}" | |
| ) | |
| # Upload metadata | |
| self.api.upload_file( | |
| path_or_fileobj=meta_temp_path, | |
| path_in_repo=metadata_path_in_repo, | |
| repo_id=self.repo_id, | |
| repo_type="dataset", | |
| commit_message=f"Add metadata for: {image_filename}" | |
| ) | |
| # Clean up temp files | |
| os.unlink(img_temp_path) | |
| os.unlink(meta_temp_path) | |
| # Return URL to the image | |
| dataset_url = f"https://huggingface.co/datasets/{self.repo_id}/blob/main/{image_path_in_repo}" | |
| logger.info(f"Successfully saved image to dataset: {image_path_in_repo}") | |
| return dataset_url | |
| except Exception as e: | |
| logger.error(f"Failed to save image to dataset: {e}") | |
| return None | |
| def get_folders(self) -> List[str]: | |
| """ | |
| Get list of folders in the dataset repository | |
| Returns: | |
| List of folder names | |
| """ | |
| try: | |
| files = self.api.list_repo_files( | |
| repo_id=self.repo_id, | |
| repo_type="dataset" | |
| ) | |
| # Extract folder names from paths | |
| folders = set() | |
| for file in files: | |
| if file.startswith("images/"): | |
| parts = file.split("/") | |
| if len(parts) > 1: | |
| folder = parts[1] | |
| folders.add(folder) | |
| return sorted(list(folders)) | |
| except Exception as e: | |
| logger.error(f"Failed to get folders: {e}") | |
| return [] | |
| def get_images_in_folder(self, folder_name: str) -> List[Dict[str, str]]: | |
| """ | |
| Get list of images in a specific folder | |
| Args: | |
| folder_name: Name of the folder | |
| Returns: | |
| List of dictionaries containing image info | |
| """ | |
| try: | |
| files = self.api.list_repo_files( | |
| repo_id=self.repo_id, | |
| repo_type="dataset" | |
| ) | |
| images = [] | |
| folder_prefix = f"images/{folder_name}/" | |
| for file in files: | |
| if file.startswith(folder_prefix) and file.endswith(".png"): | |
| # Get corresponding metadata file | |
| metadata_file = file.replace(".png", ".json") | |
| image_info = { | |
| "filename": os.path.basename(file), | |
| "path": file, | |
| "url": f"https://huggingface.co/datasets/{self.repo_id}/resolve/main/{file}", | |
| "metadata_url": f"https://huggingface.co/datasets/{self.repo_id}/resolve/main/{metadata_file}" | |
| } | |
| images.append(image_info) | |
| return sorted(images, key=lambda x: x["filename"], reverse=True) | |
| except Exception as e: | |
| logger.error(f"Failed to get images in folder {folder_name}: {e}") | |
| return [] | |
| def batch_save_images( | |
| self, | |
| images_data: List[Dict[str, Any]], | |
| folder_name: Optional[str] = None | |
| ) -> List[Optional[str]]: | |
| """ | |
| Save multiple images in batch | |
| Args: | |
| images_data: List of dicts with 'image', 'prompt', and optional 'metadata' | |
| folder_name: Optional folder name for all images | |
| Returns: | |
| List of URLs for uploaded images | |
| """ | |
| urls = [] | |
| for data in images_data: | |
| url = self.save_image( | |
| image=data['image'], | |
| prompt=data['prompt'], | |
| folder_name=folder_name, | |
| metadata=data.get('metadata') | |
| ) | |
| urls.append(url) | |
| return urls |