File size: 5,442 Bytes
9614331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Aesthetic scoring — simple inference interface.

Usage:
    from predict import AestheticScorer

    scorer = AestheticScorer.from_pretrained("somepago/aes26")
    score  = scorer.rate("photo.jpg")          # float 1-10
    scores = scorer.rate(["a.jpg", "b.jpg"])   # list of floats

Or with a local checkpoint:
    scorer = AestheticScorer.from_local("checkpoints/.../best.pt")
"""

from __future__ import annotations

import sys
from pathlib import Path
from typing import Union

import torch
import torch.nn.functional as F
from PIL import Image

# ---------------------------------------------------------------------------
# Allow running from repo root or after `pip install` via HF snapshot
# ---------------------------------------------------------------------------
_HERE = Path(__file__).parent
if str(_HERE) not in sys.path:
    sys.path.insert(0, str(_HERE))

from naflex import preprocess_image, naflex_collate
from model import AestheticModel


class AestheticScorer:
    """Scores images on a 1-10 aesthetic scale."""

    def __init__(self, model: AestheticModel, device: torch.device):
        self.model = model
        self.device = device

    # ------------------------------------------------------------------
    # Constructors
    # ------------------------------------------------------------------

    @classmethod
    def from_pretrained(
        cls,
        repo_id: str = "somepago/aes26",
        filename: str = "best.pt",
        device: str | None = None,
    ) -> "AestheticScorer":
        """Download weights from Hugging Face Hub and load model."""
        from huggingface_hub import hf_hub_download

        ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
        return cls.from_local(ckpt_path, device=device)

    @classmethod
    def from_local(
        cls,
        ckpt_path: str,
        device: str | None = None,
    ) -> "AestheticScorer":
        """Load model from a local checkpoint path."""
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        dev = torch.device(device)

        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
        config = ckpt["config"]

        # Support checkpoints that saved EMA weights under ema_state_dict
        state_key = "ema_state_dict" if "ema_state_dict" in ckpt else "model_state_dict"

        model = AestheticModel(config)
        model.load_state_dict(ckpt[state_key])
        model.eval().to(dev)

        return cls(model, dev)

    # ------------------------------------------------------------------
    # Inference
    # ------------------------------------------------------------------

    @torch.inference_mode()
    def rate(
        self,
        images: Union[str, Path, Image.Image, list],
        batch_size: int = 32,
    ) -> Union[float, list[float]]:
        """Score one or more images.

        Parameters
        ----------
        images : path, PIL Image, or list of either
        batch_size : how many images to process at once

        Returns
        -------
        float if a single image was passed, list[float] for a list
        """
        single = not isinstance(images, list)
        if single:
            images = [images]

        scores: list[float] = []
        for i in range(0, len(images), batch_size):
            batch_imgs = images[i : i + batch_size]
            items = []
            for img in batch_imgs:
                if not isinstance(img, Image.Image):
                    img = Image.open(img).convert("RGB")
                else:
                    img = img.convert("RGB")
                patches, grid = preprocess_image(img)
                items.append({"patches": patches, "grid": grid, "score": 0.0})

            collated = naflex_collate(items)
            with torch.amp.autocast("cuda", dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
                logits = self.model(
                    collated["patches"].to(self.device),
                    collated["spatial_shapes"].to(self.device),
                    collated["attention_mask"].to(self.device),
                )
            batch_scores = self.model.logits_to_score(logits).cpu().tolist()
            if isinstance(batch_scores, float):
                batch_scores = [batch_scores]
            scores.extend(batch_scores)

        return round(scores[0], 2) if single else [round(s, 2) for s in scores]


# ---------------------------------------------------------------------------
# CLI: python predict.py image1.jpg image2.jpg ...
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Score images aesthetically (1-10)")
    parser.add_argument("images", nargs="+", help="Image paths to score")
    parser.add_argument("--repo", default="somepago/aes26", help="HF repo or local checkpoint")
    parser.add_argument("--device", default=None, help="cuda / cpu")
    args = parser.parse_args()

    if Path(args.repo).exists():
        scorer = AestheticScorer.from_local(args.repo, device=args.device)
    else:
        scorer = AestheticScorer.from_pretrained(args.repo, device=args.device)

    scores = scorer.rate(args.images)
    if not isinstance(scores, list):
        scores = [scores]

    for path, score in zip(args.images, scores):
        print(f"{score:.2f}  {path}")