ImageDeduper / app /services /image_selector.py
basilbenny1002's picture
Upload 9 files
dd7bc5d verified
from __future__ import annotations
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Dict, Any
import numpy as np
import torch
from PIL import Image
from torchvision import models, transforms
from app.repositories.embeddings import EmbeddingsRepository
from app.core.config import settings
@dataclass
class SelectionResult:
kept: List[str]
removed: List[str]
class ImageSelectorService:
def __init__(self, db_repo: Optional[EmbeddingsRepository] = None, device: Optional[str] = None):
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
# Load feature extractor (ResNet50 without classifier)
base = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
self.feature_extractor = torch.nn.Sequential(*list(base.children())[:-1]).to(self.device)
self.feature_extractor.eval()
self.transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Note: repository is now user-scoped and created within choose_best.
# db_repo parameter is ignored to avoid accidental cross-user sharing.
# Lazy-load aesthetics predictor to avoid dependency unless needed
self._predictor = None
self._processor = None
# Progress tracking per user
# { user_id: { stage: int, percentage: int, eta_seconds: Optional[int], status: str, ... } }
self._progress: Dict[str, Dict[str, Any]] = {}
def _ensure_aesthetics(self):
if self._predictor is None or self._processor is None:
from transformers import CLIPProcessor
from aesthetics_predictor import AestheticsPredictorV1
model_id = "shunk031/aesthetics-predictor-v1-vit-large-patch14"
self._predictor = AestheticsPredictorV1.from_pretrained(model_id).to(self.device)
self._processor = CLIPProcessor.from_pretrained(model_id)
def embed_image(self, image_path: Path) -> np.ndarray:
img = Image.open(image_path).convert("RGB")
tensor = self.transform(img).unsqueeze(0).to(self.device)
with torch.no_grad():
emb = self.feature_extractor(tensor).squeeze().detach().cpu().numpy().astype(np.float32)
norm = np.linalg.norm(emb)
return emb / max(norm, 1e-8)
def add_image(self, image_path: Path, repo: EmbeddingsRepository) -> None:
emb = self.embed_image(image_path)
repo.upsert(str(image_path), emb.tobytes())
def predict_aesthetic(self, image_path: Path) -> float:
self._ensure_aesthetics()
img = Image.open(image_path).convert("RGB")
inputs = self._processor(images=img, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self._predictor(**inputs)
return float(outputs.logits[0].item())
def find_similar(self, query_image: Path, threshold: float, repo: EmbeddingsRepository) -> List[str]:
q = self.embed_image(query_image)
entries = repo.list_all()
similar = []
for path, emb_blob in entries:
emb = np.frombuffer(emb_blob, dtype=np.float32)
sim = float(np.dot(q, emb))
if sim >= threshold:
similar.append(path)
return similar
def choose_best(self, user_id: str, input_dir: Path, output_dir: Path, similarity: float = 0.87, use_aesthetics: bool = True) -> SelectionResult:
input_dir = Path(input_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Prepare file list once
files = [input_dir / f for f in os.listdir(input_dir)]
files = [fp for fp in files if fp.is_file()]
total = len(files)
# Create user-scoped embeddings repository
repo = EmbeddingsRepository(settings.user_db_path(user_id))
# Stage 1: embeddings indexing
import time as _time
stage1_start = _time.time()
processed1 = 0
self._progress[user_id] = {
"stage": 1,
"percentage": 0,
"eta_seconds": None,
"status": "indexing",
"total_stage1": total,
"processed_stage1": 0,
"total_stage2": total,
"processed_stage2": 0,
}
for fp in files:
try:
self.add_image(fp, repo)
except Exception:
pass
processed1 += 1
elapsed = max(_time.time() - stage1_start, 1e-6)
rate = processed1 / elapsed
remaining = max(total - processed1, 0)
eta = int(remaining / rate) if rate > 0 else None
self._progress[user_id].update(
{
"stage": 1,
"percentage": int((processed1 / max(total, 1)) * 100),
"eta_seconds": eta,
"processed_stage1": processed1,
}
)
kept: List[str] = []
removed: List[str] = []
# Stage 2: selection
stage2_start = _time.time()
processed2 = 0
self._progress[user_id].update({"stage": 2, "percentage": 0, "eta_seconds": None, "status": "selecting"})
i = 0
for fp in files:
try:
similar = self.find_similar(fp, threshold=similarity, repo=repo)
except Exception:
similar = []
# Remove found images from DB immediately to avoid regrouping in later iterations
try:
if similar:
repo.delete_many(similar)
except Exception:
pass
i += 1
best_score = -1e9
best_path: Optional[str] = None
temp_dir = input_dir / str(i)
temp_dir.mkdir(exist_ok=True)
for path in similar:
path_p = Path(path)
try:
score = self.predict_aesthetic(path_p) if use_aesthetics else 0.0
except Exception:
score = 0.0
if score > best_score:
best_score = score
best_path = path
# copy to group folder for inspection
try:
dest = temp_dir / path_p.name
if not dest.exists():
dest.write_bytes(Path(path).read_bytes())
except Exception:
pass
if best_path:
# copy best to output and delete from input
try:
bp = Path(best_path)
(output_dir / bp.name).write_bytes(bp.read_bytes())
kept.append(best_path)
try:
bp.unlink()
except Exception:
pass
except Exception:
pass
# remove the rest
for path in similar:
if path != best_path:
try:
Path(path).unlink()
removed.append(path)
except Exception:
pass
# Update progress for stage 2
processed2 += 1
elapsed2 = max(_time.time() - stage2_start, 1e-6)
rate2 = processed2 / elapsed2
remaining2 = max(total - processed2, 0)
eta2 = int(remaining2 / rate2) if rate2 > 0 else None
self._progress[user_id].update(
{
"stage": 2,
"percentage": int((processed2 / max(total, 1)) * 100),
"eta_seconds": eta2,
"processed_stage2": processed2,
}
)
# Completed
self._progress[user_id].update({"stage": 2, "percentage": 100, "eta_seconds": 0, "status": "completed"})
# Ensure DB is closed before returning so the file can be deleted on Windows
try:
repo.close()
except Exception:
pass
return SelectionResult(kept=kept, removed=removed)
def get_progress(self, user_id: str) -> Dict[str, Any]:
return dict(
self._progress.get(
user_id,
{"stage": 0, "percentage": 0, "eta_seconds": None, "status": "idle"},
)
)