VOIDER commited on
Commit
ee62efe
Β·
verified Β·
1 Parent(s): 1e3b83e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -39
app.py CHANGED
@@ -7,24 +7,6 @@ import gradio as gr
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
 
10
- # ── Model β€” exactly as in the Pony V7 Captioner notebook ───────────────────────
11
- class AestheticScorer(nn.Module):
12
- def __init__(self, input_size: int = 768):
13
- super().__init__()
14
- self.model = nn.Sequential(
15
- nn.Linear(input_size, 1024),
16
- nn.ReLU(),
17
- nn.Dropout(0.5),
18
- nn.Linear(1024, 512),
19
- nn.ReLU(),
20
- nn.Dropout(0.3),
21
- nn.Linear(512, 1),
22
- )
23
-
24
- def forward(self, x: torch.Tensor) -> torch.Tensor:
25
- return self.model(x)
26
-
27
-
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
  print(f"[info] device: {DEVICE}")
30
 
@@ -39,35 +21,39 @@ ckpt_path = hf_hub_download(
39
  )
40
  checkpoint_data = torch.load(ckpt_path, map_location=DEVICE)
41
  state_dict = checkpoint_data["state_dict"]
42
- # Strip the "model." prefix from keys (same as notebook)
43
  state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
44
 
45
- aesthetic_model = AestheticScorer(input_size=768).to(DEVICE)
 
 
 
 
 
 
 
 
 
46
  aesthetic_model.load_state_dict(state_dict)
47
  aesthetic_model.eval()
48
  print("[info] Model ready.")
49
 
50
 
51
- # ── Scoring β€” identical to notebook ────────────────────────────────────────────
52
  @torch.no_grad()
53
  def get_score(image: Image.Image) -> float:
54
- """Returns raw float score (typically 0-1 range)."""
55
  image_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(DEVICE)
56
  features = clip_model.encode_image(image_tensor).cpu().numpy()
57
  norm = np.linalg.norm(features, axis=1, keepdims=True)
58
  norm[norm == 0] = 1
59
  features = features / norm
60
  features_t = torch.tensor(features, dtype=torch.float32, device=DEVICE)
61
- raw = aesthetic_model(features_t).item()
62
- return raw
63
 
64
 
65
  def raw_to_pony(raw: float) -> int:
66
- """Convert raw score to pony score_0...score_9 (same formula as notebook)."""
67
  return int(max(0.0, min(0.99, raw)) * 10)
68
 
69
 
70
- # ── Colour palette ─────────────────────────────────────────────────────────────
71
  SCORE_COLOURS = [
72
  "#c0392b", "#e74c3c", "#e67e22", "#f39c12", "#d4ac0d",
73
  "#27ae60", "#1e8449", "#148f77", "#0e6655", "#0a4f42",
@@ -75,7 +61,7 @@ SCORE_COLOURS = [
75
 
76
 
77
  def build_html(raw: float) -> str:
78
- pony = raw_to_pony(raw)
79
  colour = SCORE_COLOURS[pony]
80
 
81
  tiles_html = ""
@@ -86,12 +72,13 @@ def build_html(raw: float) -> str:
86
  weight = "700" if active else "400"
87
  scale = "scale(1.12)" if active else "scale(1)"
88
  opac = "1" if active else "0.45"
89
- tiles_html += f"""<div style="background:{bg};border:{border};border-radius:8px;
90
- padding:10px 0;text-align:center;font-size:.82rem;font-weight:{weight};color:#fff;
91
- transform:{scale};opacity:{opac};transition:all .2s;user-select:none;">score_{i}</div>"""
 
 
92
 
93
  bar_w = min(raw, 1.0) * 100
94
-
95
  return f"""
96
  <div style="font-family:'Inter',sans-serif;padding:8px 0;">
97
  <div style="text-align:center;margin-bottom:20px;">
@@ -99,7 +86,7 @@ def build_html(raw: float) -> str:
99
  padding:14px 36px;font-size:2rem;font-weight:800;letter-spacing:.04em;
100
  box-shadow:0 4px 20px {colour}66;">score_{pony}</div>
101
  <div style="color:#aaa;font-size:.85rem;margin-top:8px;">
102
- raw score: <code style="color:#ddd">{raw:.4f}</code>
103
  </div>
104
  </div>
105
  <div style="display:grid;grid-template-columns:repeat(10,1fr);gap:6px;margin-bottom:16px;">
@@ -108,7 +95,7 @@ def build_html(raw: float) -> str:
108
  <div style="background:rgba(255,255,255,.1);border-radius:6px;height:8px;overflow:hidden;">
109
  <div style="width:{bar_w:.1f}%;height:100%;
110
  background:linear-gradient(90deg,#c0392b,#f39c12,#27ae60);
111
- border-radius:6px;transition:width .4s;"></div>
112
  </div>
113
  <div style="display:flex;justify-content:space-between;font-size:.72rem;color:#777;margin-top:4px;">
114
  <span>score_0</span><span>score_9</span>
@@ -119,26 +106,24 @@ def build_html(raw: float) -> str:
119
  def classify(image):
120
  if image is None:
121
  return "<p style='color:#888;text-align:center'>Upload an image to score it.</p>"
122
- raw = get_score(image)
123
- return build_html(raw)
124
 
125
 
126
- # ── Gradio UI ───────────────────────────────────────────────────────────────────
127
  with gr.Blocks(
128
  title="Aesthetic Classifier - PurpleSmartAI",
129
  theme=gr.themes.Soft(primary_hue="purple"),
130
  css=".gradio-container{max-width:860px!important;margin:auto} #title{text-align:center} #sub{text-align:center;color:#888;font-size:.9rem;margin-bottom:1.5rem}",
131
  ) as demo:
132
- gr.Markdown("# Aesthetic Classifier", elem_id="title")
133
  gr.Markdown(
134
  "CLIP ViT-L/14 regression model by **PurpleSmartAI** for Pony V7 captioning. "
135
- "Outputs a **score_0...score_9** tag used directly in training captions.",
136
  elem_id="sub",
137
  )
138
  with gr.Row():
139
  with gr.Column(scale=1):
140
  img_input = gr.Image(type="pil", label="Input Image", height=340)
141
- run_btn = gr.Button("Score image", variant="primary", size="lg")
142
  with gr.Column(scale=1):
143
  out_html = gr.HTML(
144
  value="<p style='color:#888;text-align:center;padding:40px 0'>Upload an image to see its score.</p>",
 
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"[info] device: {DEVICE}")
12
 
 
21
  )
22
  checkpoint_data = torch.load(ckpt_path, map_location=DEVICE)
23
  state_dict = checkpoint_data["state_dict"]
24
+ # Strip "model." prefix β€” keys become "0.weight", "3.weight", "6.weight"
25
  state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
26
 
27
+ # Build Sequential directly so keys match ("0.weight", "3.weight", "6.weight")
28
+ aesthetic_model = nn.Sequential(
29
+ nn.Linear(768, 1024),
30
+ nn.ReLU(),
31
+ nn.Dropout(0.5),
32
+ nn.Linear(1024, 512),
33
+ nn.ReLU(),
34
+ nn.Dropout(0.3),
35
+ nn.Linear(512, 1),
36
+ ).to(DEVICE)
37
  aesthetic_model.load_state_dict(state_dict)
38
  aesthetic_model.eval()
39
  print("[info] Model ready.")
40
 
41
 
 
42
  @torch.no_grad()
43
  def get_score(image: Image.Image) -> float:
 
44
  image_tensor = preprocess(image.convert("RGB")).unsqueeze(0).to(DEVICE)
45
  features = clip_model.encode_image(image_tensor).cpu().numpy()
46
  norm = np.linalg.norm(features, axis=1, keepdims=True)
47
  norm[norm == 0] = 1
48
  features = features / norm
49
  features_t = torch.tensor(features, dtype=torch.float32, device=DEVICE)
50
+ return aesthetic_model(features_t).item()
 
51
 
52
 
53
  def raw_to_pony(raw: float) -> int:
 
54
  return int(max(0.0, min(0.99, raw)) * 10)
55
 
56
 
 
57
  SCORE_COLOURS = [
58
  "#c0392b", "#e74c3c", "#e67e22", "#f39c12", "#d4ac0d",
59
  "#27ae60", "#1e8449", "#148f77", "#0e6655", "#0a4f42",
 
61
 
62
 
63
  def build_html(raw: float) -> str:
64
+ pony = raw_to_pony(raw)
65
  colour = SCORE_COLOURS[pony]
66
 
67
  tiles_html = ""
 
72
  weight = "700" if active else "400"
73
  scale = "scale(1.12)" if active else "scale(1)"
74
  opac = "1" if active else "0.45"
75
+ tiles_html += (
76
+ f'<div style="background:{bg};border:{border};border-radius:8px;'
77
+ f'padding:10px 0;text-align:center;font-size:.82rem;font-weight:{weight};color:#fff;'
78
+ f'transform:{scale};opacity:{opac};transition:all .2s;user-select:none;">score_{i}</div>'
79
+ )
80
 
81
  bar_w = min(raw, 1.0) * 100
 
82
  return f"""
83
  <div style="font-family:'Inter',sans-serif;padding:8px 0;">
84
  <div style="text-align:center;margin-bottom:20px;">
 
86
  padding:14px 36px;font-size:2rem;font-weight:800;letter-spacing:.04em;
87
  box-shadow:0 4px 20px {colour}66;">score_{pony}</div>
88
  <div style="color:#aaa;font-size:.85rem;margin-top:8px;">
89
+ raw: <code style="color:#ddd">{raw:.4f}</code>
90
  </div>
91
  </div>
92
  <div style="display:grid;grid-template-columns:repeat(10,1fr);gap:6px;margin-bottom:16px;">
 
95
  <div style="background:rgba(255,255,255,.1);border-radius:6px;height:8px;overflow:hidden;">
96
  <div style="width:{bar_w:.1f}%;height:100%;
97
  background:linear-gradient(90deg,#c0392b,#f39c12,#27ae60);
98
+ border-radius:6px;"></div>
99
  </div>
100
  <div style="display:flex;justify-content:space-between;font-size:.72rem;color:#777;margin-top:4px;">
101
  <span>score_0</span><span>score_9</span>
 
106
  def classify(image):
107
  if image is None:
108
  return "<p style='color:#888;text-align:center'>Upload an image to score it.</p>"
109
+ return build_html(get_score(image))
 
110
 
111
 
 
112
  with gr.Blocks(
113
  title="Aesthetic Classifier - PurpleSmartAI",
114
  theme=gr.themes.Soft(primary_hue="purple"),
115
  css=".gradio-container{max-width:860px!important;margin:auto} #title{text-align:center} #sub{text-align:center;color:#888;font-size:.9rem;margin-bottom:1.5rem}",
116
  ) as demo:
117
+ gr.Markdown("# 🎨 Aesthetic Classifier", elem_id="title")
118
  gr.Markdown(
119
  "CLIP ViT-L/14 regression model by **PurpleSmartAI** for Pony V7 captioning. "
120
+ "Outputs a **score_0…score_9** tag used directly in training captions.",
121
  elem_id="sub",
122
  )
123
  with gr.Row():
124
  with gr.Column(scale=1):
125
  img_input = gr.Image(type="pil", label="Input Image", height=340)
126
+ run_btn = gr.Button("✨ Score image", variant="primary", size="lg")
127
  with gr.Column(scale=1):
128
  out_html = gr.HTML(
129
  value="<p style='color:#888;text-align:center;padding:40px 0'>Upload an image to see its score.</p>",