comic-grading / core.py
Manish Gupta
Modified codebase to make the loading times drastically better.
e3460b7
raw
history blame
14.8 kB
from typing import List
from PIL import Image
import json
import gradio as gr
import io
import jinja2
import base64
import aws_utils
import parameters
import script_gen
import inout as iowrapper
import openai_wrapper
import json
import base64
AWS_BUCKET = parameters.AWS_BUCKET
llm = openai_wrapper.GPT_4O_MINI
#### Functions ordered by their order of developement.
def toggle_developer_options(is_developer: bool):
if is_developer:
# Return visibility updates for the developer options along with the values
return gr.update(visible=True)
else:
# Hide the developer options and return only the updated visibility
return gr.update(visible=False)
def list_current_dir(bucket_name: str, folder_path: str = "") -> list:
try:
response = aws_utils.S3_CLIENT.list_objects_v2(
Bucket=bucket_name, Prefix=folder_path, Delimiter="/"
)
folders = []
if "CommonPrefixes" in response:
for prefix in response["CommonPrefixes"]:
folders.append(prefix["Prefix"])
return folders
except Exception as e:
return []
def load_metadata_fn(comic_id: str):
try:
# Load character data
character_data = {}
character_path = f"s3://blix-demo-v0/{comic_id}/characters/characters.json"
char_data = json.loads(aws_utils.fetch_from_s3(character_path).decode("utf-8"))
character_data = {
name: char for name, char in char_data.items()
}
# Load episode data
episode_folders = list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/")
episode_indices = []
for folder in episode_folders:
if "episode" in folder:
idx = int(folder.split("/")[2].split("-")[-1])
episode_indices.append(idx)
if not episode_indices:
return (gr.update(choices=[]), None, {})
# Return the values
min_episode = min(episode_indices)
return (
gr.update(choices=episode_indices, value=min_episode),
min_episode,
character_data,
)
except Exception as e:
gr.Warning(f"Error loading metadata: {e}")
return (gr.update(choices=[]), None, {})
def load_episode_data(comic_id: str, episode_num: int):
try:
print(f"For episode: {episode_num}")
json_path = (
f"s3://{AWS_BUCKET}/{comic_id}/episodes/episode-{episode_num}/episode.json"
)
episode_data = json.loads(aws_utils.fetch_from_s3(json_path).decode("utf-8"))
frame_hash_map = {}
count = 1
for scene_idx, scene in enumerate(episode_data["scenes"]):
for frame_idx, _ in enumerate(scene["frames"]):
frame_hash_map[count] = {
"scene": scene_idx,
"frame": frame_idx,
}
count += 1
return (episode_data, frame_hash_map)
except Exception as e:
print(
f"Failed to load json dictionary for episode: {episode_num} at path: {json_path}"
)
import traceback as tc
print(tc.format_exc())
return {}, {}
def episode_dropdown_effect(comic_id, selected_episode):
episode_data, frame_hash_map = load_episode_data(comic_id, selected_episode)
current_frame = min(list(frame_hash_map.keys()))
return (
gr.update(choices=list(frame_hash_map.keys()), value=current_frame),
selected_episode,
current_frame,
episode_data,
frame_hash_map,
)
def load_data(episodes_data: dict, current_frame: int, frame_hash_map: dict):
try:
image_list = []
scene_num, frame_num = (
frame_hash_map[current_frame]["scene"],
frame_hash_map[current_frame]["frame"],
)
curr_frame = episodes_data["scenes"][scene_num]["frames"][frame_num]
for comp in curr_frame["compositions"]:
# Fetch image from S3
data = aws_utils.fetch_from_s3(comp["image"])
if data:
image = Image.open(io.BytesIO(data))
image_list.append(image)
else:
print(f"Failed to load image from: {comp['image']}")
return (
image_list, # Return the image list to be displayed in the gallery
curr_frame["description"],
curr_frame["narration"],
curr_frame["audio_cue_character"],
curr_frame["audio_cue_text"],
curr_frame["location"],
curr_frame["frame_setting"],
curr_frame["compositions"][0]["prompt"],
curr_frame["compositions"][0]["seed"],
curr_frame["compositions"][1]["prompt"],
curr_frame["compositions"][1]["seed"],
curr_frame["compositions"][2]["prompt"],
curr_frame["compositions"][2]["seed"],
curr_frame["compositions"][3]["prompt"],
curr_frame["compositions"][3]["seed"],
)
except Exception as e:
print("Error in load_data:", str(e)) # Debugging the error
gr.Warning("Failed to load data. Check logs!")
def update_characters(character_data: dict, current_frame: int, frame_hash_map: dict, episode_data: dict):
scene_num, frame_num = (
frame_hash_map[current_frame]["scene"],
frame_hash_map[current_frame]["frame"],
)
curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num]
return gr.CheckboxGroup(
choices=list(character_data.keys()),
value=[char["name"] for char in curr_frame["characters"]],
)
def load_data_next(
comic_id: str,
current_episode: int,
current_frame: int,
frame_hash_map: dict,
episode_data: dict,
):
if current_frame + 1 < list(frame_hash_map.keys())[-1]:
current_frame += 1
else:
current_episode += 1
episode_data, frame_hash_map = load_episode_data(comic_id, current_episode)
if len(episode_data) < 1:
gr.Warning("All episodes finished.")
return
current_frame = min(list(frame_hash_map.keys()))
return (
gr.update(value=current_episode),
gr.update(choices=list(frame_hash_map.keys()), value=current_frame),
current_episode,
current_frame,
episode_data,
frame_hash_map,
)
def load_data_prev(
comic_id: str,
current_episode: int,
current_frame: int,
frame_hash_map: dict,
episode_data: dict,
):
if current_frame - 1 >= list(frame_hash_map.keys())[0]:
current_frame -= 1
else:
current_episode -= 1
episode_data, frame_hash_map = load_episode_data(comic_id, current_episode)
if len(episode_data) < 1:
gr.Warning("No previous episode found.")
return
current_frame = min(list(frame_hash_map.keys()))
return (
gr.update(value=current_episode),
gr.update(choices=list(frame_hash_map.keys()), value=current_frame),
current_episode,
current_frame,
episode_data,
frame_hash_map,
)
def save_image(
selected_image: ..., comic_id: str, current_episode: int, current_frame: int
):
with Image.open(selected_image[0]) as img:
img_bytes = io.BytesIO()
img.convert("RGB").save(img_bytes, "JPEG")
img_bytes.seek(0)
aws_utils.save_to_s3(
AWS_BUCKET,
f"{comic_id}/episode-{current_episode}/images",
img_bytes,
f"{current_frame}.jpg",
)
gr.Info("Saved Image successfully!")
def regenerate_compositions(
image_description: str,
narration: str,
character: str,
dialouge: str,
location: str,
setting: str,
rel_chars: list,
current_episode: int,
current_frame: int,
episodes_data: dict,
frame_hash_map: dict,
character_data: dict,
):
try:
print(
f"Regenerating composition data for episode {current_episode}, frame {current_frame}"
)
scene_num, frame_num = (
frame_hash_map[current_frame]["scene"],
frame_hash_map[current_frame]["frame"],
)
prev_frame = {}
if frame_num-1 > 0:
prev_frame = episodes_data["scenes"][scene_num]["frames"][frame_num-1]
try:
related_chars = [character_data[char] for char in rel_chars]
prompt_dict = {
"system": script_gen.generate_image_compositions_instruction,
"user": jinja2.Template(
script_gen.generate_image_compositions_user_prompt
).render(
{
"FRAME": {
"description": image_description,
"narration": narration,
"audio_cue_text": dialouge,
"audio_cue_character": character,
"location": location,
"frame_setting": setting,
"characters": json.dumps(related_chars),
},
"LOCATION_DESCRIPTION": prev_frame.get("location", ""),
"frame_settings": prev_frame.get("frame_setting", ""),
}
),
}
print("Generating compositions using LLM")
compositions = llm.generate_valid_json_response(prompt_dict)
comps = compositions["compositions"]
except Exception as e:
print(f"Error updating frame compositions: {e}")
raise
print("Composition data regenerated successfully.")
return [
comps[0]["prompt"],
comps[1]["prompt"],
comps[2]["prompt"],
comps[3]["prompt"],
]
except Exception as e:
print(f"Error in regenerate_composition_data: {e}")
return [""] * 8
def regenerate_images(
current_episode: int,
current_frame: int,
visual_style: str,
height: int,
width: int,
character_data: dict,
rel_chars: dict,
prompt_1: str,
seed_1: str,
prompt_2: str,
seed_2: str,
prompt_3: str,
seed_3: str,
prompt_4: str,
seed_4: str,
):
image_list = []
try:
print(
f"Regenerating data for episode {current_episode}, and frame {current_frame}"
)
related_chars = [character_data[ch]["profile_image"] for ch in rel_chars]
new_compositions = [
{
"prompt": prompt_1,
"seed": seed_1,
},
{
"prompt": prompt_2,
"seed": seed_2,
},
{
"prompt": prompt_3,
"seed": seed_3,
},
{
"prompt": prompt_4,
"seed": seed_4,
},
]
for i, composition in enumerate(new_compositions):
try:
print(f"Generating image for composition {i}")
prompt = composition["prompt"]
if "NOCHAR" in prompt:
prompt = prompt.replace(
"NOCHAR", ""
)
payload = {
"prompt": prompt,
"characters": related_chars,
"parameters": {
"height": height,
"width": width,
"visual_style": visual_style,
"seed": composition["seed"],
},
}
data = iowrapper.get_valid_post_response(
url=f"{parameters.MODEL_SERVER_URL}generate_image",
payload=payload,
)
image_list.append(Image.open(io.BytesIO(base64.b64decode(data["image"]))))
except Exception as e:
print(f"Error processing composition {i}: {e}")
continue
print(f"Generated new images for episode: {current_episode} and frame: {current_frame}")
print(f"Length of image list: {len(image_list)}")
return image_list
except Exception as e:
print(f"Error in regenerate_data: {e}")
gr.Warning("Failed to generate new images!")
return []
def save_comic_data(
current_episode: int,
current_frame: int,
episode_data: dict,
comic_id: str,
image_description: str,
narration: str,
character: str,
dialogue: str,
location: str,
setting: str,
prompt_1: str,
prompt_2: str,
prompt_3: str,
prompt_4: str,
frame_hash_map: dict,
rel_chars: list,
character_data: dict,
images: list
):
try:
scene_num, frame_num = (
frame_hash_map[current_frame]["scene"],
frame_hash_map[current_frame]["frame"],
)
curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num]
print(
f"Saving comic data for episode {current_episode}, frame {frame_num}"
)
# Update compositions with prompts
prompts_list = [prompt_1, prompt_2, prompt_3, prompt_4]
for i, comp in enumerate(curr_frame["compositions"]):
comp["prompt"] = prompts_list[i]
# Save new images to S3
with Image.open(images[i][0]) as img:
img_bytes = io.BytesIO()
img.convert("RGB").save(img_bytes, "JPEG")
img_bytes.seek(0)
aws_utils.save_to_s3(
parameters.AWS_BUCKET,
f"{comic_id}/episodes/episode-{current_episode}/compositions/scene-{scene_num}/frame-{frame_num}",
img_bytes,
f"{i}.jpg",
)
# Update frame data
curr_frame.update(
{
"description": image_description,
"narration": narration,
"audio_cue_text": dialogue,
"location": location,
"setting": setting,
"audio_cue_character": character,
"characters": [character_data[char] for char in rel_chars],
}
)
# Save the updated episode back to S3
print(f"Saving updated episode {current_episode} to S3")
aws_utils.save_to_s3(
bucket_name=parameters.AWS_BUCKET,
folder_name=f"{comic_id}/episodes/episode-{current_episode}",
content=episode_data,
file_name="episode.json",
)
gr.Info("Comic data saved successfully!")
except Exception as e:
print(f"Error in saving comic data: {e}")
gr.Warning("Failed to save data for the comic!")