Spaces:
Runtime error
Runtime error
Make color and structure weights configurable
Browse files
app.py
CHANGED
|
@@ -151,6 +151,8 @@ def run(
|
|
| 151 |
image,
|
| 152 |
style_type: str,
|
| 153 |
style_id: float,
|
|
|
|
|
|
|
| 154 |
dlib_landmark_model,
|
| 155 |
encoder: nn.Module,
|
| 156 |
generator_dict: dict[str, nn.Module],
|
|
@@ -191,7 +193,8 @@ def run(
|
|
| 191 |
truncation=0.7,
|
| 192 |
truncation_latent=0,
|
| 193 |
use_res=True,
|
| 194 |
-
interp_weights=[
|
|
|
|
| 195 |
img_gen = torch.clamp(img_gen.detach(), -1, 1)
|
| 196 |
# deactivate color-related layers by setting w_c = 0
|
| 197 |
img_gen2, _ = generator([instyle],
|
|
@@ -200,7 +203,7 @@ def run(
|
|
| 200 |
truncation=0.7,
|
| 201 |
truncation_latent=0,
|
| 202 |
use_res=True,
|
| 203 |
-
interp_weights=[
|
| 204 |
img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
|
| 205 |
|
| 206 |
img_rec = postprocess(img_rec[0])
|
|
@@ -249,7 +252,8 @@ def main():
|
|
| 249 |
func = functools.update_wrapper(func, run)
|
| 250 |
|
| 251 |
image_paths = sorted(pathlib.Path('images').glob('*.jpg'))
|
| 252 |
-
examples = [[path.as_posix(), 'cartoon', 26
|
|
|
|
| 253 |
|
| 254 |
gr.Interface(
|
| 255 |
func,
|
|
@@ -262,6 +266,10 @@ def main():
|
|
| 262 |
label='Style Type',
|
| 263 |
),
|
| 264 |
gr.inputs.Number(default=26, label='Style Image Index'),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
],
|
| 266 |
[
|
| 267 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|
|
|
|
| 151 |
image,
|
| 152 |
style_type: str,
|
| 153 |
style_id: float,
|
| 154 |
+
structure_weight: float,
|
| 155 |
+
color_weight: float,
|
| 156 |
dlib_landmark_model,
|
| 157 |
encoder: nn.Module,
|
| 158 |
generator_dict: dict[str, nn.Module],
|
|
|
|
| 193 |
truncation=0.7,
|
| 194 |
truncation_latent=0,
|
| 195 |
use_res=True,
|
| 196 |
+
interp_weights=[structure_weight] * 7 +
|
| 197 |
+
[color_weight] * 11)
|
| 198 |
img_gen = torch.clamp(img_gen.detach(), -1, 1)
|
| 199 |
# deactivate color-related layers by setting w_c = 0
|
| 200 |
img_gen2, _ = generator([instyle],
|
|
|
|
| 203 |
truncation=0.7,
|
| 204 |
truncation_latent=0,
|
| 205 |
use_res=True,
|
| 206 |
+
interp_weights=[structure_weight] * 7 + [0] * 11)
|
| 207 |
img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
|
| 208 |
|
| 209 |
img_rec = postprocess(img_rec[0])
|
|
|
|
| 252 |
func = functools.update_wrapper(func, run)
|
| 253 |
|
| 254 |
image_paths = sorted(pathlib.Path('images').glob('*.jpg'))
|
| 255 |
+
examples = [[path.as_posix(), 'cartoon', 26, 0.6, 1.0]
|
| 256 |
+
for path in image_paths]
|
| 257 |
|
| 258 |
gr.Interface(
|
| 259 |
func,
|
|
|
|
| 266 |
label='Style Type',
|
| 267 |
),
|
| 268 |
gr.inputs.Number(default=26, label='Style Image Index'),
|
| 269 |
+
gr.inputs.Slider(
|
| 270 |
+
0, 1, step=0.1, default=0.6, label='Structure Weight'),
|
| 271 |
+
gr.inputs.Slider(0, 1, step=0.1, default=1.0,
|
| 272 |
+
label='Color Weight'),
|
| 273 |
],
|
| 274 |
[
|
| 275 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|