nano / dataset_manager.py
leave-everything's picture
Upload 5 files
ced485b verified
"""
Dataset Manager for NanoBanana Image Generator
Handles saving generated images to Hugging Face Dataset Repository
"""
import os
import json
import logging
from typing import Optional, Dict, Any, List
from datetime import datetime
from pathlib import Path
from PIL import Image
from io import BytesIO
import tempfile
from huggingface_hub import HfApi, upload_file, upload_folder, create_repo, list_repo_files
logger = logging.getLogger(__name__)
class DatasetManager:
"""
Manages image uploads to Hugging Face Dataset Repository
with custom folder structure support
"""
def __init__(self, repo_id: str, token: str):
"""
Initialize Dataset Manager
Args:
repo_id: Hugging Face dataset repository ID (e.g., "username/dataset-name")
token: Hugging Face API token with write access
"""
self.repo_id = repo_id
self.token = token
self.api = HfApi(token=token)
# Ensure repository exists
self._ensure_repo_exists()
logger.info(f"DatasetManager initialized for repository: {repo_id}")
def _ensure_repo_exists(self):
"""Create dataset repository if it doesn't exist"""
try:
# Try to get repository info
self.api.repo_info(repo_id=self.repo_id, repo_type="dataset")
logger.info(f"Dataset repository {self.repo_id} already exists")
except Exception as e:
# Repository doesn't exist, create it
try:
self.api.create_repo(
repo_id=self.repo_id,
repo_type="dataset",
private=False,
exist_ok=True
)
logger.info(f"Created new dataset repository: {self.repo_id}")
# Create initial README
readme_content = f"""# NanoBanana Generated Images Dataset
This dataset contains images generated by the NanoBanana Gemini Image Generator.
## Structure
- `images/` - Generated images organized by date or custom folders
- `YYYY_MM_DD/` - Images generated on specific dates
- Custom folders as specified during generation
## Metadata
Each image is accompanied by a JSON metadata file containing:
- Generation prompt
- Model used
- Generation parameters
- Timestamp
## Usage
You can load this dataset using the Hugging Face `datasets` library:
```python
from datasets import load_dataset
dataset = load_dataset("{self.repo_id}")
```
---
Generated with 🍌 NanoBanana Image Generator
"""
self._upload_readme(readme_content)
except Exception as create_error:
logger.warning(f"Could not create repository: {create_error}")
def _upload_readme(self, content: str):
"""Upload README.md to repository"""
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
f.write(content)
temp_path = f.name
self.api.upload_file(
path_or_fileobj=temp_path,
path_in_repo="README.md",
repo_id=self.repo_id,
repo_type="dataset"
)
# Clean up temp file
os.unlink(temp_path)
except Exception as e:
logger.error(f"Failed to upload README: {e}")
def save_image(
self,
image: Image.Image,
prompt: str,
folder_name: Optional[str] = None,
filename: Optional[str] = None, # 新しいパラメータ追加
metadata: Optional[Dict[str, Any]] = None
) -> Optional[str]:
"""
Save image to dataset repository
Args:
image: PIL Image object to save
prompt: Generation prompt used
folder_name: Optional folder name. If None, uses YYYY_MM_DD format
filename: Optional custom filename (without extension). If None, auto-generates
metadata: Optional additional metadata to save
Returns:
URL to the uploaded image in the dataset repository, or None if failed
"""
try:
# Determine folder name
if not folder_name:
folder_name = datetime.now().strftime("%Y_%m_%d")
# Generate or use custom filename
if filename:
# Sanitize filename - remove extension if provided
clean_filename = os.path.splitext(filename)[0]
# Replace invalid characters
clean_filename = "".join(c if c.isalnum() or c in '-_' else '_' for c in clean_filename)
# Ensure filename is not empty
if not clean_filename:
clean_filename = f"image_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')[:-3]}"
image_filename = f"{clean_filename}.png"
metadata_filename = f"{clean_filename}.json"
else:
# Auto-generate filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Milliseconds
image_filename = f"image_{timestamp}.png"
metadata_filename = f"image_{timestamp}.json"
# Paths in repository
image_path_in_repo = f"images/{folder_name}/{image_filename}"
metadata_path_in_repo = f"images/{folder_name}/{metadata_filename}"
# Create metadata
image_metadata = {
"prompt": prompt,
"timestamp": datetime.now().isoformat(),
"folder": folder_name,
"filename": image_filename,
"custom_filename": filename is not None, # Track if filename was custom
"size": {
"width": image.width,
"height": image.height
}
}
# Add additional metadata if provided
if metadata:
image_metadata.update(metadata)
# Save image to temporary file
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as img_temp:
image.save(img_temp, format='PNG')
img_temp_path = img_temp.name
# Save metadata to temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as meta_temp:
json.dump(image_metadata, meta_temp, indent=2)
meta_temp_path = meta_temp.name
# Upload image
self.api.upload_file(
path_or_fileobj=img_temp_path,
path_in_repo=image_path_in_repo,
repo_id=self.repo_id,
repo_type="dataset",
commit_message=f"Add generated image: {image_filename}"
)
# Upload metadata
self.api.upload_file(
path_or_fileobj=meta_temp_path,
path_in_repo=metadata_path_in_repo,
repo_id=self.repo_id,
repo_type="dataset",
commit_message=f"Add metadata for: {image_filename}"
)
# Clean up temp files
os.unlink(img_temp_path)
os.unlink(meta_temp_path)
# Return URL to the image
dataset_url = f"https://huggingface.co/datasets/{self.repo_id}/blob/main/{image_path_in_repo}"
logger.info(f"Successfully saved image to dataset: {image_path_in_repo}")
return dataset_url
except Exception as e:
logger.error(f"Failed to save image to dataset: {e}")
return None
def get_folders(self) -> List[str]:
"""
Get list of folders in the dataset repository
Returns:
List of folder names
"""
try:
files = self.api.list_repo_files(
repo_id=self.repo_id,
repo_type="dataset"
)
# Extract folder names from paths
folders = set()
for file in files:
if file.startswith("images/"):
parts = file.split("/")
if len(parts) > 1:
folder = parts[1]
folders.add(folder)
return sorted(list(folders))
except Exception as e:
logger.error(f"Failed to get folders: {e}")
return []
def get_images_in_folder(self, folder_name: str) -> List[Dict[str, str]]:
"""
Get list of images in a specific folder
Args:
folder_name: Name of the folder
Returns:
List of dictionaries containing image info
"""
try:
files = self.api.list_repo_files(
repo_id=self.repo_id,
repo_type="dataset"
)
images = []
folder_prefix = f"images/{folder_name}/"
for file in files:
if file.startswith(folder_prefix) and file.endswith(".png"):
# Get corresponding metadata file
metadata_file = file.replace(".png", ".json")
image_info = {
"filename": os.path.basename(file),
"path": file,
"url": f"https://huggingface.co/datasets/{self.repo_id}/resolve/main/{file}",
"metadata_url": f"https://huggingface.co/datasets/{self.repo_id}/resolve/main/{metadata_file}"
}
images.append(image_info)
return sorted(images, key=lambda x: x["filename"], reverse=True)
except Exception as e:
logger.error(f"Failed to get images in folder {folder_name}: {e}")
return []
def batch_save_images(
self,
images_data: List[Dict[str, Any]],
folder_name: Optional[str] = None
) -> List[Optional[str]]:
"""
Save multiple images in batch
Args:
images_data: List of dicts with 'image', 'prompt', and optional 'metadata'
folder_name: Optional folder name for all images
Returns:
List of URLs for uploaded images
"""
urls = []
for data in images_data:
url = self.save_image(
image=data['image'],
prompt=data['prompt'],
folder_name=folder_name,
metadata=data.get('metadata')
)
urls.append(url)
return urls