from typing import List from PIL import Image import json import gradio as gr import io import jinja2 import base64 import aws_utils import parameters import script_gen import inout as iowrapper import openai_wrapper import json import base64 AWS_BUCKET = parameters.AWS_BUCKET llm = openai_wrapper.GPT_4O_MINI #### Functions ordered by their order of developement. def toggle_developer_options(is_developer: bool): if is_developer: # Return visibility updates for the developer options along with the values return gr.update(visible=True) else: # Hide the developer options and return only the updated visibility return gr.update(visible=False) def list_current_dir(bucket_name: str, folder_path: str = "") -> list: try: response = aws_utils.S3_CLIENT.list_objects_v2( Bucket=bucket_name, Prefix=folder_path, Delimiter="/" ) folders = [] if "CommonPrefixes" in response: for prefix in response["CommonPrefixes"]: folders.append(prefix["Prefix"]) return folders except Exception as e: return [] def load_metadata_fn(comic_id: str): try: # Load character data character_data = {} character_path = f"s3://blix-demo-v0/{comic_id}/characters/characters.json" char_data = json.loads(aws_utils.fetch_from_s3(character_path).decode("utf-8")) character_data = { name: char for name, char in char_data.items() } # Load episode data episode_folders = list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/") episode_indices = [] for folder in episode_folders: if "episode" in folder: idx = int(folder.split("/")[2].split("-")[-1]) episode_indices.append(idx) if not episode_indices: return (gr.update(choices=[]), None, {}) # Return the values min_episode = min(episode_indices) return ( gr.update(choices=episode_indices, value=min_episode), min_episode, character_data, ) except Exception as e: gr.Warning(f"Error loading metadata: {e}") return (gr.update(choices=[]), None, {}) def load_episode_data(comic_id: str, episode_num: int): try: print(f"For episode: {episode_num}") json_path = ( f"s3://{AWS_BUCKET}/{comic_id}/episodes/episode-{episode_num}/episode.json" ) episode_data = json.loads(aws_utils.fetch_from_s3(json_path).decode("utf-8")) frame_hash_map = {} count = 1 for scene_idx, scene in enumerate(episode_data["scenes"]): for frame_idx, _ in enumerate(scene["frames"]): frame_hash_map[count] = { "scene": scene_idx, "frame": frame_idx, } count += 1 return (episode_data, frame_hash_map) except Exception as e: print( f"Failed to load json dictionary for episode: {episode_num} at path: {json_path}" ) import traceback as tc print(tc.format_exc()) return {}, {} def episode_dropdown_effect(comic_id, selected_episode): episode_data, frame_hash_map = load_episode_data(comic_id, selected_episode) current_frame = min(list(frame_hash_map.keys())) return ( gr.update(choices=list(frame_hash_map.keys()), value=current_frame), selected_episode, current_frame, episode_data, frame_hash_map, ) def load_data(episodes_data: dict, current_frame: int, frame_hash_map: dict): try: image_list = [] scene_num, frame_num = ( frame_hash_map[current_frame]["scene"], frame_hash_map[current_frame]["frame"], ) curr_frame = episodes_data["scenes"][scene_num]["frames"][frame_num] for comp in curr_frame["compositions"]: # Fetch image from S3 data = aws_utils.fetch_from_s3(comp["image"]) if data: image = Image.open(io.BytesIO(data)) image_list.append(image) else: print(f"Failed to load image from: {comp['image']}") return ( image_list, # Return the image list to be displayed in the gallery curr_frame["description"], curr_frame["narration"], curr_frame["audio_cue_character"], curr_frame["audio_cue_text"], curr_frame["location"], curr_frame["frame_setting"], curr_frame["compositions"][0]["prompt"], curr_frame["compositions"][0]["seed"], curr_frame["compositions"][1]["prompt"], curr_frame["compositions"][1]["seed"], curr_frame["compositions"][2]["prompt"], curr_frame["compositions"][2]["seed"], curr_frame["compositions"][3]["prompt"], curr_frame["compositions"][3]["seed"], ) except Exception as e: print("Error in load_data:", str(e)) # Debugging the error gr.Warning("Failed to load data. Check logs!") def update_characters(character_data: dict, current_frame: int, frame_hash_map: dict, episode_data: dict): scene_num, frame_num = ( frame_hash_map[current_frame]["scene"], frame_hash_map[current_frame]["frame"], ) curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num] return gr.CheckboxGroup( choices=list(character_data.keys()), value=[char["name"] for char in curr_frame["characters"]], ) def load_data_next( comic_id: str, current_episode: int, current_frame: int, frame_hash_map: dict, episode_data: dict, ): if current_frame + 1 < list(frame_hash_map.keys())[-1]: current_frame += 1 else: current_episode += 1 episode_data, frame_hash_map = load_episode_data(comic_id, current_episode) if len(episode_data) < 1: gr.Warning("All episodes finished.") return current_frame = min(list(frame_hash_map.keys())) return ( gr.update(value=current_episode), gr.update(choices=list(frame_hash_map.keys()), value=current_frame), current_episode, current_frame, episode_data, frame_hash_map, ) def load_data_prev( comic_id: str, current_episode: int, current_frame: int, frame_hash_map: dict, episode_data: dict, ): if current_frame - 1 >= list(frame_hash_map.keys())[0]: current_frame -= 1 else: current_episode -= 1 episode_data, frame_hash_map = load_episode_data(comic_id, current_episode) if len(episode_data) < 1: gr.Warning("No previous episode found.") return current_frame = min(list(frame_hash_map.keys())) return ( gr.update(value=current_episode), gr.update(choices=list(frame_hash_map.keys()), value=current_frame), current_episode, current_frame, episode_data, frame_hash_map, ) def save_image( selected_image: ..., comic_id: str, current_episode: int, current_frame: int ): with Image.open(selected_image[0]) as img: img_bytes = io.BytesIO() img.convert("RGB").save(img_bytes, "JPEG") img_bytes.seek(0) aws_utils.save_to_s3( AWS_BUCKET, f"{comic_id}/episode-{current_episode}/images", img_bytes, f"{current_frame}.jpg", ) gr.Info("Saved Image successfully!") def regenerate_compositions( image_description: str, narration: str, character: str, dialouge: str, location: str, setting: str, rel_chars: list, current_episode: int, current_frame: int, episodes_data: dict, frame_hash_map: dict, character_data: dict, ): try: print( f"Regenerating composition data for episode {current_episode}, frame {current_frame}" ) scene_num, frame_num = ( frame_hash_map[current_frame]["scene"], frame_hash_map[current_frame]["frame"], ) prev_frame = {} if frame_num-1 > 0: prev_frame = episodes_data["scenes"][scene_num]["frames"][frame_num-1] try: related_chars = [character_data[char] for char in rel_chars] prompt_dict = { "system": script_gen.generate_image_compositions_instruction, "user": jinja2.Template( script_gen.generate_image_compositions_user_prompt ).render( { "FRAME": { "description": image_description, "narration": narration, "audio_cue_text": dialouge, "audio_cue_character": character, "location": location, "frame_setting": setting, "characters": json.dumps(related_chars), }, "LOCATION_DESCRIPTION": prev_frame.get("location", ""), "frame_settings": prev_frame.get("frame_setting", ""), } ), } print("Generating compositions using LLM") compositions = llm.generate_valid_json_response(prompt_dict) comps = compositions["compositions"] except Exception as e: print(f"Error updating frame compositions: {e}") raise print("Composition data regenerated successfully.") return [ comps[0]["prompt"], comps[1]["prompt"], comps[2]["prompt"], comps[3]["prompt"], ] except Exception as e: print(f"Error in regenerate_composition_data: {e}") return [""] * 8 def regenerate_images( current_episode: int, current_frame: int, visual_style: str, height: int, width: int, character_data: dict, rel_chars: dict, prompt_1: str, seed_1: str, prompt_2: str, seed_2: str, prompt_3: str, seed_3: str, prompt_4: str, seed_4: str, ): image_list = [] try: print( f"Regenerating data for episode {current_episode}, and frame {current_frame}" ) related_chars = [character_data[ch]["profile_image"] for ch in rel_chars] new_compositions = [ { "prompt": prompt_1, "seed": seed_1, }, { "prompt": prompt_2, "seed": seed_2, }, { "prompt": prompt_3, "seed": seed_3, }, { "prompt": prompt_4, "seed": seed_4, }, ] for i, composition in enumerate(new_compositions): try: print(f"Generating image for composition {i}") prompt = composition["prompt"] if "NOCHAR" in prompt: prompt = prompt.replace( "NOCHAR", "" ) payload = { "prompt": prompt, "characters": related_chars, "parameters": { "height": height, "width": width, "visual_style": visual_style, "seed": composition["seed"], }, } data = iowrapper.get_valid_post_response( url=f"{parameters.MODEL_SERVER_URL}generate_image", payload=payload, ) image_list.append(Image.open(io.BytesIO(base64.b64decode(data["image"])))) except Exception as e: print(f"Error processing composition {i}: {e}") continue print(f"Generated new images for episode: {current_episode} and frame: {current_frame}") print(f"Length of image list: {len(image_list)}") return image_list except Exception as e: print(f"Error in regenerate_data: {e}") gr.Warning("Failed to generate new images!") return [] def save_comic_data( current_episode: int, current_frame: int, episode_data: dict, comic_id: str, image_description: str, narration: str, character: str, dialogue: str, location: str, setting: str, prompt_1: str, prompt_2: str, prompt_3: str, prompt_4: str, frame_hash_map: dict, rel_chars: list, character_data: dict, images: list ): try: scene_num, frame_num = ( frame_hash_map[current_frame]["scene"], frame_hash_map[current_frame]["frame"], ) curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num] print( f"Saving comic data for episode {current_episode}, frame {frame_num}" ) # Update compositions with prompts prompts_list = [prompt_1, prompt_2, prompt_3, prompt_4] for i, comp in enumerate(curr_frame["compositions"]): comp["prompt"] = prompts_list[i] # Save new images to S3 with Image.open(images[i][0]) as img: img_bytes = io.BytesIO() img.convert("RGB").save(img_bytes, "JPEG") img_bytes.seek(0) aws_utils.save_to_s3( parameters.AWS_BUCKET, f"{comic_id}/episodes/episode-{current_episode}/compositions/scene-{scene_num}/frame-{frame_num}", img_bytes, f"{i}.jpg", ) # Update frame data curr_frame.update( { "description": image_description, "narration": narration, "audio_cue_text": dialogue, "location": location, "setting": setting, "audio_cue_character": character, "characters": [character_data[char] for char in rel_chars], } ) # Save the updated episode back to S3 print(f"Saving updated episode {current_episode} to S3") aws_utils.save_to_s3( bucket_name=parameters.AWS_BUCKET, folder_name=f"{comic_id}/episodes/episode-{current_episode}", content=episode_data, file_name="episode.json", ) gr.Info("Comic data saved successfully!") except Exception as e: print(f"Error in saving comic data: {e}") gr.Warning("Failed to save data for the comic!")