Spaces:
Sleeping
Sleeping
| from typing import List | |
| from PIL import Image | |
| import json | |
| import gradio as gr | |
| import io | |
| import jinja2 | |
| import base64 | |
| import aws_utils | |
| import parameters | |
| import script_gen | |
| import inout as iowrapper | |
| import openai_wrapper | |
| import json | |
| import base64 | |
| AWS_BUCKET = parameters.AWS_BUCKET | |
| llm = openai_wrapper.GPT_4O_MINI | |
| #### Functions ordered by their order of developement. | |
| def toggle_developer_options(is_developer: bool): | |
| if is_developer: | |
| # Return visibility updates for the developer options along with the values | |
| return gr.update(visible=True) | |
| else: | |
| # Hide the developer options and return only the updated visibility | |
| return gr.update(visible=False) | |
| def list_current_dir(bucket_name: str, folder_path: str = "") -> list: | |
| try: | |
| response = aws_utils.S3_CLIENT.list_objects_v2( | |
| Bucket=bucket_name, Prefix=folder_path, Delimiter="/" | |
| ) | |
| folders = [] | |
| if "CommonPrefixes" in response: | |
| for prefix in response["CommonPrefixes"]: | |
| folders.append(prefix["Prefix"]) | |
| return folders | |
| except Exception as e: | |
| return [] | |
| def load_metadata_fn(comic_id: str): | |
| try: | |
| # Load character data | |
| character_data = {} | |
| character_path = f"s3://blix-demo-v0/{comic_id}/characters/characters.json" | |
| char_data = json.loads(aws_utils.fetch_from_s3(character_path).decode("utf-8")) | |
| character_data = { | |
| name: char for name, char in char_data.items() | |
| } | |
| # Load episode data | |
| episode_folders = list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/") | |
| episode_indices = [] | |
| for folder in episode_folders: | |
| if "episode" in folder: | |
| idx = int(folder.split("/")[2].split("-")[-1]) | |
| episode_indices.append(idx) | |
| if not episode_indices: | |
| return (gr.update(choices=[]), None, {}) | |
| # Return the values | |
| min_episode = min(episode_indices) | |
| return ( | |
| gr.update(choices=episode_indices, value=min_episode), | |
| min_episode, | |
| character_data, | |
| ) | |
| except Exception as e: | |
| gr.Warning(f"Error loading metadata: {e}") | |
| return (gr.update(choices=[]), None, {}) | |
| def load_episode_data(comic_id: str, episode_num: int): | |
| try: | |
| print(f"For episode: {episode_num}") | |
| json_path = ( | |
| f"s3://{AWS_BUCKET}/{comic_id}/episodes/episode-{episode_num}/episode.json" | |
| ) | |
| episode_data = json.loads(aws_utils.fetch_from_s3(json_path).decode("utf-8")) | |
| frame_hash_map = {} | |
| count = 1 | |
| for scene_idx, scene in enumerate(episode_data["scenes"]): | |
| for frame_idx, _ in enumerate(scene["frames"]): | |
| frame_hash_map[count] = { | |
| "scene": scene_idx, | |
| "frame": frame_idx, | |
| } | |
| count += 1 | |
| return (episode_data, frame_hash_map) | |
| except Exception as e: | |
| print( | |
| f"Failed to load json dictionary for episode: {episode_num} at path: {json_path}" | |
| ) | |
| import traceback as tc | |
| print(tc.format_exc()) | |
| return {}, {} | |
| def episode_dropdown_effect(comic_id, selected_episode): | |
| episode_data, frame_hash_map = load_episode_data(comic_id, selected_episode) | |
| current_frame = min(list(frame_hash_map.keys())) | |
| return ( | |
| gr.update(choices=list(frame_hash_map.keys()), value=current_frame), | |
| selected_episode, | |
| current_frame, | |
| episode_data, | |
| frame_hash_map, | |
| ) | |
| def load_data(episodes_data: dict, current_frame: int, frame_hash_map: dict): | |
| try: | |
| image_list = [] | |
| scene_num, frame_num = ( | |
| frame_hash_map[current_frame]["scene"], | |
| frame_hash_map[current_frame]["frame"], | |
| ) | |
| curr_frame = episodes_data["scenes"][scene_num]["frames"][frame_num] | |
| for comp in curr_frame["compositions"]: | |
| # Fetch image from S3 | |
| data = aws_utils.fetch_from_s3(comp["image"]) | |
| if data: | |
| image = Image.open(io.BytesIO(data)) | |
| image_list.append(image) | |
| else: | |
| print(f"Failed to load image from: {comp['image']}") | |
| return ( | |
| image_list, # Return the image list to be displayed in the gallery | |
| curr_frame["description"], | |
| curr_frame["narration"], | |
| curr_frame["audio_cue_character"], | |
| curr_frame["audio_cue_text"], | |
| curr_frame["location"], | |
| curr_frame["frame_setting"], | |
| curr_frame["compositions"][0]["prompt"], | |
| curr_frame["compositions"][0]["seed"], | |
| curr_frame["compositions"][1]["prompt"], | |
| curr_frame["compositions"][1]["seed"], | |
| curr_frame["compositions"][2]["prompt"], | |
| curr_frame["compositions"][2]["seed"], | |
| curr_frame["compositions"][3]["prompt"], | |
| curr_frame["compositions"][3]["seed"], | |
| ) | |
| except Exception as e: | |
| print("Error in load_data:", str(e)) # Debugging the error | |
| gr.Warning("Failed to load data. Check logs!") | |
| def update_characters(character_data: dict, current_frame: int, frame_hash_map: dict, episode_data: dict): | |
| scene_num, frame_num = ( | |
| frame_hash_map[current_frame]["scene"], | |
| frame_hash_map[current_frame]["frame"], | |
| ) | |
| curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num] | |
| return gr.CheckboxGroup( | |
| choices=list(character_data.keys()), | |
| value=[char["name"] for char in curr_frame["characters"]], | |
| ) | |
| def load_data_next( | |
| comic_id: str, | |
| current_episode: int, | |
| current_frame: int, | |
| frame_hash_map: dict, | |
| episode_data: dict, | |
| ): | |
| if current_frame + 1 < list(frame_hash_map.keys())[-1]: | |
| current_frame += 1 | |
| else: | |
| current_episode += 1 | |
| episode_data, frame_hash_map = load_episode_data(comic_id, current_episode) | |
| if len(episode_data) < 1: | |
| gr.Warning("All episodes finished.") | |
| return | |
| current_frame = min(list(frame_hash_map.keys())) | |
| return ( | |
| gr.update(value=current_episode), | |
| gr.update(choices=list(frame_hash_map.keys()), value=current_frame), | |
| current_episode, | |
| current_frame, | |
| episode_data, | |
| frame_hash_map, | |
| ) | |
| def load_data_prev( | |
| comic_id: str, | |
| current_episode: int, | |
| current_frame: int, | |
| frame_hash_map: dict, | |
| episode_data: dict, | |
| ): | |
| if current_frame - 1 >= list(frame_hash_map.keys())[0]: | |
| current_frame -= 1 | |
| else: | |
| current_episode -= 1 | |
| episode_data, frame_hash_map = load_episode_data(comic_id, current_episode) | |
| if len(episode_data) < 1: | |
| gr.Warning("No previous episode found.") | |
| return | |
| current_frame = min(list(frame_hash_map.keys())) | |
| return ( | |
| gr.update(value=current_episode), | |
| gr.update(choices=list(frame_hash_map.keys()), value=current_frame), | |
| current_episode, | |
| current_frame, | |
| episode_data, | |
| frame_hash_map, | |
| ) | |
| def save_image( | |
| selected_image: ..., comic_id: str, current_episode: int, current_frame: int | |
| ): | |
| with Image.open(selected_image[0]) as img: | |
| img_bytes = io.BytesIO() | |
| img.convert("RGB").save(img_bytes, "JPEG") | |
| img_bytes.seek(0) | |
| aws_utils.save_to_s3( | |
| AWS_BUCKET, | |
| f"{comic_id}/episode-{current_episode}/images", | |
| img_bytes, | |
| f"{current_frame}.jpg", | |
| ) | |
| gr.Info("Saved Image successfully!") | |
| def regenerate_compositions( | |
| image_description: str, | |
| narration: str, | |
| character: str, | |
| dialouge: str, | |
| location: str, | |
| setting: str, | |
| rel_chars: list, | |
| current_episode: int, | |
| current_frame: int, | |
| episodes_data: dict, | |
| frame_hash_map: dict, | |
| character_data: dict, | |
| ): | |
| try: | |
| print( | |
| f"Regenerating composition data for episode {current_episode}, frame {current_frame}" | |
| ) | |
| scene_num, frame_num = ( | |
| frame_hash_map[current_frame]["scene"], | |
| frame_hash_map[current_frame]["frame"], | |
| ) | |
| prev_frame = {} | |
| if frame_num-1 > 0: | |
| prev_frame = episodes_data["scenes"][scene_num]["frames"][frame_num-1] | |
| try: | |
| related_chars = [character_data[char] for char in rel_chars] | |
| prompt_dict = { | |
| "system": script_gen.generate_image_compositions_instruction, | |
| "user": jinja2.Template( | |
| script_gen.generate_image_compositions_user_prompt | |
| ).render( | |
| { | |
| "FRAME": { | |
| "description": image_description, | |
| "narration": narration, | |
| "audio_cue_text": dialouge, | |
| "audio_cue_character": character, | |
| "location": location, | |
| "frame_setting": setting, | |
| "characters": json.dumps(related_chars), | |
| }, | |
| "LOCATION_DESCRIPTION": prev_frame.get("location", ""), | |
| "frame_settings": prev_frame.get("frame_setting", ""), | |
| } | |
| ), | |
| } | |
| print("Generating compositions using LLM") | |
| compositions = llm.generate_valid_json_response(prompt_dict) | |
| comps = compositions["compositions"] | |
| except Exception as e: | |
| print(f"Error updating frame compositions: {e}") | |
| raise | |
| print("Composition data regenerated successfully.") | |
| return [ | |
| comps[0]["prompt"], | |
| comps[1]["prompt"], | |
| comps[2]["prompt"], | |
| comps[3]["prompt"], | |
| ] | |
| except Exception as e: | |
| print(f"Error in regenerate_composition_data: {e}") | |
| return [""] * 8 | |
| def regenerate_images( | |
| current_episode: int, | |
| current_frame: int, | |
| visual_style: str, | |
| height: int, | |
| width: int, | |
| character_data: dict, | |
| rel_chars: dict, | |
| prompt_1: str, | |
| seed_1: str, | |
| prompt_2: str, | |
| seed_2: str, | |
| prompt_3: str, | |
| seed_3: str, | |
| prompt_4: str, | |
| seed_4: str, | |
| ): | |
| image_list = [] | |
| try: | |
| print( | |
| f"Regenerating data for episode {current_episode}, and frame {current_frame}" | |
| ) | |
| related_chars = [character_data[ch]["profile_image"] for ch in rel_chars] | |
| new_compositions = [ | |
| { | |
| "prompt": prompt_1, | |
| "seed": seed_1, | |
| }, | |
| { | |
| "prompt": prompt_2, | |
| "seed": seed_2, | |
| }, | |
| { | |
| "prompt": prompt_3, | |
| "seed": seed_3, | |
| }, | |
| { | |
| "prompt": prompt_4, | |
| "seed": seed_4, | |
| }, | |
| ] | |
| for i, composition in enumerate(new_compositions): | |
| try: | |
| print(f"Generating image for composition {i}") | |
| prompt = composition["prompt"] | |
| if "NOCHAR" in prompt: | |
| prompt = prompt.replace( | |
| "NOCHAR", "" | |
| ) | |
| payload = { | |
| "prompt": prompt, | |
| "characters": related_chars, | |
| "parameters": { | |
| "height": height, | |
| "width": width, | |
| "visual_style": visual_style, | |
| "seed": composition["seed"], | |
| }, | |
| } | |
| data = iowrapper.get_valid_post_response( | |
| url=f"{parameters.MODEL_SERVER_URL}generate_image", | |
| payload=payload, | |
| ) | |
| image_list.append(Image.open(io.BytesIO(base64.b64decode(data["image"])))) | |
| except Exception as e: | |
| print(f"Error processing composition {i}: {e}") | |
| continue | |
| print(f"Generated new images for episode: {current_episode} and frame: {current_frame}") | |
| print(f"Length of image list: {len(image_list)}") | |
| return image_list | |
| except Exception as e: | |
| print(f"Error in regenerate_data: {e}") | |
| gr.Warning("Failed to generate new images!") | |
| return [] | |
| def save_comic_data( | |
| current_episode: int, | |
| current_frame: int, | |
| episode_data: dict, | |
| comic_id: str, | |
| image_description: str, | |
| narration: str, | |
| character: str, | |
| dialogue: str, | |
| location: str, | |
| setting: str, | |
| prompt_1: str, | |
| prompt_2: str, | |
| prompt_3: str, | |
| prompt_4: str, | |
| frame_hash_map: dict, | |
| rel_chars: list, | |
| character_data: dict, | |
| images: list | |
| ): | |
| try: | |
| scene_num, frame_num = ( | |
| frame_hash_map[current_frame]["scene"], | |
| frame_hash_map[current_frame]["frame"], | |
| ) | |
| curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num] | |
| print( | |
| f"Saving comic data for episode {current_episode}, frame {frame_num}" | |
| ) | |
| # Update compositions with prompts | |
| prompts_list = [prompt_1, prompt_2, prompt_3, prompt_4] | |
| for i, comp in enumerate(curr_frame["compositions"]): | |
| comp["prompt"] = prompts_list[i] | |
| # Save new images to S3 | |
| with Image.open(images[i][0]) as img: | |
| img_bytes = io.BytesIO() | |
| img.convert("RGB").save(img_bytes, "JPEG") | |
| img_bytes.seek(0) | |
| aws_utils.save_to_s3( | |
| parameters.AWS_BUCKET, | |
| f"{comic_id}/episodes/episode-{current_episode}/compositions/scene-{scene_num}/frame-{frame_num}", | |
| img_bytes, | |
| f"{i}.jpg", | |
| ) | |
| # Update frame data | |
| curr_frame.update( | |
| { | |
| "description": image_description, | |
| "narration": narration, | |
| "audio_cue_text": dialogue, | |
| "location": location, | |
| "setting": setting, | |
| "audio_cue_character": character, | |
| "characters": [character_data[char] for char in rel_chars], | |
| } | |
| ) | |
| # Save the updated episode back to S3 | |
| print(f"Saving updated episode {current_episode} to S3") | |
| aws_utils.save_to_s3( | |
| bucket_name=parameters.AWS_BUCKET, | |
| folder_name=f"{comic_id}/episodes/episode-{current_episode}", | |
| content=episode_data, | |
| file_name="episode.json", | |
| ) | |
| gr.Info("Comic data saved successfully!") | |
| except Exception as e: | |
| print(f"Error in saving comic data: {e}") | |
| gr.Warning("Failed to save data for the comic!") | |