File size: 5,442 Bytes
9614331 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """
Aesthetic scoring — simple inference interface.
Usage:
from predict import AestheticScorer
scorer = AestheticScorer.from_pretrained("somepago/aes26")
score = scorer.rate("photo.jpg") # float 1-10
scores = scorer.rate(["a.jpg", "b.jpg"]) # list of floats
Or with a local checkpoint:
scorer = AestheticScorer.from_local("checkpoints/.../best.pt")
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Union
import torch
import torch.nn.functional as F
from PIL import Image
# ---------------------------------------------------------------------------
# Allow running from repo root or after `pip install` via HF snapshot
# ---------------------------------------------------------------------------
_HERE = Path(__file__).parent
if str(_HERE) not in sys.path:
sys.path.insert(0, str(_HERE))
from naflex import preprocess_image, naflex_collate
from model import AestheticModel
class AestheticScorer:
"""Scores images on a 1-10 aesthetic scale."""
def __init__(self, model: AestheticModel, device: torch.device):
self.model = model
self.device = device
# ------------------------------------------------------------------
# Constructors
# ------------------------------------------------------------------
@classmethod
def from_pretrained(
cls,
repo_id: str = "somepago/aes26",
filename: str = "best.pt",
device: str | None = None,
) -> "AestheticScorer":
"""Download weights from Hugging Face Hub and load model."""
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
return cls.from_local(ckpt_path, device=device)
@classmethod
def from_local(
cls,
ckpt_path: str,
device: str | None = None,
) -> "AestheticScorer":
"""Load model from a local checkpoint path."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dev = torch.device(device)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
config = ckpt["config"]
# Support checkpoints that saved EMA weights under ema_state_dict
state_key = "ema_state_dict" if "ema_state_dict" in ckpt else "model_state_dict"
model = AestheticModel(config)
model.load_state_dict(ckpt[state_key])
model.eval().to(dev)
return cls(model, dev)
# ------------------------------------------------------------------
# Inference
# ------------------------------------------------------------------
@torch.inference_mode()
def rate(
self,
images: Union[str, Path, Image.Image, list],
batch_size: int = 32,
) -> Union[float, list[float]]:
"""Score one or more images.
Parameters
----------
images : path, PIL Image, or list of either
batch_size : how many images to process at once
Returns
-------
float if a single image was passed, list[float] for a list
"""
single = not isinstance(images, list)
if single:
images = [images]
scores: list[float] = []
for i in range(0, len(images), batch_size):
batch_imgs = images[i : i + batch_size]
items = []
for img in batch_imgs:
if not isinstance(img, Image.Image):
img = Image.open(img).convert("RGB")
else:
img = img.convert("RGB")
patches, grid = preprocess_image(img)
items.append({"patches": patches, "grid": grid, "score": 0.0})
collated = naflex_collate(items)
with torch.amp.autocast("cuda", dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
logits = self.model(
collated["patches"].to(self.device),
collated["spatial_shapes"].to(self.device),
collated["attention_mask"].to(self.device),
)
batch_scores = self.model.logits_to_score(logits).cpu().tolist()
if isinstance(batch_scores, float):
batch_scores = [batch_scores]
scores.extend(batch_scores)
return round(scores[0], 2) if single else [round(s, 2) for s in scores]
# ---------------------------------------------------------------------------
# CLI: python predict.py image1.jpg image2.jpg ...
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Score images aesthetically (1-10)")
parser.add_argument("images", nargs="+", help="Image paths to score")
parser.add_argument("--repo", default="somepago/aes26", help="HF repo or local checkpoint")
parser.add_argument("--device", default=None, help="cuda / cpu")
args = parser.parse_args()
if Path(args.repo).exists():
scorer = AestheticScorer.from_local(args.repo, device=args.device)
else:
scorer = AestheticScorer.from_pretrained(args.repo, device=args.device)
scores = scorer.rate(args.images)
if not isinstance(scores, list):
scores = [scores]
for path, score in zip(args.images, scores):
print(f"{score:.2f} {path}")
|