comic-grading / core.py
umang-immersfy's picture
comic grading updated code
c1254e9
raw
history blame
17.5 kB
from typing import List
from PIL import Image
import gradio as gr
import dataclasses
import io
import jinja2
import base64
import aws_utils
import parameters
import script_gen
import inout as iowrapper
import openai_wrapper
import json
from dataclasses import asdict
AWS_BUCKET = parameters.AWS_BUCKET
llm = openai_wrapper.GPT_4O_MINI
@dataclasses.dataclass
class Composition:
prompt: str
shot_type: str
seed: int
image: str
@dataclasses.dataclass
class ComicFrame:
description: str
narration: str
character_dilouge: str
character: str
location: str
setting: str
all_characters: list
compositions: List[Composition] = dataclasses.field(default_factory=list)
def list_current_dir(bucket_name: str, folder_path: str = "") -> list:
try:
response = aws_utils.S3_CLIENT.list_objects_v2(
Bucket=bucket_name, Prefix=folder_path, Delimiter="/"
)
folders = []
if "CommonPrefixes" in response:
for prefix in response["CommonPrefixes"]:
folders.append(prefix["Prefix"])
return folders
except Exception as e:
return []
def load_data_inner(
episodes_data: list, current_episode: int, current_frame: int, is_developer: bool
):
try:
images = []
curr_frame = episodes_data[current_episode][current_frame]
for comp in curr_frame.compositions:
data = aws_utils.fetch_from_s3(comp.image)
images.append(Image.open(io.BytesIO(data)))
return (
images,
episodes_data,
current_episode,
current_frame,
gr.Textbox(value=curr_frame.description, interactive=is_developer),
gr.Textbox(value=curr_frame.narration, interactive=is_developer),
gr.Textbox(value=curr_frame.character, interactive=is_developer),
gr.Textbox(value=curr_frame.character_dilouge, interactive=is_developer),
gr.Textbox(value=curr_frame.location, interactive=is_developer),
curr_frame.setting,
curr_frame.compositions[0].prompt,
curr_frame.compositions[0].seed,
curr_frame.compositions[1].prompt,
curr_frame.compositions[1].seed,
curr_frame.compositions[2].prompt,
curr_frame.compositions[2].seed,
curr_frame.compositions[3].prompt,
curr_frame.compositions[3].seed,
curr_frame.all_characters,
)
except Exception as e:
return (
[],
episodes_data,
current_episode,
current_frame,
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
gr.Textbox(),
"",
"",
"",
"",
"",
"",
"",
"",
"",
[],
)
def load_metadata_fn(comic_id: str):
try:
episodes_data = {}
episode_idx = []
character_data = {}
details = {}
character_path = f"s3://blix-demo-v0/{comic_id}/characters/characters.json"
char_data = eval(aws_utils.fetch_from_s3(source=character_path).decode("utf-8"))
for name, char in char_data.items():
character_data[name] = char["profile_image"]
for folder in list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/"):
if "episode" in folder:
json_path = f"s3://{AWS_BUCKET}/{folder}episode.json"
idx = int(folder.split("/")[2].split("-")[-1])
episode_idx.append(idx)
data = eval(aws_utils.fetch_from_s3(source=json_path).decode("utf-8"))
comic_frames = []
details[idx] = {}
cumulative_frame_count = 0
for scene_num, scene in enumerate(data["scenes"]):
scene_frame_count = len(scene["frames"])
cumulative_frame_count += scene_frame_count
details[idx][scene_num] = cumulative_frame_count
for frame in scene["frames"]:
comic_frames.append(
ComicFrame(
description=frame["description"],
narration=frame["narration"],
character=frame["audio_cue_character"],
character_dilouge=frame["audio_cue_text"],
compositions=[
Composition(**comp)
for comp in frame["compositions"]
],
location=frame["location"],
setting=frame["frame_setting"],
all_characters=[
char["name"] for char in frame["characters"]
],
)
)
episodes_data[idx] = comic_frames
current_episode, current_frame = min(episode_idx), 0
return (
gr.update(choices=episode_idx, value=episode_idx[0]),
gr.update(
choices=range(len(episodes_data[current_episode])), value=current_frame
),
current_episode,
current_frame,
episodes_data,
character_data,
details,
gr.Checkbox(visible=True),
)
except Exception as e:
return (
gr.update(choices=[]),
gr.update(choices=[]),
{},
{},
{},
gr.Checkbox(visible=False),
)
def load_data_next(
episodes_data: list, current_episode: int, current_frame: int, is_developer: bool
):
if current_frame + 1 < len(episodes_data[current_episode]):
current_frame += 1
elif current_episode + 1 < len(episodes_data):
current_episode += 1
current_frame = 0
else:
return [], current_episode, current_frame
return (
gr.update(value=current_episode),
gr.update(value=current_frame),
*load_data_inner(episodes_data, current_episode, current_frame, is_developer),
)
def load_data_prev(
episodes_data: list, current_episode: int, current_frame: int, is_developer: bool
):
if current_frame - 1 >= 0:
current_frame -= 1
elif current_episode - 1 > min(list(episodes_data.keys())):
current_episode -= 1
current_frame = 0
else:
return [], current_episode, current_frame
return (
gr.update(value=current_episode),
gr.update(value=current_frame),
*load_data_inner(episodes_data, current_episode, current_frame, is_developer),
)
def load_from_dropdown(
episodes_data: dict, selected_episode: int, selected_frame: int, is_developer: bool
):
return (
gr.update(value=selected_episode),
gr.update(value=selected_frame),
*load_data_inner(episodes_data, selected_episode, selected_frame, is_developer),
)
def load_dropdown_fn(selected_episode):
return (gr.update(value=selected_episode), gr.update(value=0), selected_episode, 0)
def load_dropdown_fn_v2(selected_frame):
return selected_frame
def save_image(selected_image, comic_id: str, current_episode: int, current_frame: int):
with Image.open(selected_image[0]) as img:
img_bytes = io.BytesIO()
img.convert("RGB").save(img_bytes, "JPEG")
img_bytes.seek(0)
aws_utils.save_to_s3(
AWS_BUCKET,
f"{comic_id}/episode-{current_episode}/images",
img_bytes,
f"{current_frame}.jpg",
)
gr.Info("Saved Image successfully!")
def toggle_developer_options(
is_developer: bool, prompt_1, prompt_2, prompt_3, prompt_4, setting
):
if is_developer:
# Return visibility updates for the developer options along with the values
return gr.update(visible=True), prompt_1, prompt_2, prompt_3, prompt_4, setting
else:
# Hide the developer options and return only the updated visibility
return gr.update(visible=False), prompt_1, prompt_2, prompt_3, prompt_4, setting
def regenerate_composition_data(
image_description,
narration,
character,
dialouge,
location,
setting,
chars,
current_episode: int,
current_frame: int,
episodes_data: dict,
):
try:
print(
f"Regenerating composition data for episode {current_episode}, frame {current_frame}"
)
frame = episodes_data[current_episode][current_frame]
try:
print("Creating prompt template for composition generation")
prompt_template = jinja2.Template(
script_gen.generate_image_compositions_user_prompt
)
except Exception as e:
print(f"Error creating prompt template: {e}")
raise
try:
print("Rendering prompt with frame details")
prompt_dict = {
"system": script_gen.generate_image_compositions_instruction,
"user": prompt_template.render(
{
"FRAME": {
"description": image_description,
"narration": narration,
"character_dilouge": dialouge,
"character": character,
"location": location,
"setting": setting,
"all_characters": chars,
}
}
),
}
except Exception as e:
print(f"Error rendering prompt: {e}")
raise
try:
print("Generating compositions using LLM")
compositions = llm.generate_valid_json_response(prompt_dict)
except Exception as e:
print(f"Error generating compositions: {e}")
raise
try:
print("Updating frame compositions")
frame.compositions = [
Composition(
**comp,
seed=(
frame.compositions[idx].seed
if idx < len(frame.compositions)
else ""
),
image=(
frame.compositions[idx].image
if idx < len(frame.compositions)
else ""
),
)
for idx, comp in enumerate(compositions["compositions"])
]
except Exception as e:
print(f"Error updating frame compositions: {e}")
raise
print("Composition data regenerated successfully.")
return [
frame.compositions[0].prompt,
frame.compositions[0].seed,
frame.compositions[1].prompt,
frame.compositions[1].seed,
frame.compositions[2].prompt,
frame.compositions[2].seed,
frame.compositions[3].prompt,
frame.compositions[3].seed,
]
except Exception as e:
print(f"Error in regenerate_composition_data: {e}")
return [""] * 8
def regenerate_data(
comic_id,
current_episode,
current_scene,
current_frame,
episodes_data,
character_data,
visual_style,
height,
width,
):
images = []
image_data_b64 = []
try:
print(
f"Regenerating data for episode {current_episode}, scene {current_scene}, frame {current_frame}"
)
frame = episodes_data[current_episode][current_frame]
related_chars = [character_data[ch] for ch in frame.all_characters]
for i, composition in enumerate(frame.compositions):
try:
print(f"Generating image for composition {i}")
payload = {
"prompt": composition.prompt,
"characters": related_chars,
"parameters": {
"height": height,
"width": width,
"visual_style": visual_style,
"seed": composition.seed,
},
}
try:
print(f"Sending request to generate image for composition {i}")
data = iowrapper.get_valid_post_response(
url=f"{parameters.MODEL_SERVER_URL}generate_image",
payload=payload,
)
print(f"Image generated for composition {i}. Decoding image data.")
image_data = io.BytesIO(base64.b64decode(data["image"]))
image_data_b64.append(image_data)
images.append(Image.open(image_data))
except Exception as e:
print(f"Error generating image for composition {i}: {e}")
continue
except Exception as e:
print(f"Error processing composition {i}: {e}")
continue
print("Data regeneration completed.")
return images, image_data_b64
except Exception as e:
print(f"Error in regenerate_data: {e}")
return [], []
def save_image_compositions(
current_episode: int,
current_frame: int,
details: dict,
comic_id: str,
image_description,
narration,
character,
dialogue, # Fixed typo from 'dialouge' to 'dialogue'
location,
setting,
chars,
prompt_1,
prompt_2,
prompt_3,
prompt_4,
):
try:
print(
f"Saving image components for episode {current_episode}, frame {current_frame}"
)
# Fetch episode details early and return if not found
episode_details = details.get(current_episode)
if not episode_details:
print(f"Episode {current_episode} not found!")
return None
# Determine scene number and frame number within the scene
scene_num, frame_num_in_scene = None, 0
for scene_idx, cumulative_frame_count in enumerate(episode_details.items()):
if current_frame < cumulative_frame_count[1]:
scene_num = cumulative_frame_count[0]
frame_num_in_scene = current_frame - (
episode_details.get(scene_num - 1, 0)
)
break
if scene_num is None:
print(f"Scene not found for frame {current_frame}.")
return None
# Fetch episode data from S3
episode_path = f"s3://blix-demo-v0/{comic_id}/episodes/episode-{current_episode}/episode.json"
print(f"Fetching episode from S3: {episode_path}")
episode_json = aws_utils.fetch_from_s3(episode_path).decode("utf-8")
episode = json.loads(episode_json)
frame_data = episode["scenes"][scene_num]["frames"][frame_num_in_scene]
print(
f"Updating compositions for scene {scene_num}, frame {frame_num_in_scene}"
)
# Update compositions with prompts
prompts_list = [prompt_1, prompt_2, prompt_3, prompt_4]
frame_data["compositions"] = [
{
"prompt": prompts_list[i],
"shot_type": comp["shot_type"],
"seed": comp["seed"],
"image": comp["image"],
}
for i, comp in enumerate(frame_data["compositions"])
]
# Batch update frame data
frame_data.update(
{
"description": image_description,
"narration": narration,
"audio_cue_text": dialogue,
"location": location,
"setting": setting,
"audio_cue_character": character,
}
)
# Save the updated episode back to S3
print(f"Saving updated episode to S3 at {episode_path}")
aws_utils.save_to_s3(
bucket_name=parameters.AWS_BUCKET,
folder_name=f"{comic_id}/episodes/episode-{current_episode}",
content=json.dumps(episode),
file_name="episode.json",
)
gr.Info("Components saved successfully!")
return scene_num
except Exception as e:
print(f"Error in save_image_compositions: {e}")
return None
def save_images(
image_data_b64,
current_episode,
current_frame,
current_scene,
comic_id,
):
try:
print(
f"Saving images for scene {current_scene}, episode {current_episode}, frame {current_frame}."
)
for i, image_data in enumerate(image_data_b64):
try:
print(f"Saving image {i} to S3")
aws_utils.save_to_s3(
parameters.AWS_BUCKET,
f"{comic_id}/episodes/episode-{current_episode}/compositions/scene-{current_scene}/frame-{current_frame}",
image_data,
f"{i}.jpg",
)
except Exception as e:
print(f"Error saving image {i} to S3: {e}")
continue
gr.Info("All Images saved successfully!")
except Exception as e:
print(f"Error in save_images: {e}")