comic-grading / core.py
dev-immersfy's picture
Major Speed update (#1)
01a3261 verified
from typing import List
from PIL import Image
import json
import gradio as gr
import io
import jinja2
import base64
import aws_utils
import parameters
import script_gen
import inout as iowrapper
import openai_wrapper
import json
import base64
AWS_BUCKET = parameters.AWS_BUCKET
llm = openai_wrapper.GPT_4O_MINI
#### Functions ordered by their order of developement.
def toggle_developer_options(is_developer: bool):
if is_developer:
# Return visibility updates for the developer options along with the values
return gr.update(visible=True)
else:
# Hide the developer options and return only the updated visibility
return gr.update(visible=False)
def list_current_dir(bucket_name: str, folder_path: str = "") -> list:
try:
response = aws_utils.S3_CLIENT.list_objects_v2(
Bucket=bucket_name, Prefix=folder_path, Delimiter="/"
)
folders = []
if "CommonPrefixes" in response:
for prefix in response["CommonPrefixes"]:
folders.append(prefix["Prefix"])
return folders
except Exception as e:
return []
def load_metadata_fn(comic_id: str):
try:
# Load character data
character_data = {}
character_path = f"s3://blix-demo-v0/{comic_id}/characters/characters.json"
char_data = json.loads(aws_utils.fetch_from_s3(character_path).decode("utf-8"))
character_data = {
name: char for name, char in char_data.items()
}
# Load episode data
episode_folders = list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/")
episode_indices = []
for folder in episode_folders:
if "episode" in folder:
idx = int(folder.split("/")[2].split("-")[-1])
episode_indices.append(idx)
if not episode_indices:
return (gr.update(choices=[]), None, {})
# Return the values
min_episode = min(episode_indices)
return (
gr.update(choices=episode_indices, value=min_episode),
min_episode,
character_data,
)
except Exception as e:
gr.Warning(f"Error loading metadata: {e}")
return (gr.update(choices=[]), None, {})
def load_episode_data(comic_id: str, episode_num: int):
try:
print(f"For episode: {episode_num}")
json_path = (
f"s3://{AWS_BUCKET}/{comic_id}/episodes/episode-{episode_num}/episode.json"
)
episode_data = json.loads(aws_utils.fetch_from_s3(json_path).decode("utf-8"))
frame_hash_map = {}
count = 1
for scene_idx, scene in enumerate(episode_data["scenes"]):
for frame_idx, _ in enumerate(scene["frames"]):
frame_hash_map[count] = {
"scene": scene_idx,
"frame": frame_idx,
}
count += 1
return (episode_data, frame_hash_map)
except Exception as e:
print(
f"Failed to load json dictionary for episode: {episode_num} at path: {json_path}"
)
import traceback as tc
print(tc.format_exc())
return {}, {}
def episode_dropdown_effect(comic_id, selected_episode):
episode_data, frame_hash_map = load_episode_data(comic_id, selected_episode)
current_frame = min(list(frame_hash_map.keys()))
return (
gr.update(choices=list(frame_hash_map.keys()), value=current_frame),
selected_episode,
current_frame,
episode_data,
frame_hash_map,
)
def load_data(episodes_data: dict, current_frame: int, frame_hash_map: dict):
try:
image_list = []
scene_num, frame_num = (
frame_hash_map[current_frame]["scene"],
frame_hash_map[current_frame]["frame"],
)
curr_frame = episodes_data["scenes"][scene_num]["frames"][frame_num]
for comp in curr_frame["compositions"]:
# Fetch image from S3
data = aws_utils.fetch_from_s3(comp["image"])
if data:
image = Image.open(io.BytesIO(data))
image_list.append(image)
else:
print(f"Failed to load image from: {comp['image']}")
return (
image_list, # Return the image list to be displayed in the gallery
curr_frame["description"],
curr_frame["narration"],
curr_frame["audio_cue_character"],
curr_frame["audio_cue_text"],
curr_frame["location"],
curr_frame["frame_setting"],
curr_frame["compositions"][0]["prompt"],
curr_frame["compositions"][0]["seed"],
curr_frame["compositions"][1]["prompt"],
curr_frame["compositions"][1]["seed"],
curr_frame["compositions"][2]["prompt"],
curr_frame["compositions"][2]["seed"],
curr_frame["compositions"][3]["prompt"],
curr_frame["compositions"][3]["seed"],
)
except Exception as e:
print("Error in load_data:", str(e)) # Debugging the error
gr.Warning("Failed to load data. Check logs!")
def update_characters(character_data: dict, current_frame: int, frame_hash_map: dict, episode_data: dict):
scene_num, frame_num = (
frame_hash_map[current_frame]["scene"],
frame_hash_map[current_frame]["frame"],
)
curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num]
return gr.CheckboxGroup(
choices=list(character_data.keys()),
value=[char["name"] for char in curr_frame["characters"]],
)
def load_data_next(
comic_id: str,
current_episode: int,
current_frame: int,
frame_hash_map: dict,
episode_data: dict,
):
if current_frame + 1 < list(frame_hash_map.keys())[-1]:
current_frame += 1
else:
current_episode += 1
episode_data, frame_hash_map = load_episode_data(comic_id, current_episode)
if len(episode_data) < 1:
gr.Warning("All episodes finished.")
return
current_frame = min(list(frame_hash_map.keys()))
return (
gr.update(value=current_episode),
gr.update(choices=list(frame_hash_map.keys()), value=current_frame),
current_episode,
current_frame,
episode_data,
frame_hash_map,
)
def load_data_prev(
comic_id: str,
current_episode: int,
current_frame: int,
frame_hash_map: dict,
episode_data: dict,
):
if current_frame - 1 >= list(frame_hash_map.keys())[0]:
current_frame -= 1
else:
current_episode -= 1
episode_data, frame_hash_map = load_episode_data(comic_id, current_episode)
if len(episode_data) < 1:
gr.Warning("No previous episode found.")
return
current_frame = min(list(frame_hash_map.keys()))
return (
gr.update(value=current_episode),
gr.update(choices=list(frame_hash_map.keys()), value=current_frame),
current_episode,
current_frame,
episode_data,
frame_hash_map,
)
def save_image(
selected_image: ..., comic_id: str, current_episode: int, current_frame: int
):
with Image.open(selected_image[0]) as img:
img_bytes = io.BytesIO()
img.convert("RGB").save(img_bytes, "JPEG")
img_bytes.seek(0)
aws_utils.save_to_s3(
AWS_BUCKET,
f"{comic_id}/episode-{current_episode}/images",
img_bytes,
f"{current_frame}.jpg",
)
gr.Info("Saved Image successfully!")
def regenerate_compositions(
image_description: str,
narration: str,
character: str,
dialouge: str,
location: str,
setting: str,
rel_chars: list,
current_episode: int,
current_frame: int,
episodes_data: dict,
frame_hash_map: dict,
character_data: dict,
):
try:
print(
f"Regenerating composition data for episode {current_episode}, frame {current_frame}"
)
scene_num, frame_num = (
frame_hash_map[current_frame]["scene"],
frame_hash_map[current_frame]["frame"],
)
prev_frame = {}
if frame_num-1 > 0:
prev_frame = episodes_data["scenes"][scene_num]["frames"][frame_num-1]
try:
related_chars = [character_data[char] for char in rel_chars]
prompt_dict = {
"system": script_gen.generate_image_compositions_instruction,
"user": jinja2.Template(
script_gen.generate_image_compositions_user_prompt
).render(
{
"FRAME": {
"description": image_description,
"narration": narration,
"audio_cue_text": dialouge,
"audio_cue_character": character,
"location": location,
"frame_setting": setting,
"characters": json.dumps(related_chars),
},
"LOCATION_DESCRIPTION": prev_frame.get("location", ""),
"frame_settings": prev_frame.get("frame_setting", ""),
}
),
}
print("Generating compositions using LLM")
compositions = llm.generate_valid_json_response(prompt_dict)
comps = compositions["compositions"]
except Exception as e:
print(f"Error updating frame compositions: {e}")
raise
print("Composition data regenerated successfully.")
return [
comps[0]["prompt"],
comps[1]["prompt"],
comps[2]["prompt"],
comps[3]["prompt"],
]
except Exception as e:
print(f"Error in regenerate_composition_data: {e}")
return [""] * 8
def regenerate_images(
current_episode: int,
current_frame: int,
visual_style: str,
height: int,
width: int,
character_data: dict,
rel_chars: dict,
prompt_1: str,
seed_1: str,
prompt_2: str,
seed_2: str,
prompt_3: str,
seed_3: str,
prompt_4: str,
seed_4: str,
):
image_list = []
try:
print(
f"Regenerating data for episode {current_episode}, and frame {current_frame}"
)
related_chars = [character_data[ch]["profile_image"] for ch in rel_chars]
new_compositions = [
{
"prompt": prompt_1,
"seed": seed_1,
},
{
"prompt": prompt_2,
"seed": seed_2,
},
{
"prompt": prompt_3,
"seed": seed_3,
},
{
"prompt": prompt_4,
"seed": seed_4,
},
]
for i, composition in enumerate(new_compositions):
try:
print(f"Generating image for composition {i}")
prompt = composition["prompt"]
if "NOCHAR" in prompt:
prompt = prompt.replace(
"NOCHAR", ""
)
payload = {
"prompt": prompt,
"characters": related_chars,
"parameters": {
"height": height,
"width": width,
"visual_style": visual_style,
"seed": composition["seed"],
},
}
data = iowrapper.get_valid_post_response(
url=f"{parameters.MODEL_SERVER_URL}generate_image",
payload=payload,
)
image_list.append(Image.open(io.BytesIO(base64.b64decode(data["image"]))))
except Exception as e:
print(f"Error processing composition {i}: {e}")
continue
print(f"Generated new images for episode: {current_episode} and frame: {current_frame}")
print(f"Length of image list: {len(image_list)}")
return image_list
except Exception as e:
print(f"Error in regenerate_data: {e}")
gr.Warning("Failed to generate new images!")
return []
def save_comic_data(
current_episode: int,
current_frame: int,
episode_data: dict,
comic_id: str,
image_description: str,
narration: str,
character: str,
dialogue: str,
location: str,
setting: str,
prompt_1: str,
prompt_2: str,
prompt_3: str,
prompt_4: str,
frame_hash_map: dict,
rel_chars: list,
character_data: dict,
images: list
):
try:
scene_num, frame_num = (
frame_hash_map[current_frame]["scene"],
frame_hash_map[current_frame]["frame"],
)
curr_frame = episode_data["scenes"][scene_num]["frames"][frame_num]
print(
f"Saving comic data for episode {current_episode}, frame {frame_num}"
)
# Update compositions with prompts
prompts_list = [prompt_1, prompt_2, prompt_3, prompt_4]
for i, comp in enumerate(curr_frame["compositions"]):
comp["prompt"] = prompts_list[i]
# Save new images to S3
with Image.open(images[i][0]) as img:
img_bytes = io.BytesIO()
img.convert("RGB").save(img_bytes, "JPEG")
img_bytes.seek(0)
aws_utils.save_to_s3(
parameters.AWS_BUCKET,
f"{comic_id}/episodes/episode-{current_episode}/compositions/scene-{scene_num}/frame-{frame_num}",
img_bytes,
f"{i}.jpg",
)
# Update frame data
curr_frame.update(
{
"description": image_description,
"narration": narration,
"audio_cue_text": dialogue,
"location": location,
"setting": setting,
"audio_cue_character": character,
"characters": [character_data[char] for char in rel_chars],
}
)
# Save the updated episode back to S3
print(f"Saving updated episode {current_episode} to S3")
aws_utils.save_to_s3(
bucket_name=parameters.AWS_BUCKET,
folder_name=f"{comic_id}/episodes/episode-{current_episode}",
content=episode_data,
file_name="episode.json",
)
gr.Info("Comic data saved successfully!")
except Exception as e:
print(f"Error in saving comic data: {e}")
gr.Warning("Failed to save data for the comic!")