File size: 5,357 Bytes
1e3b83e
73b6364
 
1e3b83e
73b6364
 
 
 
 
 
 
 
1e3b83e
73b6364
 
 
1e3b83e
73b6364
 
 
 
1e3b83e
 
 
73b6364
ee62efe
 
 
 
 
 
 
 
 
1e3b83e
 
 
73b6364
 
1e3b83e
 
d500931
 
 
 
 
1e3b83e
 
 
 
 
 
d500931
1e3b83e
 
 
73b6364
 
1e3b83e
ee62efe
d500931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee62efe
1e3b83e
d500931
1e3b83e
 
d500931
 
1e3b83e
d500931
1e3b83e
d500931
ee62efe
1e3b83e
 
d500931
 
 
 
1e3b83e
 
ee62efe
1e3b83e
d500931
1e3b83e
 
d500931
1e3b83e
 
 
 
 
d500931
ee62efe
73b6364
 
 
d500931
73b6364
d500931
 
73b6364
ee62efe
73b6364
1e3b83e
ee62efe
1e3b83e
73b6364
 
 
 
ee62efe
73b6364
d500931
 
 
73b6364
 
d500931
 
1e3b83e
73b6364
1e3b83e
 
73b6364
 
1e3b83e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import torch
import torch.nn as nn
import numpy as np
import clip
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[info] device: {DEVICE}")

print("[info] Loading CLIP ViT-L/14 ...")
clip_model, preprocess = clip.load("ViT-L/14", device=DEVICE)
clip_model.eval()

print("[info] Downloading aesthetic-classifier checkpoint ...")
ckpt_path = hf_hub_download(
    repo_id="purplesmartai/aesthetic-classifier",
    filename="v2.ckpt",
)
checkpoint_data = torch.load(ckpt_path, map_location=DEVICE)
state_dict = checkpoint_data["state_dict"]
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}

aesthetic_model = nn.Sequential(
    nn.Linear(768, 1024),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, 1),
).to(DEVICE)
aesthetic_model.load_state_dict(state_dict)
aesthetic_model.eval()
print("[info] Model ready.")


@torch.no_grad()
def get_score(image: Image.Image) -> float:
    t = preprocess(image.convert("RGB")).unsqueeze(0).to(DEVICE)
    feat = clip_model.encode_image(t).cpu().numpy().astype("float32")
    norm = np.linalg.norm(feat, axis=1, keepdims=True)
    feat = feat / np.where(norm == 0, 1, norm)
    return aesthetic_model(torch.tensor(feat, device=DEVICE)).item()


def raw_to_pony(raw: float) -> int:
    return int(max(0.0, min(0.99, raw)) * 10)


COLOURS = [
    "#c0392b", "#e74c3c", "#e67e22", "#f39c12", "#d4ac0d",
    "#27ae60", "#1e8449", "#148f77", "#0e6655", "#0a4f42",
]


def build_html(raw: float) -> str:
    pony   = raw_to_pony(raw)
    colour = COLOURS[pony]

    # Two rows of 5 so the grid never overflows
    rows = []
    for row_start in (0, 5):
        cells = ""
        for i in range(row_start, row_start + 5):
            active = i == pony
            bg     = COLOURS[i] if active else "rgba(255,255,255,0.07)"
            border = f"2px solid {COLOURS[i]}" if active else "2px solid rgba(255,255,255,0.12)"
            weight = "700" if active else "400"
            scale  = "scale(1.08)" if active else "scale(1)"
            opac   = "1" if active else "0.5"
            cells += (
                f'<div style="background:{bg};border:{border};border-radius:8px;'
                f'padding:9px 4px;text-align:center;font-size:.78rem;font-weight:{weight};'
                f'color:#fff;transform:{scale};opacity:{opac};transition:all .2s;'
                f'user-select:none;white-space:nowrap;">'
                f"score_{i}</div>"
            )
        rows.append(
            f'<div style="display:grid;grid-template-columns:repeat(5,1fr);gap:5px;margin-bottom:5px;">'
            f"{cells}</div>"
        )

    bar_w = min(max(raw, 0.0), 1.0) * 100
    return f"""
<div style="font-family:'Inter',sans-serif;padding:8px 0;">

  <div style="text-align:center;margin-bottom:18px;">
    <div style="display:inline-block;background:{colour};color:#fff;border-radius:12px;
        padding:12px 32px;font-size:1.9rem;font-weight:800;letter-spacing:.04em;
        box-shadow:0 4px 20px {colour}66;">score_{pony}</div>
    <div style="color:#aaa;font-size:.82rem;margin-top:7px;">
      raw: <code style="color:#ddd">{raw:.4f}</code>
    </div>
  </div>

  {"".join(rows)}

  <div style="background:rgba(255,255,255,.1);border-radius:6px;height:7px;overflow:hidden;margin-top:8px;">
    <div style="width:{bar_w:.1f}%;height:100%;
        background:linear-gradient(90deg,#c0392b,#f39c12,#27ae60);
        border-radius:6px;"></div>
  </div>
  <div style="display:flex;justify-content:space-between;font-size:.7rem;color:#666;margin-top:4px;">
    <span>score_0</span><span>score_9</span>
  </div>

</div>"""


def classify(image):
    if image is None:
        return "<p style='color:#888;text-align:center;padding:40px 0'>Upload an image to score it.</p>"
    return build_html(get_score(image))


with gr.Blocks(
    title="Aesthetic Classifier — PurpleSmartAI",
    theme=gr.themes.Soft(primary_hue="purple"),
    css=".gradio-container{max-width:860px!important;margin:auto}"
        " #title{text-align:center} #sub{text-align:center;color:#888;font-size:.9rem;margin-bottom:1.4rem}",
) as demo:
    gr.Markdown("# 🎨 Aesthetic Classifier", elem_id="title")
    gr.Markdown(
        "CLIP ViT-L/14 regression model by **PurpleSmartAI** for Pony V7 captioning. "
        "Outputs a **score_0…score_9** tag used directly in training captions.",
        elem_id="sub",
    )
    with gr.Row():
        with gr.Column(scale=1):
            img_input = gr.Image(type="pil", label="Input Image", height=340)
            run_btn   = gr.Button("✨ Score image", variant="primary", size="lg")
        with gr.Column(scale=1):
            out_html = gr.HTML(
                value="<p style='color:#888;text-align:center;padding:40px 0'>"
                      "Upload an image to see its score.</p>",
            )
    gr.Markdown(
        "---\n**Model:** [`purplesmartai/aesthetic-classifier`]"
        "(https://huggingface.co/purplesmartai/aesthetic-classifier)"
        " · **Backbone:** OpenAI CLIP ViT-L/14"
    )
    run_btn.click(fn=classify, inputs=img_input, outputs=out_html)
    img_input.change(fn=classify, inputs=img_input, outputs=out_html)

if __name__ == "__main__":
    demo.launch()