File size: 24,777 Bytes
045cc5e
 
 
 
 
 
47a106f
4a7bcc0
38dd8ce
47a106f
 
045cc5e
 
910b7ba
 
045cc5e
47a106f
38dd8ce
 
 
 
47a106f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38dd8ce
 
 
 
 
 
 
 
 
045cc5e
 
 
 
 
 
38dd8ce
045cc5e
7e9a2b9
045cc5e
3ee9006
38dd8ce
 
3ee9006
045cc5e
38dd8ce
 
 
3ee9006
 
38dd8ce
 
3ee9006
 
 
 
 
 
 
 
 
 
 
 
 
 
045cc5e
38dd8ce
 
47a106f
 
045cc5e
3ee9006
38dd8ce
 
 
 
 
 
 
 
 
 
 
 
3ee9006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ea6665
 
 
 
 
3ee9006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e9a2b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
045cc5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47a106f
045cc5e
 
 
47a106f
045cc5e
47a106f
045cc5e
 
 
 
 
 
 
 
 
 
 
47a106f
045cc5e
 
 
4a7bcc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
045cc5e
38dd8ce
 
 
 
 
 
 
 
 
8abf62f
56e8164
38dd8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1adae2
38dd8ce
 
a1adae2
38dd8ce
 
a1adae2
38dd8ce
 
a1adae2
38dd8ce
 
a1adae2
38dd8ce
bbf7050
 
 
 
 
 
 
38dd8ce
 
045cc5e
38dd8ce
 
 
 
 
 
 
 
 
 
 
 
045cc5e
 
4a7bcc0
 
 
 
 
 
045cc5e
4a7bcc0
045cc5e
 
 
 
 
47a106f
4a7bcc0
47a106f
 
 
045cc5e
 
38dd8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47a106f
 
4a7bcc0
47a106f
 
 
 
 
 
 
4a7bcc0
47a106f
 
 
 
 
4a7bcc0
47a106f
 
 
 
 
 
4a7bcc0
47a106f
 
 
 
 
 
 
4a7bcc0
47a106f
 
 
 
 
 
4a7bcc0
47a106f
 
 
 
 
 
4a7bcc0
47a106f
 
 
 
 
 
 
 
4a7bcc0
47a106f
 
 
 
 
 
4a7bcc0
47a106f
 
 
 
 
 
 
 
 
 
 
 
4a7bcc0
47a106f
 
 
 
 
 
4a7bcc0
47a106f
 
 
4a7bcc0
47a106f
 
 
045cc5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a7bcc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38dd8ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
045cc5e
 
47a106f
38dd8ce
 
 
 
47a106f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
045cc5e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
import datetime
import gradio as gr
import requests
import concurrent.futures
import boto3
import os
from dotenv import load_dotenv
from styles import predefined_styles
from images_and_poses import default_images, default_poses

load_dotenv()

INSTANT_ID_URL = "https://europe-west1-mdevcamp-ai.cloudfunctions.net/instantid"
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
AWS_ACCESS_SECRET = os.getenv('AWS_ACCESS_SECRET')

def process_images(
    person_images_defaults,
    person_images_custom,
    pose_images_defaults,
    pose_images_custom,
    prompt,
    negative_prompt,
    num_steps,
    identity_strength_ration,
    adapter_strength_ration,
    pose_strength_ration,
    canny_strength_ration,
    depth_strength_ration,
    guidance_strength_ration,
    generations_repeat_count,
    controlnet_selection,
    scheduler,
    enable_lcm,
    enhance_face_region,
):
    person_images_custom = person_images_custom if person_images_custom is not None else []
    person_images_defaults = person_images_defaults if person_images_defaults is not None else {}
    # person_images_defaults = [person_images_defaults[key] for key in person_images_defaults.keys()]

    pose_images_custom = pose_images_custom if pose_images_custom is not None else []
    pose_images_defaults = pose_images_defaults if pose_images_defaults is not None else {}
    # pose_images_defaults = [pose_images_defaults[key] for key in pose_images_defaults.keys()]

    if len(person_images_defaults) + len(person_images_custom) == 0:
        gr.Warning('No person images set')
        return
    if prompt == None or len(prompt) == 0:
        gr.Warning('No prompt set')
        return

    person_images_paths = person_images_custom if person_images_custom is not None else []

    yield [], "Uploading images"

    uploaded_person_image_urls, uploaded_pose_image_urls = upload_images_concurrently(
        person_images_paths=person_images_custom,
        pose_images_paths=pose_images_custom,
    )

    person_images_urls = uploaded_person_image_urls + person_images_defaults
    posess_images_urls = uploaded_pose_image_urls + pose_images_defaults

    requests_data = generate_requests_data(
        generations_repeat_count=generations_repeat_count,
        uploaded_person_image_urls=person_images_urls,
        uploaded_pose_image_urls=posess_images_urls,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_steps=num_steps,
        identity_strength_ration=identity_strength_ration,
        adapter_strength_ration=adapter_strength_ration,
        pose_strength_ration=pose_strength_ration,
        canny_strength_ration=canny_strength_ration,
        depth_strength_ration=depth_strength_ration,
        controlnet_selection=controlnet_selection,
        guidance_strength_ration=guidance_strength_ration,
        scheduler=scheduler,
        enable_lcm=enable_lcm,
        enhance_face_region=enhance_face_region,
    )

    print(requests_data)

    yield [], f"Generating images 0/{len(requests_data)}"

    gallery_items = []
    error_count = 0
    for req_data in requests_data:
        response = execute_instantid_request(req_data)
        if response is not None:
            gallery_items.append((response, "Caption"))
        else:
            error_image = "https://cdn.pixabay.com/photo/2017/02/12/21/29/false-2061132_640.png"
            gallery_items.append((error_image, "Caption"))
            error_count += 1
        
        loading_image = "https://t4.ftcdn.net/jpg/03/16/15/47/360_F_316154790_pnHGQkERUumMbzAjkgQuRvDgzjAHkFaQ.jpg"
        images = gallery_items + [loading_image] * (len(requests_data) - len(gallery_items))
        yield images, f"Generating images {len(gallery_items)}/{len(requests_data)} (Failed: {error_count})"
 
def generate_requests_data(
    generations_repeat_count,
    uploaded_person_image_urls,
    uploaded_pose_image_urls,
    prompt,
    negative_prompt,
    num_steps,
    identity_strength_ration,
    adapter_strength_ration,
    pose_strength_ration,
    canny_strength_ration,
    depth_strength_ration,
    controlnet_selection,
    guidance_strength_ration,
    scheduler,
    enable_lcm,
    enhance_face_region,
):
    requests_data = []
    for i in range(0, generations_repeat_count):
        for person_image_url in uploaded_person_image_urls:

            # Use person image if no poses ara available
            poses = [person_image_url] if len(uploaded_pose_image_urls) == 0 else uploaded_pose_image_urls
            
            for pose_image_url in poses:
                requests_data.append(
                    {
                        "faceImageUrl": person_image_url,
                        "poseImageUrl": pose_image_url,
                        "prompt": prompt,
                        "n_prompt": negative_prompt,
                        "num_steps": num_steps,
                        "identity_strength_ration": identity_strength_ration,
                        "adapter_strength_ration": adapter_strength_ration,
                        "pose_strength_ration": pose_strength_ration,
                        "canny_strength_ration": canny_strength_ration,
                        "depth_strength_ration": depth_strength_ration,
                        "controlnet_selection": controlnet_selection,
                        "guidance_strength_ration": guidance_strength_ration,
                        "scheduler": scheduler,
                        "enable_lcm": enable_lcm,
                        "enhance_face_region": enhance_face_region
                    }
                )
    return requests_data 

def upload_images_concurrently(person_images_paths, pose_images_paths):
    """
    Uploads person and pose images concurrently and keeps them organized.
    Returns a tuple of lists: (list of person image URLs, list of pose image URLs).
    """
    with concurrent.futures.ThreadPoolExecutor() as executor:
        total_images = len(person_images_paths) + len(pose_images_paths)
        uploaded_count = 0

        # Create a dictionary to hold all futures, tagging them with 'person' or 'pose'
        all_futures = {}
        for image_path in person_images_paths:
            future = executor.submit(upload_image_to_s3, image_path)
            all_futures[future] = 'person'
        for image_path in pose_images_paths:
            future = executor.submit(upload_image_to_s3, image_path)
            all_futures[future] = 'pose'
        
        # Collect results as they complete
        uploaded_person_image_urls = []
        uploaded_pose_image_urls = []
        for future in concurrent.futures.as_completed(all_futures):
            uploaded_count += 1
            # progress_callback(uploaded_count, total_images)

            image_url = future.result()
            if all_futures[future] == 'person':
                uploaded_person_image_urls.append(image_url)
            else:  # 'pose'
                uploaded_pose_image_urls.append(image_url)
    
    return uploaded_person_image_urls, uploaded_pose_image_urls

def upload_image_to_s3(image_path) -> str:
    s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_ACCESS_SECRET)
    bucket_name = 'mdevcamp-ai-upload-script'
    image_file_name = os.path.basename(image_path)
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    image_key = f"images/{timestamp}-{image_file_name}"

    print(f"Uploading started: {image_key}")

    try:
        with open(image_path, 'rb') as image:
            s3_client.upload_fileobj(image, bucket_name, image_key)
        
        print(f"Uploading finished: {image_key}")

        image_url = f"https://{bucket_name}.s3.amazonaws.com/{image_key}"
        return image_url
    except Exception as e:
        gr.Error("Uploading finished with error")
        print(f"Uploading finished with error: {e}")
        return None

def execute_instantid_request(data: dict) -> str | None:
    data = {
        "instances": [data]
    }

    print(f"InstantID started: {data}")

    response = requests.post(INSTANT_ID_URL, json=data)

    print(f"InstatntID finished: {response.status_code}")

    if 200 <= response.status_code < 300:
        return response.content.decode('utf-8')
    else:
        gr.Error("InstantID finished with error")
        print(f"InstatntID finished: {response.__dict__}")
        return None

def update_gradion_elements_with_style(style):
    style_obj = predefined_styles[style] if style in predefined_styles else predefined_styles[0]
    return (
        style_obj.prompt,
        style_obj.negative_prompt,
        style_obj.num_steps,
        style_obj.identity_strength_ratio, 
        style_obj.adapter_strength_ratio,  
        style_obj.pose_strength_ratio,     
        style_obj.canny_strength_ratio,    
        style_obj.depth_strength_ratio,    
        style_obj.guidance_strength_ratio,
        style_obj.generations_repeat_count,
        style_obj.controlnet_selection,
        style_obj.scheduler,
        style_obj.enable_lcm,
        style_obj.enhance_face_region,
    )

default_style = predefined_styles["Sculpture"]

with gr.Blocks() as demo:
    gr.Markdown("""
    
        # mDevCamp app tester
                    
        Tento nástroj slouží k experimentování s vytvářením stylů a póz, které budeme využívat v rámci mDevCamp aplikace. Uživatel nahraje selfie a následně bude odemykat různé styly a pózy, které bude moci kombinovat.
        Od vás bychom potřebovali, abyste byli trochu kreativní a zkusili vymyslet styly a pózy, které by byly použitelné. Kreativitě se meze nekladou.
                
        Pokud si myslíte, že se vám podařil nějaký zajímavý styl, tak zkopírujte prompty, pošlete pózu a screenshot všech nastavení do #ai-avatars kanálu na Slacku.
                
        Pro generování promptů můžete použít například [mDevCamp prompt generátor](https://chat.futured.app/?client=mdevcamppromptgenerator) nebo [Art Style Explorer](https://chat.openai.com/g/g-669XwyKQz-art-style-explorer), pokud máte zaplacený ChatGPT+.
                
        **Našim cieľom je mať nastavenia ktoré sú dobre zreprodukovatelné a ponúkajú dobrý výsledok aj pri viacnásobnom opakovaní.**
                
        **HW, na kterém to běží, dokáže zpracovat pouze jednu fotku naráz, tak se snažte nepouštět generování pro desítky nebo stovky obrázků, protože budete blokovat ostatní**

        **HW, na kterém to běží, je celkem drahý, proto ho budeme zapínat a vypínat pouze v určité hodiny**
                
        **Pokud příliš dlouho čekáte a pak se vám vracejí errory, tak je možné, že to právě používá i někdo jiný, škálování řešení se ještě neimplementuje**
    """)      

    with gr.Accordion("Návod", open=False):
        gr.Markdown(""" 
        ### Jak to vlastně funguje?

        ##### Výběr fotky
        Můžete vybrat jednu z přednastavených nebo nahrát svou vlastní
                    
        ##### Výběr stylu
        Pro inspiraci máme několik stylů, které už existují, můžete se u nich inspirovat tím, jak nastavit jednotlivé parametry, styl po výběru můžete volně upravovat
                    
        #### Prompt
        Zde popíšete, jak bude vypadat výsledný obrázek, prompt nemá velký vliv na to, v jaké poloze například obrázek bude, spíše popisuje, v jakém stylu to bude vytvořeno. Inspirujte se existujícími.
                    
        #### Negative prompt
        Zde popište, jak nechcete, aby obrázek vypadal, pokud dostanete rozmazaný obrázek, tak tam chcete přidat blurry například.
                    
        #### Pózy
        Zde vyberte nějakou pózu z předdefinovaných nebo nahrajte svou vlastní. Nemusíte tuto hodnotu nastavovat. Výběr kvalitní pózy je celkem náročný, některé mohou skončit errorem. Ideál je póza, která zachycuje horní část těla a je relativně blízko. Inspirace může být "We Can Do It" póza. Při výběru více póz se kombinuje každá póza s každým avatarem. Tak opatrně.
                    
        ### Pózy

        #### Generations for each image
        Počet generací každého obrázku. Pokud máš jednu fotku a sem nastavíš 5, tak vygeneruješ 5 fotek.
                    
        #### Num steps
        Určuje počet kroků, které model provede při generování obrázku. Vyšší počet kroků vede k detailnějším výsledkům, ale prodlužuje čas generování.

        #### Identity strength ration    
        Ovlivňuje míru, do jaké se zachová identita vstupního obrázku. Vyšší hodnota znamená, že výsledný obrázek bude více podobný vstupu.

        #### Adapter strength ration
        Ovlivňuje míru, do jaké se nastaví podobnost usazení tváře na referenční pozici.

        #### Pose strength ration
        Určuje míru, do jaké se aplikuje informace o póze ze vstupního obrázku. Vyšší hodnota znamená, že výsledná póza bude více podobná té nastavené.
                    
        #### Canny strength ration
        Určuje míru, do jaké se aplikuje informace o jednotlivých detailech ze vstupního obrázku. Vyšší hodnota znamená, že generovaný obrázek bude kopírovat i věci jako logo na tričku například.

        #### Depth strength ration
        Kontroluje míru, do jaké se aplikuje informace o hloubce ze vstupního obrázku. Vyšší hodnota znamená silnější vliv hloubkové mapy na generovaný obrázek.

        #### Guidance strength ration
        Vcelku pokročilá věc, asi nemusíš měnit nebo dohledej online.
                    
        #### Scheduler
        Vcelku pokročilá věc, asi nemusíš měnit nebo dohledej online.
                    
        #### Controlnet selection
        Určuje, to či sa má použiť póza (pose), detaily (canny) alebo hĺbka (depth) z obrázku.

        #### Enable LCM
        Vcelku pokročilá věc, asi nemusíš měnit nebo dohledej online.

        #### Enhance Face Region
        Vcelku pokročilá věc, asi nemusíš měnit nebo dohledej online.

        ### Usage tips
        - If you're not satisfied with the similarity, try to increase the weight of "IdentityNet Strength" and "Adapter Strength".
        - If you feel that the saturation is too high, first decrease the Adapter strength. If it is still too high, then decrease the IdentityNet strength.
        - If you find that text control is not as expected, decrease Adapter strength.
        - If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.

        ![manual](https://instantid.github.io/static/documents/editbility.jpg)
        """)

    with gr.Row():
        with gr.Column():
            gr.Markdown("### Faces")
            default_faces_selection = gr.CheckboxGroup(
                choices=[(key, value) for key, value in default_images.items()],
                label="Select some familiar faces",
                type="value",
            )
            custom_faces_upload = gr.File(
                label="Or upload custom selfies",
                file_count="multiple",
                file_types=["jpg", "jpeg", "png", "webp"],
            )


    style_dropdown = gr.Dropdown(
        choices=list(predefined_styles.keys()),
        value=default_style.style_name,
        label="Predefined style",
    )

    prompt = gr.Textbox(
        value=default_style.prompt,
        label="Prompt",
        lines=5,
        placeholder="Superman",
    )

    negative_prompt = gr.Textbox(
        value=default_style.negative_prompt,
        label="Negative prompt",
        lines=2,
        placeholder="Blurry",
    )

    with gr.Accordion("Poses", open=False):
        with gr.Row():
            with gr.Column():
                gr.Markdown("### Poses")
                default_poses_selection = gr.CheckboxGroup(
                    choices=[(key, value) for key, value in default_poses.items()],
                    label="Select default poses",
                    type="value",
                )
                default_poses_previews = gr.Gallery(
                    label="Defaul poses gallery",
                    columns=5,
                    show_label=True,
                    allow_preview=False,
                )

                def update_default_poses_previews(selected_poses):
                    selected_poses_previews = [(pose, key) for key, pose in default_poses.items() if pose in selected_poses]
                    return selected_poses_previews

                default_poses_selection.change(
                    fn=update_default_poses_previews,
                    inputs=[default_poses_selection],
                    outputs=[default_poses_previews],
                )

                custom_poses_upload = gr.File(
                    label="Or upload custom poses",
                    file_count="multiple",
                    file_types=["jpg", "jpeg", "png", "webp"],
                )

    with gr.Accordion("Advanced", open=False):
        generations_repeat_count = gr.Number(
            value=default_style.generations_repeat_count,
            label="Generations for each image (how much images will be generated from each image)",
            minimum=1,
            maximum=10,
        )

        with gr.Row():
            num_steps = gr.Number(
                value=default_style.num_steps,
                label="Num steps",
                minimum=1,
                maximum=100,
            )
            identity_strength_ration = gr.Number(
                value=default_style.identity_strength_ratio,
                label="Identity strength ration",
                minimum=0.0,
                maximum=2.0,
                step=0.01,
            )
            adapter_strength_ration = gr.Number(
                value=default_style.adapter_strength_ratio,
                label="Adapter strength ration",
                minimum=0.0,
                maximum=2.0,
                step=0.01,
            )
        with gr.Row():
            pose_strength_ration = gr.Number(
                value=default_style.pose_strength_ratio,
                label="Pose strength ration",
                minimum=0.0,
                maximum=2.0,
                step=0.01,
            )
            canny_strength_ration = gr.Number(
                value=default_style.canny_strength_ratio,
                label="Canny strength ration",
                minimum=0.0,
                maximum=2.0,
                step=0.01,
            )
            depth_strength_ration = gr.Number(
                value=default_style.depth_strength_ratio,
                label="Depth strength ration",
                minimum=0.0,
                maximum=2.0,
                step=0.01,
            )
            
        with gr.Row():
            guidance_strength_ration = gr.Number(
                value=default_style.guidance_strength_ratio,
                label="Guidance strength ration",
                minimum=0.0,
                maximum=50.0,
                step=0.01,
            )
            scheduler = gr.Dropdown(
                value=default_style.scheduler,
                choices=[
                    "DEISMultistepScheduler",
                    "HeunDiscreteScheduler",
                    "EulerDiscreteScheduler",
                    "DPMSolverMultistepScheduler",
                    "DPMSolverMultistepScheduler-Karras",
                    "DPMSolverMultistepScheduler-Karras-SDE",
                ],
                label="Scheduler",
            )
        with gr.Row():
            controlnet_selection = gr.CheckboxGroup(
                value=default_style.controlnet_selection,
                choices=["pose", "canny", "depth"],
                label="Controlnet selection",
            )

        with gr.Row():
            enable_lcm = gr.Checkbox(
                value=default_style.enable_lcm,
                label="Enable LCM",
            )
            enhance_face_region = gr.Checkbox(
                value=default_style.enhance_face_region,
                label="Enhance Face Region",
            )

    btn = gr.Button(
        value="Generate",
    )

    status = gr.Markdown()

    output_gallery = gr.Gallery(
        label="Results",
        show_label=False,
        elem_id="gallery",
        columns=[3],
        rows=[1],
        object_fit="contain",
        height="auto",
    )

    style_dropdown.change(
        fn=update_gradion_elements_with_style,
        inputs=[style_dropdown],
        outputs=[
            prompt,
            negative_prompt,
            num_steps,
            identity_strength_ration, 
            adapter_strength_ration,  
            pose_strength_ration,     
            canny_strength_ration,    
            depth_strength_ration,    
            guidance_strength_ration,
            generations_repeat_count,
            controlnet_selection,
            scheduler,
            enable_lcm,
            enhance_face_region,
        ],
    )

    def calculate_total_images(person_images_defaults, person_images_custom, pose_images_defaults, pose_images_custom, generations_repeat_count):
        person_images_custom = person_images_custom if person_images_custom is not None else []
        person_images_count = len([image for image in person_images_defaults if image in default_images.values()]) + len(person_images_custom)

        pose_images_custom = pose_images_custom if pose_images_custom is not None else []
        pose_images_count = len([pose for pose in pose_images_defaults if pose in default_poses.values()]) + len(pose_images_custom)
        
        if pose_images_count == 0:
            pose_images_count = 1  # Use person image if no poses are available
        
        total_images = person_images_count * pose_images_count * generations_repeat_count
        return total_images
    
    def update_button_text(person_images_defaults, person_images_custom, pose_images_defaults, pose_images_custom, generations_repeat_count):
        total_images = calculate_total_images(person_images_defaults, person_images_custom, pose_images_defaults, pose_images_custom, generations_repeat_count)
        return f"Generate ({total_images} images)"

    default_faces_selection.change(
        fn=update_button_text,
        inputs=[
            default_faces_selection,
            custom_faces_upload,
            default_poses_selection,
            custom_poses_upload,
            generations_repeat_count,
        ],
        outputs=[
            btn
        ],
    )

    custom_faces_upload.change(
        fn=update_button_text,
        inputs=[
            default_faces_selection,
            custom_faces_upload,
            default_poses_selection,
            custom_poses_upload,
            generations_repeat_count,
        ],
        outputs=[
            btn
        ],
    )

    default_poses_selection.change(
        fn=update_button_text,
        inputs=[
            default_faces_selection,
            custom_faces_upload,
            default_poses_selection,
            custom_poses_upload,
            generations_repeat_count,
        ],
        outputs=[
            btn
        ],
    )

    custom_poses_upload.change(
        fn=update_button_text,
        inputs=[
            default_faces_selection,
            custom_faces_upload,
            default_poses_selection,
            custom_poses_upload,
            generations_repeat_count,
        ],
        outputs=[
            btn
        ],
    )

    generations_repeat_count.change(
        fn=update_button_text,
        inputs=[
            default_faces_selection,
            custom_faces_upload,
            default_poses_selection,
            custom_poses_upload,
            generations_repeat_count,
        ],
        outputs=[
            btn
        ],
    )

    btn.click(
        fn=process_images,
        inputs=[
            default_faces_selection,
            custom_faces_upload,
            default_poses_selection,
            custom_poses_upload,
            prompt,
            negative_prompt,
            num_steps,
            identity_strength_ration,
            adapter_strength_ration,
            pose_strength_ration,
            canny_strength_ration,
            depth_strength_ration,
            guidance_strength_ration,
            generations_repeat_count,
            controlnet_selection,
            scheduler,
            enable_lcm,
            enhance_face_region,
        ],
        outputs=[output_gallery, status]
    )

if __name__ == "__main__":
    demo.launch()