RojaKatta commited on
Commit
bbc4eef
Β·
verified Β·
1 Parent(s): 210afcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -52
app.py CHANGED
@@ -1,54 +1,243 @@
1
- import gradio as gr
 
2
  from PIL import Image
3
- import os
4
- def load_hairstyles():
5
- folder = "hairstyles"
6
- if not os.path.exists(folder):
7
- return []
8
- return [
9
- Image.open(os.path.join(folder, f)).convert("RGBA")
10
- for f in sorted(os.listdir(folder)) if f.endswith(".png")
11
  ]
12
- hairstyles = load_hairstyles()
13
- def apply_hairstyle(user_img, style_index, x_offset, y_offset, scale):
14
- if user_img is None or not hairstyles:
15
- return None
16
- user_img = user_img.convert("RGBA")
17
- base_w, base_h = user_img.size
18
-
19
- hairstyle = hairstyles[style_index]
20
-
21
- # Resize the hairstyle based on scale
22
- new_size = (int(base_w * scale), int(hairstyle.height * (base_w * scale / hairstyle.width)))
23
- hairstyle = hairstyle.resize(new_size)
24
-
25
- # Create a blank transparent image to position the hairstyle
26
- composite = Image.new("RGBA", user_img.size)
27
- paste_x = int((base_w - new_size[0]) / 2 + x_offset)
28
- paste_y = int(y_offset)
29
- composite.paste(hairstyle, (paste_x, paste_y), hairstyle)
30
-
31
- # Overlay it
32
- result = Image.alpha_composite(user_img, composite)
33
- return result.convert("RGB")
34
-
35
- with gr.Blocks() as demo:
36
- gr.Markdown("## πŸ’‡ Salon Virtual Hairstyle Try-On (Adjustable)")
37
- with gr.Row():
38
- with gr.Column():
39
- image_input = gr.Image(type="pil", label="πŸ“· Upload an Image")
40
- style_slider = gr.Slider(0, max(len(hairstyles)-1, 0), step=1, label="🎨 Select Hairstyle")
41
- x_offset = gr.Slider(-200, 200, value=0, step=1, label="β¬…οΈβž‘οΈ Move Left / Right")
42
- y_offset = gr.Slider(-200, 200, value=0, step=1, label="⬆️⬇️ Move Up / Down")
43
- scale = gr.Slider(0.3, 2.0, value=1.0, step=0.05, label="πŸ“ Scale Hairstyle")
44
- apply_btn = gr.Button("✨ Apply Hairstyle")
45
- with gr.Column():
46
- result_output = gr.Image(label="πŸ” Result Preview")
47
-
48
- apply_btn.click(
49
- fn=apply_hairstyle,
50
- inputs=[image_input, style_slider, x_offset, y_offset, scale],
51
- outputs=result_output
52
- )
53
-
54
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, tempfile
2
+ import cv2, numpy as np, gradio as gr
3
  from PIL import Image
4
+
5
+ # ---------------------- Paths (hair/ first) ----------------------
6
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
+ CANDIDATES = [
8
+ os.path.join(BASE_DIR, "hair"), # <- your folder
9
+ os.path.join(BASE_DIR, "assets", "hairstyles"),
10
+ os.path.join(BASE_DIR, "assets", "Hairstyles"),
11
+ os.path.join(BASE_DIR, "hairstyles"),
12
  ]
13
+ HAIR_DIR = None
14
+ for p in CANDIDATES:
15
+ if os.path.isdir(p):
16
+ HAIR_DIR = p
17
+ break
18
+ if HAIR_DIR is None: # create canonical path if nothing exists yet
19
+ HAIR_DIR = os.path.join(BASE_DIR, "hair")
20
+ os.makedirs(HAIR_DIR, exist_ok=True)
21
+
22
+ META_PATH = os.path.join(HAIR_DIR, "meta.json") # optional per-style anchors
23
+
24
+ # ---------------------- MediaPipe ----------------------
25
+ try:
26
+ import mediapipe as mp
27
+ except Exception as e:
28
+ raise RuntimeError(f"Mediapipe import failed. Check requirements.txt pins. Details: {e}")
29
+
30
+ mp_face_mesh = mp.solutions.face_mesh
31
+ mp_selfie_seg = mp.solutions.selfie_segmentation
32
+ LM = {"left_eye_outer": 33, "right_eye_outer": 263, "mid_forehead": 10}
33
+
34
+ # ---------------------- Helpers ----------------------
35
+ def load_hairstyles():
36
+ try:
37
+ files = [f for f in os.listdir(HAIR_DIR) if f.lower().endswith(".png")]
38
+ except FileNotFoundError:
39
+ files = []
40
+ files.sort()
41
+ return files
42
+
43
+ HAIR_FILES = load_hairstyles()
44
+
45
+ def load_meta():
46
+ if os.path.exists(META_PATH):
47
+ try:
48
+ with open(META_PATH, "r") as f:
49
+ m = json.load(f)
50
+ return m if isinstance(m, dict) else {}
51
+ except Exception:
52
+ return {}
53
+ return {}
54
+ META = load_meta()
55
+
56
+ def detect_face_keypoints(img_bgr):
57
+ h, w = img_bgr.shape[:2]
58
+ with mp_face_mesh.FaceMesh(
59
+ static_image_mode=True, max_num_faces=1, refine_landmarks=True,
60
+ min_detection_confidence=0.6
61
+ ) as fm:
62
+ res = fm.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
63
+ if not res.multi_face_landmarks:
64
+ return None
65
+ lm = res.multi_face_landmarks[0].landmark
66
+ def xy(i): return np.array([lm[i].x*w, lm[i].y*h], dtype=np.float32)
67
+ return np.stack([xy(LM["left_eye_outer"]), xy(LM["right_eye_outer"]), xy(LM["mid_forehead"])])
68
+
69
+ def person_mask(img_bgr):
70
+ with mp_selfie_seg.SelfieSegmentation(model_selection=1) as seg:
71
+ rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
72
+ m = seg.process(rgb).segmentation_mask
73
+ mask = (m > 0.5).astype(np.float32)
74
+ mask = cv2.GaussianBlur(mask, (35, 35), 0)
75
+ return mask
76
+
77
+ def load_hair_png(name):
78
+ path = os.path.join(HAIR_DIR, name)
79
+ hair = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGRA
80
+ if hair is None or hair.shape[2] != 4:
81
+ raise ValueError(f"Invalid hair asset: {name} (must be RGBA PNG)")
82
+ return hair
83
+
84
+ def hair_reference_points(hair_bgra, filename):
85
+ h, w = hair_bgra.shape[:2]
86
+ if filename in META:
87
+ pts = np.array(META[filename], dtype=np.float32)
88
+ if pts.shape == (3, 2):
89
+ return pts
90
+ # Defaults (tune via meta.json for perfection)
91
+ pL = np.array([0.30*w, 0.60*h], dtype=np.float32)
92
+ pR = np.array([0.70*w, 0.60*h], dtype=np.float32)
93
+ pM = np.array([0.50*w, 0.40*h], dtype=np.float32)
94
+ return np.stack([pL, pR, pM], axis=0)
95
+
96
+ def warp_and_alpha_blend(base_bgr, hair_bgra, M, opacity=1.0):
97
+ H, W = base_bgr.shape[:2]
98
+ hair_rgb = hair_bgra[:, :, :3]
99
+ hair_a = hair_bgra[:, :, 3] / 255.0
100
+ hair_warp = cv2.warpAffine(hair_rgb, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
101
+ a_warp = cv2.warpAffine(hair_a, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
102
+ a = np.clip(a_warp * opacity, 0, 1)[..., None]
103
+ out = (a * hair_warp + (1 - a) * base_bgr).astype(np.uint8)
104
+ return out
105
+
106
+ def apply_tryon(image, hairstyle, scale_pct, rot_deg, dx, dy, opacity):
107
+ if image is None:
108
+ return None, "Upload a photo or enable webcam."
109
+ if not hairstyle:
110
+ return np.array(image), "Pick a hairstyle first."
111
+
112
+ img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
113
+
114
+ kpts = detect_face_keypoints(img_bgr)
115
+ if kpts is None:
116
+ return image, "No face detected. Try a brighter, front-facing photo."
117
+
118
+ hair = load_hair_png(hairstyle)
119
+ hair_pts = hair_reference_points(hair, hairstyle)
120
+
121
+ dst = kpts.copy()
122
+ dst[:, 0] += dx
123
+ dst[:, 1] += dy
124
+
125
+ center = hair_pts.mean(axis=0)
126
+ theta = np.deg2rad(rot_deg)
127
+ s = max(0.5, scale_pct / 100.0)
128
+ R = np.array([[np.cos(theta), -np.sin(theta)],
129
+ [np.sin(theta), np.cos(theta)]], dtype=np.float32)
130
+ hair_pts_adj = (hair_pts - center) @ R.T * s + center
131
+
132
+ M, _ = cv2.estimateAffinePartial2D(hair_pts_adj, dst, method=cv2.LMEDS)
133
+ if M is None:
134
+ return image, "Could not compute alignment for this image/style."
135
+
136
+ out = warp_and_alpha_blend(img_bgr, hair, M, opacity=opacity)
137
+
138
+ head = person_mask(img_bgr)
139
+ head3 = head[..., None]
140
+ out = (head3 * out + (1 - head3) * img_bgr).astype(np.uint8)
141
+
142
+ out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
143
+ return out_rgb, "OK"
144
+
145
+ def save_png(img):
146
+ if img is None:
147
+ return None
148
+ p = os.path.join(tempfile.gettempdir(), "tryon_result.png")
149
+ Image.fromarray(img).save(p)
150
+ return p
151
+
152
+ def hair_preview(hairstyle):
153
+ if not hairstyle:
154
+ return None
155
+ hair = load_hair_png(hairstyle)
156
+ h, w = hair.shape[:2]
157
+ tile = 16
158
+ bg = np.kron(
159
+ ((np.indices((h//tile+1, w//tile+1)).sum(axis=0) % 2) * 64 + 192).astype(np.uint8),
160
+ np.ones((tile, tile), np.uint8)
161
+ )[:h, :w]
162
+ bg_rgb = np.dstack([bg, bg, bg])
163
+ a = (hair[:, :, 3:4] / 255.0)
164
+ comp = (a * hair[:, :, :3] + (1 - a) * bg_rgb).astype(np.uint8)
165
+ comp = cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
166
+ return comp
167
+
168
+ # ---------------------- UI ----------------------
169
+ def build_ui():
170
+ with gr.Blocks(title="Virtual Try-On (FR1–FR8)", css="""
171
+ .gradio-container {max-width: 980px; margin: auto;}
172
+ @media (max-width: 768px){ .gradio-container {padding: 8px;} }
173
+ """) as demo:
174
+ gr.Markdown("## Salon Hairstyle Virtual Try-On\nUpload or use webcam, pick a style from **Select Hairstyle**, adjust, then download.")
175
+
176
+ if not HAIR_FILES:
177
+ gr.Markdown("⚠️ **No hairstyle PNGs found.** Upload files into **`hair/`** (or `assets/hairstyles/`) and reload this Space.")
178
+
179
+ with gr.Tabs():
180
+ # -------- Photo (FR1,3–7) --------
181
+ with gr.Tab("Photo"):
182
+ with gr.Row():
183
+ in_img = gr.Image(label="Upload photo (JPEG/PNG)", sources=["upload"], type="pil")
184
+ hair = gr.Dropdown(
185
+ choices=HAIR_FILES,
186
+ value=(HAIR_FILES[0] if HAIR_FILES else None),
187
+ label="Select Hairstyle (from 'hair/')",
188
+ interactive=True
189
+ )
190
+ with gr.Row():
191
+ preview = gr.Image(label="Hairstyle Preview", height=260)
192
+ hair.change(fn=hair_preview, inputs=[hair], outputs=[preview])
193
+
194
+ with gr.Accordion("Alignment Controls", open=True):
195
+ with gr.Row():
196
+ scale = gr.Slider(50, 200, 100, 1, label="Scale %")
197
+ rot = gr.Slider(-30, 30, 0, 1, label="Rotate (deg)")
198
+ with gr.Row():
199
+ dx = gr.Slider(-200, 200, 0, 1, label="Horizontal Nudge (px)")
200
+ dy = gr.Slider(-200, 200, 0, 1, label="Vertical Nudge (px)")
201
+ opacity = gr.Slider(0.2, 1.0, 1.0, 0.05, label="Hair Opacity")
202
+
203
+ out = gr.Image(label="Result Preview")
204
+ status = gr.Markdown()
205
+
206
+ run = gr.Button("Apply (Align & Overlay)")
207
+ run.click(
208
+ fn=lambda im, h, s, r, dxv, dyv, op: apply_tryon(im, h, s, r, dxv, dyv, op),
209
+ inputs=[in_img, hair, scale, rot, dx, dy, opacity],
210
+ outputs=[out, status]
211
+ )
212
+
213
+ dl = gr.DownloadButton(label="Download Result", file_name="tryon.png")
214
+ dl.click(fn=save_png, inputs=[out], outputs=[dl])
215
+
216
+ gr.Markdown("Share this Space link after you make it public (FR-7).")
217
+
218
+ # -------- Webcam (FR2–FR6) --------
219
+ with gr.Tab("Webcam"):
220
+ cam = gr.Image(sources=["webcam"], streaming=True, type="pil", label="Enable camera to start")
221
+ hair2 = gr.Dropdown(choices=HAIR_FILES, value=(HAIR_FILES[0] if HAIR_FILES else None), label="Select Hairstyle")
222
+ scale2 = gr.Slider(50, 200, 100, 1, label="Scale %")
223
+ rot2 = gr.Slider(-25, 25, 0, 1, label="Rotate (deg)")
224
+ dx2 = gr.Slider(-150, 150, 0, 1, label="Horizontal Nudge (px)")
225
+ dy2 = gr.Slider(-150, 150, 0, 1, label="Vertical Nudge (px)")
226
+ opacity2 = gr.Slider(0.2, 1.0, 0.95, 0.05, label="Hair Opacity")
227
+ out2 = gr.Image(label="Live Preview")
228
+
229
+ def live(im, h, s, r, dxv, dyv, op):
230
+ res, _ = apply_tryon(im, h, s, r, dxv, dyv, op)
231
+ return res
232
+
233
+ cam.stream(live, inputs=[cam, hair2, scale2, rot2, dx2, dy2, opacity2], outputs=[out2])
234
+
235
+ return demo
236
+
237
+ # Export for Spaces autostart
238
+ app = build_ui()
239
+ demo = app
240
+
241
+ # Local dev
242
+ if __name__ == "__main__":
243
+ app.launch()