client / server /cloud_storage.py
P01yH3dr0n's picture
launch
774fe36
Raw
History Blame Contribute Delete
8.41 kB
"""
Cloud storage: save images to local remote_images/ and sync with HuggingFace dataset.
"""
import base64
import json
import os
import threading
from datetime import datetime
from pathlib import Path
from typing import Optional
from huggingface_hub import HfApi, snapshot_download
DATA_DIR = Path("remote_images")
INDEX_FILE = DATA_DIR / ".search_index.json"
class CloudStorage:
def __init__(self):
self.repo_id = os.environ.get("HF_DATASET_REPO", "")
self.repo_type = "dataset"
self.api = HfApi(token=os.environ.get("HF_TOKEN", ""))
self._counter_date = "" # current date string
self._counter_value = 0 # current day's count
self._lock = threading.Lock()
DATA_DIR.mkdir(parents=True, exist_ok=True)
self._init_counter()
self._search_index = {} # path -> search_text
self._load_search_index()
def _init_counter(self):
"""Initialize counter from existing files for today."""
today = datetime.now().strftime("%Y-%m-%d")
self._counter_date = today
self._counter_value = self._count_files_for_date(today)
def _count_files_for_date(self, date_str: str) -> int:
"""Count existing PNG files for a given date, return max index."""
date_dir = DATA_DIR / date_str
if not date_dir.exists():
return 0
max_idx = 0
for f in date_dir.glob("*.png"):
try:
idx = int(f.stem.split('-')[0])
max_idx = max(max_idx, idx)
except (ValueError, IndexError):
pass
return max_idx
def _next_index(self, date_str: str) -> int:
with self._lock:
if date_str != self._counter_date:
# Date changed (past midnight): reset counter
self._counter_date = date_str
self._counter_value = self._count_files_for_date(date_str)
self._counter_value += 1
return self._counter_value
def save_image(self, image_b64: str, seed: int) -> Optional[str]:
"""Save image locally and return relative path."""
if not image_b64:
return None
date_str = datetime.now().strftime("%Y-%m-%d")
date_dir = DATA_DIR / date_str
date_dir.mkdir(parents=True, exist_ok=True)
index = self._next_index(date_str)
filename = f"{index:05d}-{seed}.png"
filepath = date_dir / filename
image_bytes = base64.b64decode(image_b64)
with open(filepath, 'wb') as f:
f.write(image_bytes)
rel_path = f"{date_str}/{filename}"
return rel_path
def upload_file(self, rel_path: str):
"""Upload a single file to HuggingFace."""
if not self.repo_id:
return
local_path = DATA_DIR / rel_path
if not local_path.exists():
return
try:
self.api.upload_file(
path_or_fileobj=str(local_path),
path_in_repo=rel_path,
repo_id=self.repo_id,
repo_type=self.repo_type,
)
except Exception as e:
print(f"[CloudStorage] Upload failed for {rel_path}: {e}")
def delete_file(self, rel_path: str):
"""Delete a file from HuggingFace and locally."""
# Delete local
local_path = DATA_DIR / rel_path
if local_path.exists():
local_path.unlink()
# Remove search index entry
self.remove_search_entry(rel_path)
# Delete remote
if not self.repo_id:
return
try:
self.api.delete_file(
path_in_repo=rel_path,
repo_id=self.repo_id,
repo_type=self.repo_type,
)
except Exception as e:
print(f"[CloudStorage] Remote delete failed for {rel_path}: {e}")
def sync_download(self):
"""Download snapshot from HuggingFace to local."""
if not self.repo_id:
return
try:
snapshot_download(
repo_id=self.repo_id,
repo_type=self.repo_type,
local_dir=DATA_DIR,
token=os.environ.get("HF_TOKEN", "")
)
# Reinitialize counter from synced files
self._init_counter()
self.build_search_index_from_files()
except Exception as e:
print(f"[CloudStorage] Sync download failed: {e}")
def _load_search_index(self):
if INDEX_FILE.exists():
try:
with open(INDEX_FILE) as f:
self._search_index = json.load(f)
except Exception:
self._search_index = {}
def _save_search_index(self):
try:
with open(INDEX_FILE, 'w') as f:
json.dump(self._search_index, f)
except Exception as e:
print(f"[CloudStorage] Failed to save search index: {e}")
def set_search_text(self, rel_path: str, search_text: str):
"""Store search text for an image path."""
self._search_index[rel_path] = search_text.lower()
self._save_search_index()
def get_search_text(self, rel_path: str) -> str:
return self._search_index.get(rel_path, "")
def search_images(self, query: str, date_filter: str = "") -> list:
"""Search images by query against stored search text."""
all_images = self.list_images(date_filter)
if not query.strip():
return all_images
terms = query.strip().lower().split()
results = []
for img in all_images:
text = self._search_index.get(img["path"], "")
if all(t in text for t in terms):
results.append(img)
return results
def remove_search_entry(self, rel_path: str):
"""Remove search index entry for a path."""
self._search_index.pop(rel_path, None)
self._save_search_index()
def build_search_index_from_files(self):
"""Rebuild search index by reading PNG info from all files."""
from PIL import Image as PILImage
for date_dir in DATA_DIR.iterdir():
if not date_dir.is_dir() or date_dir.name.startswith('.'):
continue
for f in date_dir.glob("*.png"):
rel_path = f"{date_dir.name}/{f.name}"
if rel_path in self._search_index:
continue
try:
img = PILImage.open(f)
parts = []
if hasattr(img, 'text'):
for k, v in img.text.items():
parts.append(v)
self._search_index[rel_path] = " ".join(parts).lower()
except Exception:
pass
self._save_search_index()
def list_images(self, date_filter: str = "") -> list:
"""List all images, optionally filtered by date prefix."""
results = []
for date_dir in sorted(DATA_DIR.iterdir(), reverse=True):
if not date_dir.is_dir() or date_dir.name.startswith('.'):
continue
if date_filter and date_filter not in date_dir.name:
continue
for f in sorted(date_dir.glob("*.png"), reverse=True):
rel_path = f"{date_dir.name}/{f.name}"
results.append({
"path": rel_path,
"name": f.name,
"date": date_dir.name,
"size": f.stat().st_size,
"mtime": f.stat().st_mtime,
})
return results
def get_image_b64(self, rel_path: str) -> Optional[str]:
"""Get base64 of an image by relative path."""
local_path = DATA_DIR / rel_path
if not local_path.exists():
return None
with open(local_path, 'rb') as f:
return base64.b64encode(f.read()).decode()
def list_dates(self) -> list:
"""List all date folders."""
dates = []
for d in sorted(DATA_DIR.iterdir(), reverse=True):
if d.is_dir() and not d.name.startswith('.'):
count = len(list(d.glob("*.png")))
if count > 0:
dates.append({"date": d.name, "count": count})
return dates
cloud_storage = CloudStorage()