VOIDER commited on
Commit
73b6364
Β·
verified Β·
1 Parent(s): 23584f6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +162 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import clip
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
+
10
+ # ── Labels ─────────────────────────────────────────────────────────────────────
11
+ # Pony V7 captioning uses 9 aesthetic buckets (worst β†’ best)
12
+ LABELS = [
13
+ "worst quality",
14
+ "very bad quality",
15
+ "bad quality",
16
+ "low quality",
17
+ "normal quality",
18
+ "good quality",
19
+ "high quality",
20
+ "best quality",
21
+ "masterpiece",
22
+ ]
23
+
24
+ # Colour gradient: red β†’ yellow β†’ green
25
+ COLOURS = [
26
+ "#e74c3c", "#e67e22", "#f39c12",
27
+ "#d4ac0d", "#a9cce3", "#27ae60",
28
+ "#1e8449", "#148f77", "#0e6655",
29
+ ]
30
+
31
+ # ── Model ───────────────────────────────────────────────────────────────────────
32
+ class AestheticHead(nn.Module):
33
+ """Small MLP head that sits on top of frozen CLIP image features."""
34
+ def __init__(self, in_features: int = 768, num_classes: int = 9):
35
+ super().__init__()
36
+ self.layers = nn.Sequential(
37
+ nn.Linear(in_features, 1024),
38
+ nn.ReLU(),
39
+ nn.Dropout(0.2),
40
+ nn.Linear(1024, 128),
41
+ nn.ReLU(),
42
+ nn.Dropout(0.2),
43
+ nn.Linear(128, num_classes),
44
+ )
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ return self.layers(x)
48
+
49
+
50
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
51
+ print(f"[info] device: {DEVICE}")
52
+
53
+ # Load CLIP backbone
54
+ print("[info] Loading CLIP ViT-L/14 …")
55
+ clip_model, preprocess = clip.load("ViT-L/14", device=DEVICE)
56
+ clip_model.eval()
57
+
58
+ # Load aesthetic head
59
+ print("[info] Downloading aesthetic-classifier checkpoint …")
60
+ ckpt_path = hf_hub_download(
61
+ repo_id="purplesmartai/aesthetic-classifier",
62
+ filename="v2.ckpt",
63
+ )
64
+ state_dict = torch.load(ckpt_path, map_location=DEVICE)
65
+
66
+ # Auto-detect architecture from checkpoint keys
67
+ first_key = next(iter(state_dict))
68
+ # If keys start with 'layers.' it's our AestheticHead; otherwise try to load directly
69
+ if isinstance(state_dict, dict) and not any(k.startswith("layers") for k in state_dict):
70
+ # Flat state dict β€” try wrapping in 'layers'
71
+ new_sd = {"layers." + k if not k.startswith("layers") else k: v for k, v in state_dict.items()}
72
+ state_dict = new_sd
73
+
74
+ # Detect input size from first weight tensor
75
+ in_feat = 768 # default ViT-L/14
76
+ for k, v in state_dict.items():
77
+ if "weight" in k and v.dim() == 2:
78
+ in_feat = v.shape[1]
79
+ break
80
+
81
+ num_classes = len(LABELS)
82
+ model = AestheticHead(in_features=in_feat, num_classes=num_classes).to(DEVICE)
83
+ try:
84
+ model.load_state_dict(state_dict, strict=True)
85
+ print("[info] Checkpoint loaded (strict).")
86
+ except RuntimeError:
87
+ model.load_state_dict(state_dict, strict=False)
88
+ print("[warn] Checkpoint loaded (non-strict).")
89
+ model.eval()
90
+
91
+
92
+ # ── Inference ───────────────────────────────────────────────────────────────────
93
+ @torch.no_grad()
94
+ def classify(image: Image.Image):
95
+ if image is None:
96
+ return {}
97
+
98
+ # Preprocess & encode with CLIP
99
+ tensor = preprocess(image).unsqueeze(0).to(DEVICE)
100
+ features = clip_model.encode_image(tensor).float()
101
+ features = features / features.norm(dim=-1, keepdim=True)
102
+
103
+ # Run head
104
+ logits = model(features)
105
+ probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
106
+
107
+ # Top prediction
108
+ top_idx = int(np.argmax(probs))
109
+
110
+ result = {label: float(prob) for label, prob in zip(LABELS, probs)}
111
+ return result
112
+
113
+
114
+ # ── Gradio UI ───────────────────────────────────────────────────────────────────
115
+ EXAMPLES = []
116
+ examples_dir = "examples"
117
+ if os.path.isdir(examples_dir):
118
+ EXAMPLES = [[os.path.join(examples_dir, f)] for f in os.listdir(examples_dir)
119
+ if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))]
120
+
121
+ with gr.Blocks(
122
+ title="Aesthetic Classifier β€” PurpleSmartAI",
123
+ theme=gr.themes.Soft(primary_hue="purple"),
124
+ css="""
125
+ .gradio-container { max-width: 900px !important; margin: auto; }
126
+ #title { text-align: center; margin-bottom: 0.5rem; }
127
+ #subtitle { text-align: center; color: #888; margin-bottom: 1.5rem; font-size: 0.95rem; }
128
+ """,
129
+ ) as demo:
130
+ gr.Markdown("# 🎨 Aesthetic Classifier", elem_id="title")
131
+ gr.Markdown(
132
+ "CLIP-based aesthetic quality classifier by **PurpleSmartAI** β€” "
133
+ "originally developed for [Pony V7](https://huggingface.co/purplesmartai/aesthetic-classifier) captioning.\n\n"
134
+ "Upload an image and get a probability distribution across 9 quality tiers.",
135
+ elem_id="subtitle",
136
+ )
137
+
138
+ with gr.Row():
139
+ with gr.Column(scale=1):
140
+ img_input = gr.Image(type="pil", label="Input Image", height=340)
141
+ run_btn = gr.Button("✨ Classify", variant="primary", size="lg")
142
+
143
+ with gr.Column(scale=1):
144
+ label_output = gr.Label(
145
+ num_top_classes=9,
146
+ label="Aesthetic Score Distribution",
147
+ )
148
+
149
+ if EXAMPLES:
150
+ gr.Examples(examples=EXAMPLES, inputs=img_input, label="Example images")
151
+
152
+ gr.Markdown(
153
+ "---\n"
154
+ "**Model:** [`purplesmartai/aesthetic-classifier`](https://huggingface.co/purplesmartai/aesthetic-classifier) Β· "
155
+ "**Backbone:** OpenAI CLIP ViT-L/14"
156
+ )
157
+
158
+ run_btn.click(fn=classify, inputs=img_input, outputs=label_output)
159
+ img_input.change(fn=classify, inputs=img_input, outputs=label_output)
160
+
161
+ if __name__ == "__main__":
162
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ ftfy
5
+ regex
6
+ tqdm
7
+ git+https://github.com/openai/CLIP.git
8
+ huggingface_hub>=0.20.0
9
+ Pillow>=9.0.0
10
+ numpy>=1.24.0