| """ |
| Aesthetic scoring — simple inference interface. |
| |
| Usage: |
| from predict import AestheticScorer |
| |
| scorer = AestheticScorer.from_pretrained("somepago/aes26") |
| score = scorer.rate("photo.jpg") # float 1-10 |
| scores = scorer.rate(["a.jpg", "b.jpg"]) # list of floats |
| |
| Or with a local checkpoint: |
| scorer = AestheticScorer.from_local("checkpoints/.../best.pt") |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
| from typing import Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
|
|
| |
| |
| |
| _HERE = Path(__file__).parent |
| if str(_HERE) not in sys.path: |
| sys.path.insert(0, str(_HERE)) |
|
|
| from naflex import preprocess_image, naflex_collate |
| from model import AestheticModel |
|
|
|
|
| class AestheticScorer: |
| """Scores images on a 1-10 aesthetic scale.""" |
|
|
| def __init__(self, model: AestheticModel, device: torch.device): |
| self.model = model |
| self.device = device |
|
|
| |
| |
| |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| repo_id: str = "somepago/aes26", |
| filename: str = "best.pt", |
| device: str | None = None, |
| ) -> "AestheticScorer": |
| """Download weights from Hugging Face Hub and load model.""" |
| from huggingface_hub import hf_hub_download |
|
|
| ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) |
| return cls.from_local(ckpt_path, device=device) |
|
|
| @classmethod |
| def from_local( |
| cls, |
| ckpt_path: str, |
| device: str | None = None, |
| ) -> "AestheticScorer": |
| """Load model from a local checkpoint path.""" |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| dev = torch.device(device) |
|
|
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| config = ckpt["config"] |
|
|
| |
| state_key = "ema_state_dict" if "ema_state_dict" in ckpt else "model_state_dict" |
|
|
| model = AestheticModel(config) |
| model.load_state_dict(ckpt[state_key]) |
| model.eval().to(dev) |
|
|
| return cls(model, dev) |
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def rate( |
| self, |
| images: Union[str, Path, Image.Image, list], |
| batch_size: int = 32, |
| ) -> Union[float, list[float]]: |
| """Score one or more images. |
| |
| Parameters |
| ---------- |
| images : path, PIL Image, or list of either |
| batch_size : how many images to process at once |
| |
| Returns |
| ------- |
| float if a single image was passed, list[float] for a list |
| """ |
| single = not isinstance(images, list) |
| if single: |
| images = [images] |
|
|
| scores: list[float] = [] |
| for i in range(0, len(images), batch_size): |
| batch_imgs = images[i : i + batch_size] |
| items = [] |
| for img in batch_imgs: |
| if not isinstance(img, Image.Image): |
| img = Image.open(img).convert("RGB") |
| else: |
| img = img.convert("RGB") |
| patches, grid = preprocess_image(img) |
| items.append({"patches": patches, "grid": grid, "score": 0.0}) |
|
|
| collated = naflex_collate(items) |
| with torch.amp.autocast("cuda", dtype=torch.bfloat16, enabled=self.device.type == "cuda"): |
| logits = self.model( |
| collated["patches"].to(self.device), |
| collated["spatial_shapes"].to(self.device), |
| collated["attention_mask"].to(self.device), |
| ) |
| batch_scores = self.model.logits_to_score(logits).cpu().tolist() |
| if isinstance(batch_scores, float): |
| batch_scores = [batch_scores] |
| scores.extend(batch_scores) |
|
|
| return round(scores[0], 2) if single else [round(s, 2) for s in scores] |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Score images aesthetically (1-10)") |
| parser.add_argument("images", nargs="+", help="Image paths to score") |
| parser.add_argument("--repo", default="somepago/aes26", help="HF repo or local checkpoint") |
| parser.add_argument("--device", default=None, help="cuda / cpu") |
| args = parser.parse_args() |
|
|
| if Path(args.repo).exists(): |
| scorer = AestheticScorer.from_local(args.repo, device=args.device) |
| else: |
| scorer = AestheticScorer.from_pretrained(args.repo, device=args.device) |
|
|
| scores = scorer.rate(args.images) |
| if not isinstance(scores, list): |
| scores = [scores] |
|
|
| for path, score in zip(args.images, scores): |
| print(f"{score:.2f} {path}") |
|
|