""" Cloud storage: save images to local remote_images/ and sync with HuggingFace dataset. """ import base64 import json import os import threading from datetime import datetime from pathlib import Path from typing import Optional from huggingface_hub import HfApi, snapshot_download DATA_DIR = Path("remote_images") INDEX_FILE = DATA_DIR / ".search_index.json" class CloudStorage: def __init__(self): self.repo_id = os.environ.get("HF_DATASET_REPO", "") self.repo_type = "dataset" self.api = HfApi(token=os.environ.get("HF_TOKEN", "")) self._counter_date = "" # current date string self._counter_value = 0 # current day's count self._lock = threading.Lock() DATA_DIR.mkdir(parents=True, exist_ok=True) self._init_counter() self._search_index = {} # path -> search_text self._load_search_index() def _init_counter(self): """Initialize counter from existing files for today.""" today = datetime.now().strftime("%Y-%m-%d") self._counter_date = today self._counter_value = self._count_files_for_date(today) def _count_files_for_date(self, date_str: str) -> int: """Count existing PNG files for a given date, return max index.""" date_dir = DATA_DIR / date_str if not date_dir.exists(): return 0 max_idx = 0 for f in date_dir.glob("*.png"): try: idx = int(f.stem.split('-')[0]) max_idx = max(max_idx, idx) except (ValueError, IndexError): pass return max_idx def _next_index(self, date_str: str) -> int: with self._lock: if date_str != self._counter_date: # Date changed (past midnight): reset counter self._counter_date = date_str self._counter_value = self._count_files_for_date(date_str) self._counter_value += 1 return self._counter_value def save_image(self, image_b64: str, seed: int) -> Optional[str]: """Save image locally and return relative path.""" if not image_b64: return None date_str = datetime.now().strftime("%Y-%m-%d") date_dir = DATA_DIR / date_str date_dir.mkdir(parents=True, exist_ok=True) index = self._next_index(date_str) filename = f"{index:05d}-{seed}.png" filepath = date_dir / filename image_bytes = base64.b64decode(image_b64) with open(filepath, 'wb') as f: f.write(image_bytes) rel_path = f"{date_str}/{filename}" return rel_path def upload_file(self, rel_path: str): """Upload a single file to HuggingFace.""" if not self.repo_id: return local_path = DATA_DIR / rel_path if not local_path.exists(): return try: self.api.upload_file( path_or_fileobj=str(local_path), path_in_repo=rel_path, repo_id=self.repo_id, repo_type=self.repo_type, ) except Exception as e: print(f"[CloudStorage] Upload failed for {rel_path}: {e}") def delete_file(self, rel_path: str): """Delete a file from HuggingFace and locally.""" # Delete local local_path = DATA_DIR / rel_path if local_path.exists(): local_path.unlink() # Remove search index entry self.remove_search_entry(rel_path) # Delete remote if not self.repo_id: return try: self.api.delete_file( path_in_repo=rel_path, repo_id=self.repo_id, repo_type=self.repo_type, ) except Exception as e: print(f"[CloudStorage] Remote delete failed for {rel_path}: {e}") def sync_download(self): """Download snapshot from HuggingFace to local.""" if not self.repo_id: return try: snapshot_download( repo_id=self.repo_id, repo_type=self.repo_type, local_dir=DATA_DIR, token=os.environ.get("HF_TOKEN", "") ) # Reinitialize counter from synced files self._init_counter() self.build_search_index_from_files() except Exception as e: print(f"[CloudStorage] Sync download failed: {e}") def _load_search_index(self): if INDEX_FILE.exists(): try: with open(INDEX_FILE) as f: self._search_index = json.load(f) except Exception: self._search_index = {} def _save_search_index(self): try: with open(INDEX_FILE, 'w') as f: json.dump(self._search_index, f) except Exception as e: print(f"[CloudStorage] Failed to save search index: {e}") def set_search_text(self, rel_path: str, search_text: str): """Store search text for an image path.""" self._search_index[rel_path] = search_text.lower() self._save_search_index() def get_search_text(self, rel_path: str) -> str: return self._search_index.get(rel_path, "") def search_images(self, query: str, date_filter: str = "") -> list: """Search images by query against stored search text.""" all_images = self.list_images(date_filter) if not query.strip(): return all_images terms = query.strip().lower().split() results = [] for img in all_images: text = self._search_index.get(img["path"], "") if all(t in text for t in terms): results.append(img) return results def remove_search_entry(self, rel_path: str): """Remove search index entry for a path.""" self._search_index.pop(rel_path, None) self._save_search_index() def build_search_index_from_files(self): """Rebuild search index by reading PNG info from all files.""" from PIL import Image as PILImage for date_dir in DATA_DIR.iterdir(): if not date_dir.is_dir() or date_dir.name.startswith('.'): continue for f in date_dir.glob("*.png"): rel_path = f"{date_dir.name}/{f.name}" if rel_path in self._search_index: continue try: img = PILImage.open(f) parts = [] if hasattr(img, 'text'): for k, v in img.text.items(): parts.append(v) self._search_index[rel_path] = " ".join(parts).lower() except Exception: pass self._save_search_index() def list_images(self, date_filter: str = "") -> list: """List all images, optionally filtered by date prefix.""" results = [] for date_dir in sorted(DATA_DIR.iterdir(), reverse=True): if not date_dir.is_dir() or date_dir.name.startswith('.'): continue if date_filter and date_filter not in date_dir.name: continue for f in sorted(date_dir.glob("*.png"), reverse=True): rel_path = f"{date_dir.name}/{f.name}" results.append({ "path": rel_path, "name": f.name, "date": date_dir.name, "size": f.stat().st_size, "mtime": f.stat().st_mtime, }) return results def get_image_b64(self, rel_path: str) -> Optional[str]: """Get base64 of an image by relative path.""" local_path = DATA_DIR / rel_path if not local_path.exists(): return None with open(local_path, 'rb') as f: return base64.b64encode(f.read()).decode() def list_dates(self) -> list: """List all date folders.""" dates = [] for d in sorted(DATA_DIR.iterdir(), reverse=True): if d.is_dir() and not d.name.startswith('.'): count = len(list(d.glob("*.png"))) if count > 0: dates.append({"date": d.name, "count": count}) return dates cloud_storage = CloudStorage()