File size: 14,827 Bytes
0926cd3
 
01a3261
0926cd3
 
 
ab8d95b
0926cd3
 
 
6e0fda9
ab8d95b
a77d74a
0abb49c
0926cd3
 
ab8d95b
0926cd3
 
01a3261
 
 
 
 
 
 
 
be9cced
 
0926cd3
a77d74a
 
 
 
 
 
 
 
 
 
 
0926cd3
 
01a3261
a77d74a
01a3261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a77d74a
01a3261
 
 
a77d74a
 
01a3261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a77d74a
01a3261
 
 
0926cd3
 
01a3261
 
 
 
 
 
 
 
 
 
 
 
 
a77d74a
01a3261
 
 
 
 
 
a77d74a
01a3261
 
 
 
 
 
 
 
a77d74a
 
01a3261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a77d74a
 
01a3261
 
 
 
 
 
 
 
 
 
 
 
 
 
0926cd3
 
ab8d95b
01a3261
be9cced
 
01a3261
 
ab8d95b
01a3261
0926cd3
 
01a3261
 
 
 
 
 
0926cd3
 
01a3261
 
 
 
 
0926cd3
 
 
ab8d95b
01a3261
be9cced
 
01a3261
 
ab8d95b
01a3261
0926cd3
 
01a3261
 
 
 
 
 
0926cd3
 
01a3261
 
 
 
 
0926cd3
 
 
01a3261
 
0926cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
01a3261
 
 
 
 
 
 
 
0926cd3
 
 
01a3261
 
0926cd3
a77d74a
c1254e9
 
 
01a3261
 
 
 
 
 
 
a77d74a
c1254e9
01a3261
c1254e9
 
01a3261
 
 
c1254e9
 
 
 
01a3261
 
c1254e9
01a3261
 
 
 
 
a77d74a
 
c1254e9
a77d74a
c1254e9
 
01a3261
c1254e9
 
 
6e0fda9
c1254e9
a77d74a
01a3261
 
 
 
a77d74a
 
c1254e9
a77d74a
0926cd3
 
01a3261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0926cd3
01a3261
 
c1254e9
01a3261
c1254e9
01a3261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1254e9
01a3261
c1254e9
 
01a3261
 
 
 
 
c1254e9
01a3261
c1254e9
 
 
 
 
01a3261
c1254e9
 
6e0fda9
01a3261
 
 
 
 
c1254e9
 
 
 
01a3261
 
 
c1254e9
 
01a3261
 
c1254e9
 
01a3261
c1254e9
 
01a3261
c1254e9
01a3261
 
 
 
 
 
 
 
 
 
 
 
 
 
c1254e9
 
01a3261
 
 
0abb49c
01a3261
c1254e9
01a3261
c1254e9
 
 
 
01a3261
 
 
 
 
 
 
 
 
 
 
 
 
c1254e9
01a3261
 
c1254e9
 
 
 
 
 
 
01a3261
c1254e9
 
 
 
01a3261
c1254e9
 
 
01a3261
c1254e9
 
01a3261
c1254e9
01a3261
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
from typing import List
from PIL import Image
import json
import gradio as gr
import io
import jinja2
import base64
import aws_utils
import parameters
import script_gen
import inout as iowrapper
import openai_wrapper
import json
import base64

AWS_BUCKET = parameters.AWS_BUCKET
llm = openai_wrapper.GPT_4O_MINI


#### Functions ordered by their order of developement.
def toggle_developer_options(is_developer: bool):
    if is_developer:
        # Return visibility updates for the developer options along with the values
        return gr.update(visible=True)
    else:
        # Hide the developer options and return only the updated visibility
        return gr.update(visible=False)


def list_current_dir(bucket_name: str, folder_path: str = "") -> list:
    try:
        response = aws_utils.S3_CLIENT.list_objects_v2(
            Bucket=bucket_name, Prefix=folder_path, Delimiter="/"
        )
        folders = []
        if "CommonPrefixes" in response:
            for prefix in response["CommonPrefixes"]:
                folders.append(prefix["Prefix"])
        return folders
    except Exception as e:
        return []


def load_metadata_fn(comic_id: str):
    try:
        # Load character data
        character_data = {}
        character_path = f"s3://blix-demo-v0/{comic_id}/characters/characters.json"
        char_data = json.loads(aws_utils.fetch_from_s3(character_path).decode("utf-8"))
        character_data = {
            name: char for name, char in char_data.items()
        }

        # Load episode data
        episode_folders = list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/")
        episode_indices = []
        for folder in episode_folders:
            if "episode" in folder:
                idx = int(folder.split("/")[2].split("-")[-1])
                episode_indices.append(idx)
        if not episode_indices:
            return (gr.update(choices=[]), None, {})

        # Return the values
        min_episode = min(episode_indices)
        return (
            gr.update(choices=episode_indices, value=min_episode),
            min_episode,
            character_data,
        )
    except Exception as e:
        gr.Warning(f"Error loading metadata: {e}")
        return (gr.update(choices=[]), None, {})


def load_episode_data(comic_id: str, episode_num: int):
    try:
        print(f"For episode: {episode_num}")
        json_path = (
            f"s3://{AWS_BUCKET}/{comic_id}/episodes/episode-{episode_num}/episode.json"
        )
        episode_data = json.loads(aws_utils.fetch_from_s3(json_path).decode("utf-8"))
        frame_hash_map = {}
        count = 1
        for scene_idx, scene in enumerate(episode_data["scenes"]):
            for frame_idx, _ in enumerate(scene["frames"]):
                frame_hash_map[count] = {
                    "scene": scene_idx,
                    "frame": frame_idx,
                }
                count += 1
        return (episode_data, frame_hash_map)
    except Exception as e:
        print(
            f"Failed to load json dictionary for episode: {episode_num} at path: {json_path}"
        )
        import traceback as tc
        print(tc.format_exc())
        return {}, {}


def episode_dropdown_effect(comic_id, selected_episode):
    episode_data, frame_hash_map = load_episode_data(comic_id, selected_episode)
    current_frame = min(list(frame_hash_map.keys()))
    return (
        gr.update(choices=list(frame_hash_map.keys()), value=current_frame),
        selected_episode,
        current_frame,
        episode_data,
        frame_hash_map,
    )


def load_data(episodes_data: dict, current_frame: int, frame_hash_map: dict):
    try:
        image_list = []
        scene_num, frame_num = (
            frame_hash_map[current_frame]["scene"],
            frame_hash_map[current_frame]["frame"],
        )
        curr_frame = episodes_data["scenes"][scene_num]["frames"][frame_num]

        for comp in curr_frame["compositions"]:
            # Fetch image from S3
            data = aws_utils.fetch_from_s3(comp["image"])
            if data:
                image = Image.open(io.BytesIO(data))
                image_list.append(image)
            else:
                print(f"Failed to load image from: {comp['image']}")

        return (
            image_list,  # Return the image list to be displayed in the gallery
            curr_frame["description"],
            curr_frame["narration"],
            curr_frame["audio_cue_character"],
            curr_frame["audio_cue_text"],
            curr_frame["location"],
            curr_frame["frame_setting"],
            curr_frame["compositions"][0]["prompt"],
            curr_frame["compositions"][0]["seed"],
            curr_frame["compositions"][1]["prompt"],
            curr_frame["compositions"][1]["seed"],
            curr_frame["compositions"][2]["prompt"],
            curr_frame["compositions"][2]["seed"],
            curr_frame["compositions"][3]["prompt"],
            curr_frame["compositions"][3]["seed"],
        )
    except Exception as e:
        print("Error in load_data:", str(e))  # Debugging the error
        gr.Warning("Failed to load data. Check logs!")


def update_characters(character_data: dict, current_frame: int, frame_hash_map: dict, episode_data: dict):
    scene_num, frame_num = (
        frame_hash_map[current_frame]["scene"],
        frame_hash_map[current_frame]["frame"],
    )
    curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num]
    return gr.CheckboxGroup(
        choices=list(character_data.keys()),
        value=[char["name"] for char in curr_frame["characters"]],
    )


def load_data_next(
    comic_id: str,
    current_episode: int,
    current_frame: int,
    frame_hash_map: dict,
    episode_data: dict,
):
    if current_frame + 1 < list(frame_hash_map.keys())[-1]:
        current_frame += 1
    else:
        current_episode += 1
        episode_data, frame_hash_map = load_episode_data(comic_id, current_episode)
        if len(episode_data) < 1:
            gr.Warning("All episodes finished.")
            return
        current_frame = min(list(frame_hash_map.keys()))
    return (
        gr.update(value=current_episode),
        gr.update(choices=list(frame_hash_map.keys()), value=current_frame),
        current_episode,
        current_frame,
        episode_data,
        frame_hash_map,
    )


def load_data_prev(
    comic_id: str,
    current_episode: int,
    current_frame: int,
    frame_hash_map: dict,
    episode_data: dict,
):
    if current_frame - 1 >= list(frame_hash_map.keys())[0]:
        current_frame -= 1
    else:
        current_episode -= 1
        episode_data, frame_hash_map = load_episode_data(comic_id, current_episode)
        if len(episode_data) < 1:
            gr.Warning("No previous episode found.")
            return
        current_frame = min(list(frame_hash_map.keys()))
    return (
        gr.update(value=current_episode),
        gr.update(choices=list(frame_hash_map.keys()), value=current_frame),
        current_episode,
        current_frame,
        episode_data,
        frame_hash_map,
    )


def save_image(
    selected_image: ..., comic_id: str, current_episode: int, current_frame: int
):
    with Image.open(selected_image[0]) as img:
        img_bytes = io.BytesIO()
        img.convert("RGB").save(img_bytes, "JPEG")
        img_bytes.seek(0)
        aws_utils.save_to_s3(
            AWS_BUCKET,
            f"{comic_id}/episode-{current_episode}/images",
            img_bytes,
            f"{current_frame}.jpg",
        )
    gr.Info("Saved Image successfully!")


def regenerate_compositions(
    image_description: str,
    narration: str,
    character: str,
    dialouge: str,
    location: str,
    setting: str,
    rel_chars: list,
    current_episode: int,
    current_frame: int,
    episodes_data: dict,
    frame_hash_map: dict,
    character_data: dict,
):
    try:
        print(
            f"Regenerating composition data for episode {current_episode}, frame {current_frame}"
        )
        scene_num, frame_num = (
            frame_hash_map[current_frame]["scene"],
            frame_hash_map[current_frame]["frame"],
        )
        prev_frame = {}
        if frame_num-1 > 0:
            prev_frame = episodes_data["scenes"][scene_num]["frames"][frame_num-1]

        try:
            related_chars = [character_data[char] for char in rel_chars]
            prompt_dict = {
                "system": script_gen.generate_image_compositions_instruction,
                "user": jinja2.Template(
                    script_gen.generate_image_compositions_user_prompt
                ).render(
                    {
                        "FRAME": {
                            "description": image_description,
                            "narration": narration,
                            "audio_cue_text": dialouge,
                            "audio_cue_character": character,
                            "location": location,
                            "frame_setting": setting,
                            "characters": json.dumps(related_chars),
                        },
                        "LOCATION_DESCRIPTION": prev_frame.get("location", ""),
                        "frame_settings": prev_frame.get("frame_setting", ""),
                    }
                ),
            }

            print("Generating compositions using LLM")
            compositions = llm.generate_valid_json_response(prompt_dict)
            comps = compositions["compositions"]
        except Exception as e:
            print(f"Error updating frame compositions: {e}")
            raise

        print("Composition data regenerated successfully.")
        return [
            comps[0]["prompt"],
            comps[1]["prompt"],
            comps[2]["prompt"],
            comps[3]["prompt"],
        ]
    except Exception as e:
        print(f"Error in regenerate_composition_data: {e}")
        return [""] * 8


def regenerate_images(
    current_episode: int,
    current_frame: int,
    visual_style: str,
    height: int,
    width: int,
    character_data: dict,
    rel_chars: dict,
    prompt_1: str,
    seed_1: str,
    prompt_2: str,
    seed_2: str,
    prompt_3: str,
    seed_3: str,
    prompt_4: str,
    seed_4: str,
):
    image_list = []
    try:        
        print(
            f"Regenerating data for episode {current_episode}, and frame {current_frame}"
        )
        related_chars = [character_data[ch]["profile_image"] for ch in rel_chars]
        new_compositions = [
            {
                "prompt": prompt_1,
                "seed": seed_1,
            },
            {
                "prompt": prompt_2,
                "seed": seed_2,
            },
            {
                "prompt": prompt_3,
                "seed": seed_3,
            },
            {
                "prompt": prompt_4,
                "seed": seed_4,
            },
        ]

        for i, composition in enumerate(new_compositions):
            try:
                print(f"Generating image for composition {i}")
                prompt = composition["prompt"]
                if "NOCHAR" in prompt:
                    prompt = prompt.replace(
                        "NOCHAR", ""
                    )
                payload = {
                    "prompt": prompt,
                    "characters": related_chars,
                    "parameters": {
                        "height": height,
                        "width": width,
                        "visual_style": visual_style,
                        "seed": composition["seed"],
                    },
                }

                data = iowrapper.get_valid_post_response(
                    url=f"{parameters.MODEL_SERVER_URL}generate_image",
                    payload=payload,
                )
                image_list.append(Image.open(io.BytesIO(base64.b64decode(data["image"]))))
            except Exception as e:
                print(f"Error processing composition {i}: {e}")
                continue

        print(f"Generated new images for episode: {current_episode} and frame: {current_frame}")
        print(f"Length of image list: {len(image_list)}")
        return image_list
    except Exception as e:
        print(f"Error in regenerate_data: {e}")
        gr.Warning("Failed to generate new images!")
        return []


def save_comic_data(
    current_episode: int,
    current_frame: int,
    episode_data: dict,
    comic_id: str,
    image_description: str,
    narration: str,
    character: str,
    dialogue: str,
    location: str,
    setting: str,
    prompt_1: str,
    prompt_2: str,
    prompt_3: str,
    prompt_4: str,
    frame_hash_map: dict,
    rel_chars: list,
    character_data: dict,
    images: list
):
    try:
        scene_num, frame_num = (
            frame_hash_map[current_frame]["scene"],
            frame_hash_map[current_frame]["frame"],
        )
        curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num]
        print(
            f"Saving comic data for episode {current_episode}, frame {frame_num}"
        )

        # Update compositions with prompts
        prompts_list = [prompt_1, prompt_2, prompt_3, prompt_4]
        for i, comp in enumerate(curr_frame["compositions"]):
            comp["prompt"] = prompts_list[i]
            # Save new images to S3
            with Image.open(images[i][0]) as img:
                img_bytes = io.BytesIO()
                img.convert("RGB").save(img_bytes, "JPEG")
                img_bytes.seek(0)
                aws_utils.save_to_s3(
                    parameters.AWS_BUCKET,
                    f"{comic_id}/episodes/episode-{current_episode}/compositions/scene-{scene_num}/frame-{frame_num}",
                    img_bytes,
                    f"{i}.jpg",
                )

        # Update frame data
        curr_frame.update(
            {
                "description": image_description,
                "narration": narration,
                "audio_cue_text": dialogue,
                "location": location,
                "setting": setting,
                "audio_cue_character": character,
                "characters": [character_data[char] for char in rel_chars],
            }
        )

        # Save the updated episode back to S3
        print(f"Saving updated episode {current_episode} to S3")
        aws_utils.save_to_s3(
            bucket_name=parameters.AWS_BUCKET,
            folder_name=f"{comic_id}/episodes/episode-{current_episode}",
            content=episode_data,
            file_name="episode.json",
        )
        gr.Info("Comic data saved successfully!")
    except Exception as e:
        print(f"Error in saving comic data: {e}")
        gr.Warning("Failed to save data for the comic!")