Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,24 +7,6 @@ import gradio as gr
|
|
| 7 |
from PIL import Image
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
|
| 10 |
-
# ββ Model β exactly as in the Pony V7 Captioner notebook βββββββββββββββββββββββ
|
| 11 |
-
class AestheticScorer(nn.Module):
|
| 12 |
-
def __init__(self, input_size: int = 768):
|
| 13 |
-
super().__init__()
|
| 14 |
-
self.model = nn.Sequential(
|
| 15 |
-
nn.Linear(input_size, 1024),
|
| 16 |
-
nn.ReLU(),
|
| 17 |
-
nn.Dropout(0.5),
|
| 18 |
-
nn.Linear(1024, 512),
|
| 19 |
-
nn.ReLU(),
|
| 20 |
-
nn.Dropout(0.3),
|
| 21 |
-
nn.Linear(512, 1),
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 25 |
-
return self.model(x)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
print(f"[info] device: {DEVICE}")
|
| 30 |
|
|
@@ -39,35 +21,39 @@ ckpt_path = hf_hub_download(
|
|
| 39 |
)
|
| 40 |
checkpoint_data = torch.load(ckpt_path, map_location=DEVICE)
|
| 41 |
state_dict = checkpoint_data["state_dict"]
|
| 42 |
-
# Strip
|
| 43 |
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
aesthetic_model.load_state_dict(state_dict)
|
| 47 |
aesthetic_model.eval()
|
| 48 |
print("[info] Model ready.")
|
| 49 |
|
| 50 |
|
| 51 |
-
# ββ Scoring β identical to notebook ββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
@torch.no_grad()
|
| 53 |
def get_score(image: Image.Image) -> float:
|
| 54 |
-
"""Returns raw float score (typically 0-1 range)."""
|
| 55 |
image_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(DEVICE)
|
| 56 |
features = clip_model.encode_image(image_tensor).cpu().numpy()
|
| 57 |
norm = np.linalg.norm(features, axis=1, keepdims=True)
|
| 58 |
norm[norm == 0] = 1
|
| 59 |
features = features / norm
|
| 60 |
features_t = torch.tensor(features, dtype=torch.float32, device=DEVICE)
|
| 61 |
-
|
| 62 |
-
return raw
|
| 63 |
|
| 64 |
|
| 65 |
def raw_to_pony(raw: float) -> int:
|
| 66 |
-
"""Convert raw score to pony score_0...score_9 (same formula as notebook)."""
|
| 67 |
return int(max(0.0, min(0.99, raw)) * 10)
|
| 68 |
|
| 69 |
|
| 70 |
-
# ββ Colour palette βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
SCORE_COLOURS = [
|
| 72 |
"#c0392b", "#e74c3c", "#e67e22", "#f39c12", "#d4ac0d",
|
| 73 |
"#27ae60", "#1e8449", "#148f77", "#0e6655", "#0a4f42",
|
|
@@ -75,7 +61,7 @@ SCORE_COLOURS = [
|
|
| 75 |
|
| 76 |
|
| 77 |
def build_html(raw: float) -> str:
|
| 78 |
-
pony
|
| 79 |
colour = SCORE_COLOURS[pony]
|
| 80 |
|
| 81 |
tiles_html = ""
|
|
@@ -86,12 +72,13 @@ def build_html(raw: float) -> str:
|
|
| 86 |
weight = "700" if active else "400"
|
| 87 |
scale = "scale(1.12)" if active else "scale(1)"
|
| 88 |
opac = "1" if active else "0.45"
|
| 89 |
-
tiles_html +=
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
|
| 93 |
bar_w = min(raw, 1.0) * 100
|
| 94 |
-
|
| 95 |
return f"""
|
| 96 |
<div style="font-family:'Inter',sans-serif;padding:8px 0;">
|
| 97 |
<div style="text-align:center;margin-bottom:20px;">
|
|
@@ -99,7 +86,7 @@ def build_html(raw: float) -> str:
|
|
| 99 |
padding:14px 36px;font-size:2rem;font-weight:800;letter-spacing:.04em;
|
| 100 |
box-shadow:0 4px 20px {colour}66;">score_{pony}</div>
|
| 101 |
<div style="color:#aaa;font-size:.85rem;margin-top:8px;">
|
| 102 |
-
raw
|
| 103 |
</div>
|
| 104 |
</div>
|
| 105 |
<div style="display:grid;grid-template-columns:repeat(10,1fr);gap:6px;margin-bottom:16px;">
|
|
@@ -108,7 +95,7 @@ def build_html(raw: float) -> str:
|
|
| 108 |
<div style="background:rgba(255,255,255,.1);border-radius:6px;height:8px;overflow:hidden;">
|
| 109 |
<div style="width:{bar_w:.1f}%;height:100%;
|
| 110 |
background:linear-gradient(90deg,#c0392b,#f39c12,#27ae60);
|
| 111 |
-
border-radius:6px;
|
| 112 |
</div>
|
| 113 |
<div style="display:flex;justify-content:space-between;font-size:.72rem;color:#777;margin-top:4px;">
|
| 114 |
<span>score_0</span><span>score_9</span>
|
|
@@ -119,26 +106,24 @@ def build_html(raw: float) -> str:
|
|
| 119 |
def classify(image):
|
| 120 |
if image is None:
|
| 121 |
return "<p style='color:#888;text-align:center'>Upload an image to score it.</p>"
|
| 122 |
-
|
| 123 |
-
return build_html(raw)
|
| 124 |
|
| 125 |
|
| 126 |
-
# ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 127 |
with gr.Blocks(
|
| 128 |
title="Aesthetic Classifier - PurpleSmartAI",
|
| 129 |
theme=gr.themes.Soft(primary_hue="purple"),
|
| 130 |
css=".gradio-container{max-width:860px!important;margin:auto} #title{text-align:center} #sub{text-align:center;color:#888;font-size:.9rem;margin-bottom:1.5rem}",
|
| 131 |
) as demo:
|
| 132 |
-
gr.Markdown("# Aesthetic Classifier", elem_id="title")
|
| 133 |
gr.Markdown(
|
| 134 |
"CLIP ViT-L/14 regression model by **PurpleSmartAI** for Pony V7 captioning. "
|
| 135 |
-
"Outputs a **score_0
|
| 136 |
elem_id="sub",
|
| 137 |
)
|
| 138 |
with gr.Row():
|
| 139 |
with gr.Column(scale=1):
|
| 140 |
img_input = gr.Image(type="pil", label="Input Image", height=340)
|
| 141 |
-
run_btn = gr.Button("Score image", variant="primary", size="lg")
|
| 142 |
with gr.Column(scale=1):
|
| 143 |
out_html = gr.HTML(
|
| 144 |
value="<p style='color:#888;text-align:center;padding:40px 0'>Upload an image to see its score.</p>",
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
print(f"[info] device: {DEVICE}")
|
| 12 |
|
|
|
|
| 21 |
)
|
| 22 |
checkpoint_data = torch.load(ckpt_path, map_location=DEVICE)
|
| 23 |
state_dict = checkpoint_data["state_dict"]
|
| 24 |
+
# Strip "model." prefix β keys become "0.weight", "3.weight", "6.weight"
|
| 25 |
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
|
| 26 |
|
| 27 |
+
# Build Sequential directly so keys match ("0.weight", "3.weight", "6.weight")
|
| 28 |
+
aesthetic_model = nn.Sequential(
|
| 29 |
+
nn.Linear(768, 1024),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
nn.Dropout(0.5),
|
| 32 |
+
nn.Linear(1024, 512),
|
| 33 |
+
nn.ReLU(),
|
| 34 |
+
nn.Dropout(0.3),
|
| 35 |
+
nn.Linear(512, 1),
|
| 36 |
+
).to(DEVICE)
|
| 37 |
aesthetic_model.load_state_dict(state_dict)
|
| 38 |
aesthetic_model.eval()
|
| 39 |
print("[info] Model ready.")
|
| 40 |
|
| 41 |
|
|
|
|
| 42 |
@torch.no_grad()
|
| 43 |
def get_score(image: Image.Image) -> float:
|
|
|
|
| 44 |
image_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(DEVICE)
|
| 45 |
features = clip_model.encode_image(image_tensor).cpu().numpy()
|
| 46 |
norm = np.linalg.norm(features, axis=1, keepdims=True)
|
| 47 |
norm[norm == 0] = 1
|
| 48 |
features = features / norm
|
| 49 |
features_t = torch.tensor(features, dtype=torch.float32, device=DEVICE)
|
| 50 |
+
return aesthetic_model(features_t).item()
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
def raw_to_pony(raw: float) -> int:
|
|
|
|
| 54 |
return int(max(0.0, min(0.99, raw)) * 10)
|
| 55 |
|
| 56 |
|
|
|
|
| 57 |
SCORE_COLOURS = [
|
| 58 |
"#c0392b", "#e74c3c", "#e67e22", "#f39c12", "#d4ac0d",
|
| 59 |
"#27ae60", "#1e8449", "#148f77", "#0e6655", "#0a4f42",
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
def build_html(raw: float) -> str:
|
| 64 |
+
pony = raw_to_pony(raw)
|
| 65 |
colour = SCORE_COLOURS[pony]
|
| 66 |
|
| 67 |
tiles_html = ""
|
|
|
|
| 72 |
weight = "700" if active else "400"
|
| 73 |
scale = "scale(1.12)" if active else "scale(1)"
|
| 74 |
opac = "1" if active else "0.45"
|
| 75 |
+
tiles_html += (
|
| 76 |
+
f'<div style="background:{bg};border:{border};border-radius:8px;'
|
| 77 |
+
f'padding:10px 0;text-align:center;font-size:.82rem;font-weight:{weight};color:#fff;'
|
| 78 |
+
f'transform:{scale};opacity:{opac};transition:all .2s;user-select:none;">score_{i}</div>'
|
| 79 |
+
)
|
| 80 |
|
| 81 |
bar_w = min(raw, 1.0) * 100
|
|
|
|
| 82 |
return f"""
|
| 83 |
<div style="font-family:'Inter',sans-serif;padding:8px 0;">
|
| 84 |
<div style="text-align:center;margin-bottom:20px;">
|
|
|
|
| 86 |
padding:14px 36px;font-size:2rem;font-weight:800;letter-spacing:.04em;
|
| 87 |
box-shadow:0 4px 20px {colour}66;">score_{pony}</div>
|
| 88 |
<div style="color:#aaa;font-size:.85rem;margin-top:8px;">
|
| 89 |
+
raw: <code style="color:#ddd">{raw:.4f}</code>
|
| 90 |
</div>
|
| 91 |
</div>
|
| 92 |
<div style="display:grid;grid-template-columns:repeat(10,1fr);gap:6px;margin-bottom:16px;">
|
|
|
|
| 95 |
<div style="background:rgba(255,255,255,.1);border-radius:6px;height:8px;overflow:hidden;">
|
| 96 |
<div style="width:{bar_w:.1f}%;height:100%;
|
| 97 |
background:linear-gradient(90deg,#c0392b,#f39c12,#27ae60);
|
| 98 |
+
border-radius:6px;"></div>
|
| 99 |
</div>
|
| 100 |
<div style="display:flex;justify-content:space-between;font-size:.72rem;color:#777;margin-top:4px;">
|
| 101 |
<span>score_0</span><span>score_9</span>
|
|
|
|
| 106 |
def classify(image):
|
| 107 |
if image is None:
|
| 108 |
return "<p style='color:#888;text-align:center'>Upload an image to score it.</p>"
|
| 109 |
+
return build_html(get_score(image))
|
|
|
|
| 110 |
|
| 111 |
|
|
|
|
| 112 |
with gr.Blocks(
|
| 113 |
title="Aesthetic Classifier - PurpleSmartAI",
|
| 114 |
theme=gr.themes.Soft(primary_hue="purple"),
|
| 115 |
css=".gradio-container{max-width:860px!important;margin:auto} #title{text-align:center} #sub{text-align:center;color:#888;font-size:.9rem;margin-bottom:1.5rem}",
|
| 116 |
) as demo:
|
| 117 |
+
gr.Markdown("# π¨ Aesthetic Classifier", elem_id="title")
|
| 118 |
gr.Markdown(
|
| 119 |
"CLIP ViT-L/14 regression model by **PurpleSmartAI** for Pony V7 captioning. "
|
| 120 |
+
"Outputs a **score_0β¦score_9** tag used directly in training captions.",
|
| 121 |
elem_id="sub",
|
| 122 |
)
|
| 123 |
with gr.Row():
|
| 124 |
with gr.Column(scale=1):
|
| 125 |
img_input = gr.Image(type="pil", label="Input Image", height=340)
|
| 126 |
+
run_btn = gr.Button("β¨ Score image", variant="primary", size="lg")
|
| 127 |
with gr.Column(scale=1):
|
| 128 |
out_html = gr.HTML(
|
| 129 |
value="<p style='color:#888;text-align:center;padding:40px 0'>Upload an image to see its score.</p>",
|