Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Optional, Dict, Any | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| from app.repositories.embeddings import EmbeddingsRepository | |
| from app.core.config import settings | |
| class SelectionResult: | |
| kept: List[str] | |
| removed: List[str] | |
| class ImageSelectorService: | |
| def __init__(self, db_repo: Optional[EmbeddingsRepository] = None, device: Optional[str] = None): | |
| self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) | |
| # Load feature extractor (ResNet50 without classifier) | |
| base = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| self.feature_extractor = torch.nn.Sequential(*list(base.children())[:-1]).to(self.device) | |
| self.feature_extractor.eval() | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| # Note: repository is now user-scoped and created within choose_best. | |
| # db_repo parameter is ignored to avoid accidental cross-user sharing. | |
| # Lazy-load aesthetics predictor to avoid dependency unless needed | |
| self._predictor = None | |
| self._processor = None | |
| # Progress tracking per user | |
| # { user_id: { stage: int, percentage: int, eta_seconds: Optional[int], status: str, ... } } | |
| self._progress: Dict[str, Dict[str, Any]] = {} | |
| def _ensure_aesthetics(self): | |
| if self._predictor is None or self._processor is None: | |
| from transformers import CLIPProcessor | |
| from aesthetics_predictor import AestheticsPredictorV1 | |
| model_id = "shunk031/aesthetics-predictor-v1-vit-large-patch14" | |
| self._predictor = AestheticsPredictorV1.from_pretrained(model_id).to(self.device) | |
| self._processor = CLIPProcessor.from_pretrained(model_id) | |
| def embed_image(self, image_path: Path) -> np.ndarray: | |
| img = Image.open(image_path).convert("RGB") | |
| tensor = self.transform(img).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| emb = self.feature_extractor(tensor).squeeze().detach().cpu().numpy().astype(np.float32) | |
| norm = np.linalg.norm(emb) | |
| return emb / max(norm, 1e-8) | |
| def add_image(self, image_path: Path, repo: EmbeddingsRepository) -> None: | |
| emb = self.embed_image(image_path) | |
| repo.upsert(str(image_path), emb.tobytes()) | |
| def predict_aesthetic(self, image_path: Path) -> float: | |
| self._ensure_aesthetics() | |
| img = Image.open(image_path).convert("RGB") | |
| inputs = self._processor(images=img, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self._predictor(**inputs) | |
| return float(outputs.logits[0].item()) | |
| def find_similar(self, query_image: Path, threshold: float, repo: EmbeddingsRepository) -> List[str]: | |
| q = self.embed_image(query_image) | |
| entries = repo.list_all() | |
| similar = [] | |
| for path, emb_blob in entries: | |
| emb = np.frombuffer(emb_blob, dtype=np.float32) | |
| sim = float(np.dot(q, emb)) | |
| if sim >= threshold: | |
| similar.append(path) | |
| return similar | |
| def choose_best(self, user_id: str, input_dir: Path, output_dir: Path, similarity: float = 0.87, use_aesthetics: bool = True) -> SelectionResult: | |
| input_dir = Path(input_dir) | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Prepare file list once | |
| files = [input_dir / f for f in os.listdir(input_dir)] | |
| files = [fp for fp in files if fp.is_file()] | |
| total = len(files) | |
| # Create user-scoped embeddings repository | |
| repo = EmbeddingsRepository(settings.user_db_path(user_id)) | |
| # Stage 1: embeddings indexing | |
| import time as _time | |
| stage1_start = _time.time() | |
| processed1 = 0 | |
| self._progress[user_id] = { | |
| "stage": 1, | |
| "percentage": 0, | |
| "eta_seconds": None, | |
| "status": "indexing", | |
| "total_stage1": total, | |
| "processed_stage1": 0, | |
| "total_stage2": total, | |
| "processed_stage2": 0, | |
| } | |
| for fp in files: | |
| try: | |
| self.add_image(fp, repo) | |
| except Exception: | |
| pass | |
| processed1 += 1 | |
| elapsed = max(_time.time() - stage1_start, 1e-6) | |
| rate = processed1 / elapsed | |
| remaining = max(total - processed1, 0) | |
| eta = int(remaining / rate) if rate > 0 else None | |
| self._progress[user_id].update( | |
| { | |
| "stage": 1, | |
| "percentage": int((processed1 / max(total, 1)) * 100), | |
| "eta_seconds": eta, | |
| "processed_stage1": processed1, | |
| } | |
| ) | |
| kept: List[str] = [] | |
| removed: List[str] = [] | |
| # Stage 2: selection | |
| stage2_start = _time.time() | |
| processed2 = 0 | |
| self._progress[user_id].update({"stage": 2, "percentage": 0, "eta_seconds": None, "status": "selecting"}) | |
| i = 0 | |
| for fp in files: | |
| try: | |
| similar = self.find_similar(fp, threshold=similarity, repo=repo) | |
| except Exception: | |
| similar = [] | |
| # Remove found images from DB immediately to avoid regrouping in later iterations | |
| try: | |
| if similar: | |
| repo.delete_many(similar) | |
| except Exception: | |
| pass | |
| i += 1 | |
| best_score = -1e9 | |
| best_path: Optional[str] = None | |
| temp_dir = input_dir / str(i) | |
| temp_dir.mkdir(exist_ok=True) | |
| for path in similar: | |
| path_p = Path(path) | |
| try: | |
| score = self.predict_aesthetic(path_p) if use_aesthetics else 0.0 | |
| except Exception: | |
| score = 0.0 | |
| if score > best_score: | |
| best_score = score | |
| best_path = path | |
| # copy to group folder for inspection | |
| try: | |
| dest = temp_dir / path_p.name | |
| if not dest.exists(): | |
| dest.write_bytes(Path(path).read_bytes()) | |
| except Exception: | |
| pass | |
| if best_path: | |
| # copy best to output and delete from input | |
| try: | |
| bp = Path(best_path) | |
| (output_dir / bp.name).write_bytes(bp.read_bytes()) | |
| kept.append(best_path) | |
| try: | |
| bp.unlink() | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| # remove the rest | |
| for path in similar: | |
| if path != best_path: | |
| try: | |
| Path(path).unlink() | |
| removed.append(path) | |
| except Exception: | |
| pass | |
| # Update progress for stage 2 | |
| processed2 += 1 | |
| elapsed2 = max(_time.time() - stage2_start, 1e-6) | |
| rate2 = processed2 / elapsed2 | |
| remaining2 = max(total - processed2, 0) | |
| eta2 = int(remaining2 / rate2) if rate2 > 0 else None | |
| self._progress[user_id].update( | |
| { | |
| "stage": 2, | |
| "percentage": int((processed2 / max(total, 1)) * 100), | |
| "eta_seconds": eta2, | |
| "processed_stage2": processed2, | |
| } | |
| ) | |
| # Completed | |
| self._progress[user_id].update({"stage": 2, "percentage": 100, "eta_seconds": 0, "status": "completed"}) | |
| # Ensure DB is closed before returning so the file can be deleted on Windows | |
| try: | |
| repo.close() | |
| except Exception: | |
| pass | |
| return SelectionResult(kept=kept, removed=removed) | |
| def get_progress(self, user_id: str) -> Dict[str, Any]: | |
| return dict( | |
| self._progress.get( | |
| user_id, | |
| {"stage": 0, "percentage": 0, "eta_seconds": None, "status": "idle"}, | |
| ) | |
| ) | |