Spaces:
Sleeping
Sleeping
| from typing import List | |
| from PIL import Image | |
| import gradio as gr | |
| import dataclasses | |
| import io | |
| import jinja2 | |
| import base64 | |
| import aws_utils | |
| import parameters | |
| import script_gen | |
| import inout as iowrapper | |
| import openai_wrapper | |
| import json | |
| from dataclasses import asdict | |
| AWS_BUCKET = parameters.AWS_BUCKET | |
| llm = openai_wrapper.GPT_4O_MINI | |
| class Composition: | |
| prompt: str | |
| shot_type: str | |
| seed: int | |
| image: str | |
| class ComicFrame: | |
| description: str | |
| narration: str | |
| character_dilouge: str | |
| character: str | |
| location: str | |
| setting: str | |
| all_characters: list | |
| compositions: List[Composition] = dataclasses.field(default_factory=list) | |
| def list_current_dir(bucket_name: str, folder_path: str = "") -> list: | |
| try: | |
| response = aws_utils.S3_CLIENT.list_objects_v2( | |
| Bucket=bucket_name, Prefix=folder_path, Delimiter="/" | |
| ) | |
| folders = [] | |
| if "CommonPrefixes" in response: | |
| for prefix in response["CommonPrefixes"]: | |
| folders.append(prefix["Prefix"]) | |
| return folders | |
| except Exception as e: | |
| return [] | |
| def load_data_inner( | |
| episodes_data: list, current_episode: int, current_frame: int, is_developer: bool | |
| ): | |
| try: | |
| images = [] | |
| curr_frame = episodes_data[current_episode][current_frame] | |
| for comp in curr_frame.compositions: | |
| data = aws_utils.fetch_from_s3(comp.image) | |
| images.append(Image.open(io.BytesIO(data))) | |
| return ( | |
| images, | |
| episodes_data, | |
| current_episode, | |
| current_frame, | |
| gr.Textbox(value=curr_frame.description, interactive=is_developer), | |
| gr.Textbox(value=curr_frame.narration, interactive=is_developer), | |
| gr.Textbox(value=curr_frame.character, interactive=is_developer), | |
| gr.Textbox(value=curr_frame.character_dilouge, interactive=is_developer), | |
| gr.Textbox(value=curr_frame.location, interactive=is_developer), | |
| curr_frame.setting, | |
| curr_frame.compositions[0].prompt, | |
| curr_frame.compositions[0].seed, | |
| curr_frame.compositions[1].prompt, | |
| curr_frame.compositions[1].seed, | |
| curr_frame.compositions[2].prompt, | |
| curr_frame.compositions[2].seed, | |
| curr_frame.compositions[3].prompt, | |
| curr_frame.compositions[3].seed, | |
| curr_frame.all_characters, | |
| ) | |
| except Exception as e: | |
| return ( | |
| [], | |
| episodes_data, | |
| current_episode, | |
| current_frame, | |
| gr.Textbox(), | |
| gr.Textbox(), | |
| gr.Textbox(), | |
| gr.Textbox(), | |
| gr.Textbox(), | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| [], | |
| ) | |
| def load_metadata_fn(comic_id: str): | |
| try: | |
| episodes_data = {} | |
| episode_idx = [] | |
| character_data = {} | |
| details = {} | |
| character_path = f"s3://blix-demo-v0/{comic_id}/characters/characters.json" | |
| char_data = eval(aws_utils.fetch_from_s3(source=character_path).decode("utf-8")) | |
| for name, char in char_data.items(): | |
| character_data[name] = char["profile_image"] | |
| for folder in list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/"): | |
| if "episode" in folder: | |
| json_path = f"s3://{AWS_BUCKET}/{folder}episode.json" | |
| idx = int(folder.split("/")[2].split("-")[-1]) | |
| episode_idx.append(idx) | |
| data = eval(aws_utils.fetch_from_s3(source=json_path).decode("utf-8")) | |
| comic_frames = [] | |
| details[idx] = {} | |
| cumulative_frame_count = 0 | |
| for scene_num, scene in enumerate(data["scenes"]): | |
| scene_frame_count = len(scene["frames"]) | |
| cumulative_frame_count += scene_frame_count | |
| details[idx][scene_num] = cumulative_frame_count | |
| for frame in scene["frames"]: | |
| comic_frames.append( | |
| ComicFrame( | |
| description=frame["description"], | |
| narration=frame["narration"], | |
| character=frame["audio_cue_character"], | |
| character_dilouge=frame["audio_cue_text"], | |
| compositions=[ | |
| Composition(**comp) | |
| for comp in frame["compositions"] | |
| ], | |
| location=frame["location"], | |
| setting=frame["frame_setting"], | |
| all_characters=[ | |
| char["name"] for char in frame["characters"] | |
| ], | |
| ) | |
| ) | |
| episodes_data[idx] = comic_frames | |
| current_episode, current_frame = min(episode_idx), 0 | |
| return ( | |
| gr.update(choices=episode_idx, value=episode_idx[0]), | |
| gr.update( | |
| choices=range(len(episodes_data[current_episode])), value=current_frame | |
| ), | |
| current_episode, | |
| current_frame, | |
| episodes_data, | |
| character_data, | |
| details, | |
| gr.Checkbox(visible=True), | |
| ) | |
| except Exception as e: | |
| return ( | |
| gr.update(choices=[]), | |
| gr.update(choices=[]), | |
| {}, | |
| {}, | |
| {}, | |
| gr.Checkbox(visible=False), | |
| ) | |
| def load_data_next( | |
| episodes_data: list, current_episode: int, current_frame: int, is_developer: bool | |
| ): | |
| if current_frame + 1 < len(episodes_data[current_episode]): | |
| current_frame += 1 | |
| elif current_episode + 1 < len(episodes_data): | |
| current_episode += 1 | |
| current_frame = 0 | |
| else: | |
| return [], current_episode, current_frame | |
| return ( | |
| gr.update(value=current_episode), | |
| gr.update(value=current_frame), | |
| *load_data_inner(episodes_data, current_episode, current_frame, is_developer), | |
| ) | |
| def load_data_prev( | |
| episodes_data: list, current_episode: int, current_frame: int, is_developer: bool | |
| ): | |
| if current_frame - 1 >= 0: | |
| current_frame -= 1 | |
| elif current_episode - 1 > min(list(episodes_data.keys())): | |
| current_episode -= 1 | |
| current_frame = 0 | |
| else: | |
| return [], current_episode, current_frame | |
| return ( | |
| gr.update(value=current_episode), | |
| gr.update(value=current_frame), | |
| *load_data_inner(episodes_data, current_episode, current_frame, is_developer), | |
| ) | |
| def load_from_dropdown( | |
| episodes_data: dict, selected_episode: int, selected_frame: int, is_developer: bool | |
| ): | |
| return ( | |
| gr.update(value=selected_episode), | |
| gr.update(value=selected_frame), | |
| *load_data_inner(episodes_data, selected_episode, selected_frame, is_developer), | |
| ) | |
| def load_dropdown_fn(selected_episode): | |
| return (gr.update(value=selected_episode), gr.update(value=0), selected_episode, 0) | |
| def load_dropdown_fn_v2(selected_frame): | |
| return selected_frame | |
| def save_image(selected_image, comic_id: str, current_episode: int, current_frame: int): | |
| with Image.open(selected_image[0]) as img: | |
| img_bytes = io.BytesIO() | |
| img.convert("RGB").save(img_bytes, "JPEG") | |
| img_bytes.seek(0) | |
| aws_utils.save_to_s3( | |
| AWS_BUCKET, | |
| f"{comic_id}/episode-{current_episode}/images", | |
| img_bytes, | |
| f"{current_frame}.jpg", | |
| ) | |
| gr.Info("Saved Image successfully!") | |
| def toggle_developer_options( | |
| is_developer: bool, prompt_1, prompt_2, prompt_3, prompt_4, setting | |
| ): | |
| if is_developer: | |
| # Return visibility updates for the developer options along with the values | |
| return gr.update(visible=True), prompt_1, prompt_2, prompt_3, prompt_4, setting | |
| else: | |
| # Hide the developer options and return only the updated visibility | |
| return gr.update(visible=False), prompt_1, prompt_2, prompt_3, prompt_4, setting | |
| def regenerate_composition_data( | |
| image_description, | |
| narration, | |
| character, | |
| dialouge, | |
| location, | |
| setting, | |
| chars, | |
| current_episode: int, | |
| current_frame: int, | |
| episodes_data: dict, | |
| ): | |
| try: | |
| print( | |
| f"Regenerating composition data for episode {current_episode}, frame {current_frame}" | |
| ) | |
| frame = episodes_data[current_episode][current_frame] | |
| try: | |
| print("Creating prompt template for composition generation") | |
| prompt_template = jinja2.Template( | |
| script_gen.generate_image_compositions_user_prompt | |
| ) | |
| except Exception as e: | |
| print(f"Error creating prompt template: {e}") | |
| raise | |
| try: | |
| print("Rendering prompt with frame details") | |
| prompt_dict = { | |
| "system": script_gen.generate_image_compositions_instruction, | |
| "user": prompt_template.render( | |
| { | |
| "FRAME": { | |
| "description": image_description, | |
| "narration": narration, | |
| "character_dilouge": dialouge, | |
| "character": character, | |
| "location": location, | |
| "setting": setting, | |
| "all_characters": chars, | |
| } | |
| } | |
| ), | |
| } | |
| except Exception as e: | |
| print(f"Error rendering prompt: {e}") | |
| raise | |
| try: | |
| print("Generating compositions using LLM") | |
| compositions = llm.generate_valid_json_response(prompt_dict) | |
| except Exception as e: | |
| print(f"Error generating compositions: {e}") | |
| raise | |
| try: | |
| print("Updating frame compositions") | |
| frame.compositions = [ | |
| Composition( | |
| **comp, | |
| seed=( | |
| frame.compositions[idx].seed | |
| if idx < len(frame.compositions) | |
| else "" | |
| ), | |
| image=( | |
| frame.compositions[idx].image | |
| if idx < len(frame.compositions) | |
| else "" | |
| ), | |
| ) | |
| for idx, comp in enumerate(compositions["compositions"]) | |
| ] | |
| except Exception as e: | |
| print(f"Error updating frame compositions: {e}") | |
| raise | |
| print("Composition data regenerated successfully.") | |
| return [ | |
| frame.compositions[0].prompt, | |
| frame.compositions[0].seed, | |
| frame.compositions[1].prompt, | |
| frame.compositions[1].seed, | |
| frame.compositions[2].prompt, | |
| frame.compositions[2].seed, | |
| frame.compositions[3].prompt, | |
| frame.compositions[3].seed, | |
| ] | |
| except Exception as e: | |
| print(f"Error in regenerate_composition_data: {e}") | |
| return [""] * 8 | |
| def regenerate_data( | |
| comic_id, | |
| current_episode, | |
| current_scene, | |
| current_frame, | |
| episodes_data, | |
| character_data, | |
| visual_style, | |
| height, | |
| width, | |
| ): | |
| images = [] | |
| image_data_b64 = [] | |
| try: | |
| print( | |
| f"Regenerating data for episode {current_episode}, scene {current_scene}, frame {current_frame}" | |
| ) | |
| frame = episodes_data[current_episode][current_frame] | |
| related_chars = [character_data[ch] for ch in frame.all_characters] | |
| for i, composition in enumerate(frame.compositions): | |
| try: | |
| print(f"Generating image for composition {i}") | |
| payload = { | |
| "prompt": composition.prompt, | |
| "characters": related_chars, | |
| "parameters": { | |
| "height": height, | |
| "width": width, | |
| "visual_style": visual_style, | |
| "seed": composition.seed, | |
| }, | |
| } | |
| try: | |
| print(f"Sending request to generate image for composition {i}") | |
| data = iowrapper.get_valid_post_response( | |
| url=f"{parameters.MODEL_SERVER_URL}generate_image", | |
| payload=payload, | |
| ) | |
| print(f"Image generated for composition {i}. Decoding image data.") | |
| image_data = io.BytesIO(base64.b64decode(data["image"])) | |
| image_data_b64.append(image_data) | |
| images.append(Image.open(image_data)) | |
| except Exception as e: | |
| print(f"Error generating image for composition {i}: {e}") | |
| continue | |
| except Exception as e: | |
| print(f"Error processing composition {i}: {e}") | |
| continue | |
| print("Data regeneration completed.") | |
| return images, image_data_b64 | |
| except Exception as e: | |
| print(f"Error in regenerate_data: {e}") | |
| return [], [] | |
| def save_image_compositions( | |
| current_episode: int, | |
| current_frame: int, | |
| details: dict, | |
| comic_id: str, | |
| image_description, | |
| narration, | |
| character, | |
| dialogue, # Fixed typo from 'dialouge' to 'dialogue' | |
| location, | |
| setting, | |
| chars, | |
| prompt_1, | |
| prompt_2, | |
| prompt_3, | |
| prompt_4, | |
| ): | |
| try: | |
| print( | |
| f"Saving image components for episode {current_episode}, frame {current_frame}" | |
| ) | |
| # Fetch episode details early and return if not found | |
| episode_details = details.get(current_episode) | |
| if not episode_details: | |
| print(f"Episode {current_episode} not found!") | |
| return None | |
| # Determine scene number and frame number within the scene | |
| scene_num, frame_num_in_scene = None, 0 | |
| for scene_idx, cumulative_frame_count in enumerate(episode_details.items()): | |
| if current_frame < cumulative_frame_count[1]: | |
| scene_num = cumulative_frame_count[0] | |
| frame_num_in_scene = current_frame - ( | |
| episode_details.get(scene_num - 1, 0) | |
| ) | |
| break | |
| if scene_num is None: | |
| print(f"Scene not found for frame {current_frame}.") | |
| return None | |
| # Fetch episode data from S3 | |
| episode_path = f"s3://blix-demo-v0/{comic_id}/episodes/episode-{current_episode}/episode.json" | |
| print(f"Fetching episode from S3: {episode_path}") | |
| episode_json = aws_utils.fetch_from_s3(episode_path).decode("utf-8") | |
| episode = json.loads(episode_json) | |
| frame_data = episode["scenes"][scene_num]["frames"][frame_num_in_scene] | |
| print( | |
| f"Updating compositions for scene {scene_num}, frame {frame_num_in_scene}" | |
| ) | |
| # Update compositions with prompts | |
| prompts_list = [prompt_1, prompt_2, prompt_3, prompt_4] | |
| frame_data["compositions"] = [ | |
| { | |
| "prompt": prompts_list[i], | |
| "shot_type": comp["shot_type"], | |
| "seed": comp["seed"], | |
| "image": comp["image"], | |
| } | |
| for i, comp in enumerate(frame_data["compositions"]) | |
| ] | |
| # Batch update frame data | |
| frame_data.update( | |
| { | |
| "description": image_description, | |
| "narration": narration, | |
| "audio_cue_text": dialogue, | |
| "location": location, | |
| "setting": setting, | |
| "audio_cue_character": character, | |
| } | |
| ) | |
| # Save the updated episode back to S3 | |
| print(f"Saving updated episode to S3 at {episode_path}") | |
| aws_utils.save_to_s3( | |
| bucket_name=parameters.AWS_BUCKET, | |
| folder_name=f"{comic_id}/episodes/episode-{current_episode}", | |
| content=json.dumps(episode), | |
| file_name="episode.json", | |
| ) | |
| gr.Info("Components saved successfully!") | |
| return scene_num | |
| except Exception as e: | |
| print(f"Error in save_image_compositions: {e}") | |
| return None | |
| def save_images( | |
| image_data_b64, | |
| current_episode, | |
| current_frame, | |
| current_scene, | |
| comic_id, | |
| ): | |
| try: | |
| print( | |
| f"Saving images for scene {current_scene}, episode {current_episode}, frame {current_frame}." | |
| ) | |
| for i, image_data in enumerate(image_data_b64): | |
| try: | |
| print(f"Saving image {i} to S3") | |
| aws_utils.save_to_s3( | |
| parameters.AWS_BUCKET, | |
| f"{comic_id}/episodes/episode-{current_episode}/compositions/scene-{current_scene}/frame-{current_frame}", | |
| image_data, | |
| f"{i}.jpg", | |
| ) | |
| except Exception as e: | |
| print(f"Error saving image {i} to S3: {e}") | |
| continue | |
| gr.Info("All Images saved successfully!") | |
| except Exception as e: | |
| print(f"Error in save_images: {e}") | |