AestheticSigLIP / predict.py
somepago's picture
Upload predict.py with huggingface_hub
9614331 verified
"""
Aesthetic scoring — simple inference interface.
Usage:
from predict import AestheticScorer
scorer = AestheticScorer.from_pretrained("somepago/aes26")
score = scorer.rate("photo.jpg") # float 1-10
scores = scorer.rate(["a.jpg", "b.jpg"]) # list of floats
Or with a local checkpoint:
scorer = AestheticScorer.from_local("checkpoints/.../best.pt")
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Union
import torch
import torch.nn.functional as F
from PIL import Image
# ---------------------------------------------------------------------------
# Allow running from repo root or after `pip install` via HF snapshot
# ---------------------------------------------------------------------------
_HERE = Path(__file__).parent
if str(_HERE) not in sys.path:
sys.path.insert(0, str(_HERE))
from naflex import preprocess_image, naflex_collate
from model import AestheticModel
class AestheticScorer:
"""Scores images on a 1-10 aesthetic scale."""
def __init__(self, model: AestheticModel, device: torch.device):
self.model = model
self.device = device
# ------------------------------------------------------------------
# Constructors
# ------------------------------------------------------------------
@classmethod
def from_pretrained(
cls,
repo_id: str = "somepago/aes26",
filename: str = "best.pt",
device: str | None = None,
) -> "AestheticScorer":
"""Download weights from Hugging Face Hub and load model."""
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
return cls.from_local(ckpt_path, device=device)
@classmethod
def from_local(
cls,
ckpt_path: str,
device: str | None = None,
) -> "AestheticScorer":
"""Load model from a local checkpoint path."""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dev = torch.device(device)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
config = ckpt["config"]
# Support checkpoints that saved EMA weights under ema_state_dict
state_key = "ema_state_dict" if "ema_state_dict" in ckpt else "model_state_dict"
model = AestheticModel(config)
model.load_state_dict(ckpt[state_key])
model.eval().to(dev)
return cls(model, dev)
# ------------------------------------------------------------------
# Inference
# ------------------------------------------------------------------
@torch.inference_mode()
def rate(
self,
images: Union[str, Path, Image.Image, list],
batch_size: int = 32,
) -> Union[float, list[float]]:
"""Score one or more images.
Parameters
----------
images : path, PIL Image, or list of either
batch_size : how many images to process at once
Returns
-------
float if a single image was passed, list[float] for a list
"""
single = not isinstance(images, list)
if single:
images = [images]
scores: list[float] = []
for i in range(0, len(images), batch_size):
batch_imgs = images[i : i + batch_size]
items = []
for img in batch_imgs:
if not isinstance(img, Image.Image):
img = Image.open(img).convert("RGB")
else:
img = img.convert("RGB")
patches, grid = preprocess_image(img)
items.append({"patches": patches, "grid": grid, "score": 0.0})
collated = naflex_collate(items)
with torch.amp.autocast("cuda", dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
logits = self.model(
collated["patches"].to(self.device),
collated["spatial_shapes"].to(self.device),
collated["attention_mask"].to(self.device),
)
batch_scores = self.model.logits_to_score(logits).cpu().tolist()
if isinstance(batch_scores, float):
batch_scores = [batch_scores]
scores.extend(batch_scores)
return round(scores[0], 2) if single else [round(s, 2) for s in scores]
# ---------------------------------------------------------------------------
# CLI: python predict.py image1.jpg image2.jpg ...
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Score images aesthetically (1-10)")
parser.add_argument("images", nargs="+", help="Image paths to score")
parser.add_argument("--repo", default="somepago/aes26", help="HF repo or local checkpoint")
parser.add_argument("--device", default=None, help="cuda / cpu")
args = parser.parse_args()
if Path(args.repo).exists():
scorer = AestheticScorer.from_local(args.repo, device=args.device)
else:
scorer = AestheticScorer.from_pretrained(args.repo, device=args.device)
scores = scorer.rate(args.images)
if not isinstance(scores, list):
scores = [scores]
for path, score in zip(args.images, scores):
print(f"{score:.2f} {path}")