""" Aesthetic scoring — simple inference interface. Usage: from predict import AestheticScorer scorer = AestheticScorer.from_pretrained("somepago/aes26") score = scorer.rate("photo.jpg") # float 1-10 scores = scorer.rate(["a.jpg", "b.jpg"]) # list of floats Or with a local checkpoint: scorer = AestheticScorer.from_local("checkpoints/.../best.pt") """ from __future__ import annotations import sys from pathlib import Path from typing import Union import torch import torch.nn.functional as F from PIL import Image # --------------------------------------------------------------------------- # Allow running from repo root or after `pip install` via HF snapshot # --------------------------------------------------------------------------- _HERE = Path(__file__).parent if str(_HERE) not in sys.path: sys.path.insert(0, str(_HERE)) from naflex import preprocess_image, naflex_collate from model import AestheticModel class AestheticScorer: """Scores images on a 1-10 aesthetic scale.""" def __init__(self, model: AestheticModel, device: torch.device): self.model = model self.device = device # ------------------------------------------------------------------ # Constructors # ------------------------------------------------------------------ @classmethod def from_pretrained( cls, repo_id: str = "somepago/aes26", filename: str = "best.pt", device: str | None = None, ) -> "AestheticScorer": """Download weights from Hugging Face Hub and load model.""" from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) return cls.from_local(ckpt_path, device=device) @classmethod def from_local( cls, ckpt_path: str, device: str | None = None, ) -> "AestheticScorer": """Load model from a local checkpoint path.""" if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" dev = torch.device(device) ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) config = ckpt["config"] # Support checkpoints that saved EMA weights under ema_state_dict state_key = "ema_state_dict" if "ema_state_dict" in ckpt else "model_state_dict" model = AestheticModel(config) model.load_state_dict(ckpt[state_key]) model.eval().to(dev) return cls(model, dev) # ------------------------------------------------------------------ # Inference # ------------------------------------------------------------------ @torch.inference_mode() def rate( self, images: Union[str, Path, Image.Image, list], batch_size: int = 32, ) -> Union[float, list[float]]: """Score one or more images. Parameters ---------- images : path, PIL Image, or list of either batch_size : how many images to process at once Returns ------- float if a single image was passed, list[float] for a list """ single = not isinstance(images, list) if single: images = [images] scores: list[float] = [] for i in range(0, len(images), batch_size): batch_imgs = images[i : i + batch_size] items = [] for img in batch_imgs: if not isinstance(img, Image.Image): img = Image.open(img).convert("RGB") else: img = img.convert("RGB") patches, grid = preprocess_image(img) items.append({"patches": patches, "grid": grid, "score": 0.0}) collated = naflex_collate(items) with torch.amp.autocast("cuda", dtype=torch.bfloat16, enabled=self.device.type == "cuda"): logits = self.model( collated["patches"].to(self.device), collated["spatial_shapes"].to(self.device), collated["attention_mask"].to(self.device), ) batch_scores = self.model.logits_to_score(logits).cpu().tolist() if isinstance(batch_scores, float): batch_scores = [batch_scores] scores.extend(batch_scores) return round(scores[0], 2) if single else [round(s, 2) for s in scores] # --------------------------------------------------------------------------- # CLI: python predict.py image1.jpg image2.jpg ... # --------------------------------------------------------------------------- if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Score images aesthetically (1-10)") parser.add_argument("images", nargs="+", help="Image paths to score") parser.add_argument("--repo", default="somepago/aes26", help="HF repo or local checkpoint") parser.add_argument("--device", default=None, help="cuda / cpu") args = parser.parse_args() if Path(args.repo).exists(): scorer = AestheticScorer.from_local(args.repo, device=args.device) else: scorer = AestheticScorer.from_pretrained(args.repo, device=args.device) scores = scorer.rate(args.images) if not isinstance(scores, list): scores = [scores] for path, score in zip(args.images, scores): print(f"{score:.2f} {path}")