Spaces:
Running
Running
| """ | |
| Cloud storage: save images to local remote_images/ and sync with HuggingFace dataset. | |
| """ | |
| import base64 | |
| import json | |
| import os | |
| import threading | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Optional | |
| from huggingface_hub import HfApi, snapshot_download | |
| DATA_DIR = Path("remote_images") | |
| INDEX_FILE = DATA_DIR / ".search_index.json" | |
| class CloudStorage: | |
| def __init__(self): | |
| self.repo_id = os.environ.get("HF_DATASET_REPO", "") | |
| self.repo_type = "dataset" | |
| self.api = HfApi(token=os.environ.get("HF_TOKEN", "")) | |
| self._counter_date = "" # current date string | |
| self._counter_value = 0 # current day's count | |
| self._lock = threading.Lock() | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| self._init_counter() | |
| self._search_index = {} # path -> search_text | |
| self._load_search_index() | |
| def _init_counter(self): | |
| """Initialize counter from existing files for today.""" | |
| today = datetime.now().strftime("%Y-%m-%d") | |
| self._counter_date = today | |
| self._counter_value = self._count_files_for_date(today) | |
| def _count_files_for_date(self, date_str: str) -> int: | |
| """Count existing PNG files for a given date, return max index.""" | |
| date_dir = DATA_DIR / date_str | |
| if not date_dir.exists(): | |
| return 0 | |
| max_idx = 0 | |
| for f in date_dir.glob("*.png"): | |
| try: | |
| idx = int(f.stem.split('-')[0]) | |
| max_idx = max(max_idx, idx) | |
| except (ValueError, IndexError): | |
| pass | |
| return max_idx | |
| def _next_index(self, date_str: str) -> int: | |
| with self._lock: | |
| if date_str != self._counter_date: | |
| # Date changed (past midnight): reset counter | |
| self._counter_date = date_str | |
| self._counter_value = self._count_files_for_date(date_str) | |
| self._counter_value += 1 | |
| return self._counter_value | |
| def save_image(self, image_b64: str, seed: int) -> Optional[str]: | |
| """Save image locally and return relative path.""" | |
| if not image_b64: | |
| return None | |
| date_str = datetime.now().strftime("%Y-%m-%d") | |
| date_dir = DATA_DIR / date_str | |
| date_dir.mkdir(parents=True, exist_ok=True) | |
| index = self._next_index(date_str) | |
| filename = f"{index:05d}-{seed}.png" | |
| filepath = date_dir / filename | |
| image_bytes = base64.b64decode(image_b64) | |
| with open(filepath, 'wb') as f: | |
| f.write(image_bytes) | |
| rel_path = f"{date_str}/{filename}" | |
| return rel_path | |
| def upload_file(self, rel_path: str): | |
| """Upload a single file to HuggingFace.""" | |
| if not self.repo_id: | |
| return | |
| local_path = DATA_DIR / rel_path | |
| if not local_path.exists(): | |
| return | |
| try: | |
| self.api.upload_file( | |
| path_or_fileobj=str(local_path), | |
| path_in_repo=rel_path, | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| ) | |
| except Exception as e: | |
| print(f"[CloudStorage] Upload failed for {rel_path}: {e}") | |
| def delete_file(self, rel_path: str): | |
| """Delete a file from HuggingFace and locally.""" | |
| # Delete local | |
| local_path = DATA_DIR / rel_path | |
| if local_path.exists(): | |
| local_path.unlink() | |
| # Remove search index entry | |
| self.remove_search_entry(rel_path) | |
| # Delete remote | |
| if not self.repo_id: | |
| return | |
| try: | |
| self.api.delete_file( | |
| path_in_repo=rel_path, | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| ) | |
| except Exception as e: | |
| print(f"[CloudStorage] Remote delete failed for {rel_path}: {e}") | |
| def sync_download(self): | |
| """Download snapshot from HuggingFace to local.""" | |
| if not self.repo_id: | |
| return | |
| try: | |
| snapshot_download( | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| local_dir=DATA_DIR, | |
| token=os.environ.get("HF_TOKEN", "") | |
| ) | |
| # Reinitialize counter from synced files | |
| self._init_counter() | |
| self.build_search_index_from_files() | |
| except Exception as e: | |
| print(f"[CloudStorage] Sync download failed: {e}") | |
| def _load_search_index(self): | |
| if INDEX_FILE.exists(): | |
| try: | |
| with open(INDEX_FILE) as f: | |
| self._search_index = json.load(f) | |
| except Exception: | |
| self._search_index = {} | |
| def _save_search_index(self): | |
| try: | |
| with open(INDEX_FILE, 'w') as f: | |
| json.dump(self._search_index, f) | |
| except Exception as e: | |
| print(f"[CloudStorage] Failed to save search index: {e}") | |
| def set_search_text(self, rel_path: str, search_text: str): | |
| """Store search text for an image path.""" | |
| self._search_index[rel_path] = search_text.lower() | |
| self._save_search_index() | |
| def get_search_text(self, rel_path: str) -> str: | |
| return self._search_index.get(rel_path, "") | |
| def search_images(self, query: str, date_filter: str = "") -> list: | |
| """Search images by query against stored search text.""" | |
| all_images = self.list_images(date_filter) | |
| if not query.strip(): | |
| return all_images | |
| terms = query.strip().lower().split() | |
| results = [] | |
| for img in all_images: | |
| text = self._search_index.get(img["path"], "") | |
| if all(t in text for t in terms): | |
| results.append(img) | |
| return results | |
| def remove_search_entry(self, rel_path: str): | |
| """Remove search index entry for a path.""" | |
| self._search_index.pop(rel_path, None) | |
| self._save_search_index() | |
| def build_search_index_from_files(self): | |
| """Rebuild search index by reading PNG info from all files.""" | |
| from PIL import Image as PILImage | |
| for date_dir in DATA_DIR.iterdir(): | |
| if not date_dir.is_dir() or date_dir.name.startswith('.'): | |
| continue | |
| for f in date_dir.glob("*.png"): | |
| rel_path = f"{date_dir.name}/{f.name}" | |
| if rel_path in self._search_index: | |
| continue | |
| try: | |
| img = PILImage.open(f) | |
| parts = [] | |
| if hasattr(img, 'text'): | |
| for k, v in img.text.items(): | |
| parts.append(v) | |
| self._search_index[rel_path] = " ".join(parts).lower() | |
| except Exception: | |
| pass | |
| self._save_search_index() | |
| def list_images(self, date_filter: str = "") -> list: | |
| """List all images, optionally filtered by date prefix.""" | |
| results = [] | |
| for date_dir in sorted(DATA_DIR.iterdir(), reverse=True): | |
| if not date_dir.is_dir() or date_dir.name.startswith('.'): | |
| continue | |
| if date_filter and date_filter not in date_dir.name: | |
| continue | |
| for f in sorted(date_dir.glob("*.png"), reverse=True): | |
| rel_path = f"{date_dir.name}/{f.name}" | |
| results.append({ | |
| "path": rel_path, | |
| "name": f.name, | |
| "date": date_dir.name, | |
| "size": f.stat().st_size, | |
| "mtime": f.stat().st_mtime, | |
| }) | |
| return results | |
| def get_image_b64(self, rel_path: str) -> Optional[str]: | |
| """Get base64 of an image by relative path.""" | |
| local_path = DATA_DIR / rel_path | |
| if not local_path.exists(): | |
| return None | |
| with open(local_path, 'rb') as f: | |
| return base64.b64encode(f.read()).decode() | |
| def list_dates(self) -> list: | |
| """List all date folders.""" | |
| dates = [] | |
| for d in sorted(DATA_DIR.iterdir(), reverse=True): | |
| if d.is_dir() and not d.name.startswith('.'): | |
| count = len(list(d.glob("*.png"))) | |
| if count > 0: | |
| dates.append({"date": d.name, "count": count}) | |
| return dates | |
| cloud_storage = CloudStorage() |