Spaces:
Sleeping
Sleeping
Manish Gupta commited on
Commit ·
0926cd3
1
Parent(s): 9d2470e
Modified some part, still WIP
Browse files- app.py +130 -186
- core.py +269 -0
- io.py +50 -0
- openai_wrapper.py +86 -0
- parameters.py +9 -0
- requirements.txt +6 -1
- script_gen.py +53 -0
app.py
CHANGED
|
@@ -1,186 +1,35 @@
|
|
| 1 |
-
|
| 2 |
-
import dataclasses
|
| 3 |
-
import io
|
| 4 |
-
from PIL import Image
|
| 5 |
import gradio as gr
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
AWS_BUCKET = os.getenv("AWS_BUCKET")
|
| 10 |
-
os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
|
| 11 |
-
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
|
| 12 |
-
os.environ["S3_BUCKET_NAME"] = os.getenv("AWS_BUCKET")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@dataclasses.dataclass
|
| 16 |
-
class ComicFrame:
|
| 17 |
-
description: str
|
| 18 |
-
narration: str
|
| 19 |
-
character_dilouge: str
|
| 20 |
-
character: str
|
| 21 |
-
compositions: list
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def list_current_dir(bucket_name: str, folder_path: str = "") -> list:
|
| 25 |
-
response = aws_utils.S3_CLIENT.list_objects_v2(
|
| 26 |
-
Bucket=bucket_name, Prefix=folder_path, Delimiter="/"
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# Check if the bucket contains objects
|
| 30 |
-
folders = []
|
| 31 |
-
if "CommonPrefixes" in response:
|
| 32 |
-
for prefix in response["CommonPrefixes"]:
|
| 33 |
-
folders.append(prefix["Prefix"])
|
| 34 |
-
return folders
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def load_data_inner(episodes_data: list, current_episode: int, current_frame: int):
|
| 38 |
-
images = []
|
| 39 |
-
curr_frame = episodes_data[current_episode][current_frame]
|
| 40 |
-
# Loading the 0th frame of 0th scene in 0th episode.
|
| 41 |
-
for comps in curr_frame.compositions:
|
| 42 |
-
data = aws_utils.fetch_from_s3(comps)
|
| 43 |
-
images.append(Image.open(io.BytesIO(data)))
|
| 44 |
-
|
| 45 |
-
return (
|
| 46 |
-
images,
|
| 47 |
-
episodes_data,
|
| 48 |
-
current_episode,
|
| 49 |
-
current_frame,
|
| 50 |
-
curr_frame.description,
|
| 51 |
-
curr_frame.narration,
|
| 52 |
-
curr_frame.character,
|
| 53 |
-
curr_frame.character_dilouge,
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def load_metadata_fn(comic_id: str):
|
| 58 |
-
# Logic to load and return images based on comic_id and episode
|
| 59 |
-
# You can replace this with actual image paths or generation logic
|
| 60 |
-
print(f"Getting episodes for comic id: {comic_id}")
|
| 61 |
-
episodes_data = {}
|
| 62 |
-
episode_idx = []
|
| 63 |
-
for folder in list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/"):
|
| 64 |
-
if "episode" in folder:
|
| 65 |
-
json_path = f"s3://{AWS_BUCKET}/{folder}episode.json"
|
| 66 |
-
idx = int(folder.split("/")[2].split("-")[-1])
|
| 67 |
-
episode_idx.append(idx)
|
| 68 |
-
data = eval(aws_utils.fetch_from_s3(source=json_path).decode("utf-8"))
|
| 69 |
-
comic_frames = []
|
| 70 |
-
for scene in data["scenes"]:
|
| 71 |
-
for frame in scene["frames"]:
|
| 72 |
-
comic_frames.append(
|
| 73 |
-
ComicFrame(
|
| 74 |
-
description=frame["description"],
|
| 75 |
-
narration=frame["narration"],
|
| 76 |
-
character=frame["audio_cue_character"],
|
| 77 |
-
character_dilouge=frame["audio_cue_text"],
|
| 78 |
-
compositions=[
|
| 79 |
-
comp["image"] for comp in frame["compositions"]
|
| 80 |
-
],
|
| 81 |
-
)
|
| 82 |
-
)
|
| 83 |
-
episodes_data[idx] = comic_frames
|
| 84 |
-
current_episode, current_frame = min(episode_idx), 0
|
| 85 |
-
return (
|
| 86 |
-
gr.update(choices=episode_idx, value=episode_idx[0]),
|
| 87 |
-
gr.update(
|
| 88 |
-
choices=range(len(episodes_data[current_episode])), value=current_frame
|
| 89 |
-
),
|
| 90 |
-
episodes_data
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def load_data_next(episodes_data: list, current_episode: int, current_frame: int):
|
| 95 |
-
if current_frame + 1 < len(episodes_data[current_episode]):
|
| 96 |
-
current_frame += 1
|
| 97 |
-
elif current_episode + 1 < len(episodes_data):
|
| 98 |
-
current_episode += 1
|
| 99 |
-
current_frame = 0
|
| 100 |
-
else:
|
| 101 |
-
return [], current_episode, current_frame
|
| 102 |
-
return (
|
| 103 |
-
gr.update(value=current_episode),
|
| 104 |
-
gr.update(value=current_frame),
|
| 105 |
-
*load_data_inner(episodes_data, current_episode, current_frame),
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def load_data_prev(episodes_data: list, current_episode: int, current_frame: int):
|
| 110 |
-
if current_frame - 1 >= 0:
|
| 111 |
-
current_frame -= 1
|
| 112 |
-
elif current_episode - 1 > min(list(episodes_data.keys())):
|
| 113 |
-
current_episode -= 1
|
| 114 |
-
current_frame = 0
|
| 115 |
-
else:
|
| 116 |
-
return [], current_episode, current_frame
|
| 117 |
-
return (
|
| 118 |
-
gr.update(value=current_episode),
|
| 119 |
-
gr.update(value=current_frame),
|
| 120 |
-
*load_data_inner(episodes_data, current_episode, current_frame),
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def load_from_dropdown(
|
| 125 |
-
episodes_data: dict, selected_episode, selected_frame,
|
| 126 |
-
):
|
| 127 |
-
return (
|
| 128 |
-
gr.update(value=selected_episode),
|
| 129 |
-
gr.update(value=selected_frame),
|
| 130 |
-
*load_data_inner(episodes_data, selected_episode, selected_frame),
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def load_dropdown_fn(selected_episode):
|
| 135 |
-
return (
|
| 136 |
-
gr.update(value=selected_episode),
|
| 137 |
-
gr.update(value=0),
|
| 138 |
-
selected_episode,
|
| 139 |
-
0,
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
def load_dropdown_fn_v2(selected_frame):
|
| 143 |
-
return selected_frame
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def save_image(
|
| 147 |
-
selected_image,
|
| 148 |
-
comic_id: str,
|
| 149 |
-
current_episode: int,
|
| 150 |
-
current_frame: int,
|
| 151 |
-
):
|
| 152 |
-
# Implement your AWS S3 save logic here
|
| 153 |
-
# print(f"Saving image: {selected_image}")
|
| 154 |
-
with Image.open(selected_image[0]) as img:
|
| 155 |
-
# Convert and save as JPG
|
| 156 |
-
img_bytes = io.BytesIO()
|
| 157 |
-
img.convert("RGB").save(img_bytes, "JPEG")
|
| 158 |
-
img_bytes.seek(0)
|
| 159 |
-
|
| 160 |
-
aws_utils.save_to_s3(
|
| 161 |
-
AWS_BUCKET,
|
| 162 |
-
f"{comic_id}/episodes/episode-{current_episode}/images",
|
| 163 |
-
img_bytes,
|
| 164 |
-
f"{current_frame}.jpg",
|
| 165 |
-
)
|
| 166 |
-
gr.Info("Saved Image successfully!")
|
| 167 |
-
|
| 168 |
-
|
| 169 |
with gr.Blocks() as demo:
|
| 170 |
selected_image = gr.State(None)
|
| 171 |
current_episode = gr.State(-1)
|
|
|
|
| 172 |
current_frame = gr.State(-1)
|
| 173 |
episodes_data = gr.State({})
|
|
|
|
| 174 |
|
| 175 |
with gr.Row():
|
| 176 |
-
|
|
|
|
| 177 |
load_metadata = gr.Button("Load Metadata")
|
| 178 |
|
| 179 |
# Display information about current Image
|
| 180 |
with gr.Row():
|
| 181 |
-
episode_dropdown = gr.Dropdown(
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
images = gr.Gallery(
|
| 186 |
label="Select an Image", elem_id="image_select", columns=4, height=300
|
|
@@ -192,7 +41,8 @@ with gr.Blocks() as demo:
|
|
| 192 |
|
| 193 |
with gr.Row():
|
| 194 |
character = gr.Textbox(label="Character", interactive=False)
|
| 195 |
-
dialouge = gr.Textbox(label="
|
|
|
|
| 196 |
|
| 197 |
# buttons to interact with the data
|
| 198 |
with gr.Row():
|
|
@@ -200,38 +50,81 @@ with gr.Blocks() as demo:
|
|
| 200 |
save_button = gr.Button("Save Image")
|
| 201 |
next_button = gr.Button("Next Image")
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
load_metadata.click(
|
| 204 |
-
load_metadata_fn,
|
| 205 |
inputs=[comic_id],
|
| 206 |
outputs=[
|
| 207 |
episode_dropdown,
|
| 208 |
frame_dropdown,
|
| 209 |
-
episodes_data
|
|
|
|
| 210 |
],
|
| 211 |
)
|
| 212 |
|
| 213 |
episode_dropdown.input(
|
| 214 |
-
load_dropdown_fn,
|
| 215 |
inputs=[episode_dropdown],
|
| 216 |
outputs=[
|
| 217 |
episode_dropdown,
|
| 218 |
frame_dropdown,
|
| 219 |
current_episode,
|
| 220 |
current_frame,
|
| 221 |
-
]
|
| 222 |
)
|
| 223 |
|
| 224 |
frame_dropdown.input(
|
| 225 |
-
load_dropdown_fn_v2,
|
| 226 |
inputs=[frame_dropdown],
|
| 227 |
outputs=[
|
| 228 |
current_frame,
|
| 229 |
-
]
|
| 230 |
)
|
| 231 |
|
| 232 |
load_images.click(
|
| 233 |
-
load_from_dropdown,
|
| 234 |
-
inputs=[episodes_data,
|
| 235 |
outputs=[
|
| 236 |
episode_dropdown,
|
| 237 |
frame_dropdown,
|
|
@@ -243,7 +136,18 @@ with gr.Blocks() as demo:
|
|
| 243 |
image_description,
|
| 244 |
narration,
|
| 245 |
character,
|
| 246 |
-
dialouge
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
],
|
| 248 |
)
|
| 249 |
|
|
@@ -254,7 +158,7 @@ with gr.Blocks() as demo:
|
|
| 254 |
images.select(get_select_index, images, selected_image)
|
| 255 |
|
| 256 |
next_button.click(
|
| 257 |
-
load_data_next,
|
| 258 |
inputs=[episodes_data, current_episode, current_frame],
|
| 259 |
outputs=[
|
| 260 |
episode_dropdown,
|
|
@@ -267,12 +171,23 @@ with gr.Blocks() as demo:
|
|
| 267 |
image_description,
|
| 268 |
narration,
|
| 269 |
character,
|
| 270 |
-
dialouge
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
],
|
| 272 |
)
|
| 273 |
|
| 274 |
prev_button.click(
|
| 275 |
-
load_data_prev,
|
| 276 |
inputs=[episodes_data, current_episode, current_frame],
|
| 277 |
outputs=[
|
| 278 |
episode_dropdown,
|
|
@@ -285,12 +200,23 @@ with gr.Blocks() as demo:
|
|
| 285 |
image_description,
|
| 286 |
narration,
|
| 287 |
character,
|
| 288 |
-
dialouge
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
],
|
| 290 |
)
|
| 291 |
|
| 292 |
save_button.click(
|
| 293 |
-
save_image,
|
| 294 |
inputs=[
|
| 295 |
selected_image,
|
| 296 |
comic_id,
|
|
@@ -300,4 +226,22 @@ with gr.Blocks() as demo:
|
|
| 300 |
outputs=[],
|
| 301 |
)
|
| 302 |
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Driver File."""
|
|
|
|
|
|
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
+
import core
|
| 4 |
+
import parameters
|
| 5 |
|
| 6 |
+
############################################ LAYOUT ############################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
with gr.Blocks() as demo:
|
| 8 |
selected_image = gr.State(None)
|
| 9 |
current_episode = gr.State(-1)
|
| 10 |
+
current_scene = gr.State(-1)
|
| 11 |
current_frame = gr.State(-1)
|
| 12 |
episodes_data = gr.State({})
|
| 13 |
+
current_frame_data = gr.State(None)
|
| 14 |
|
| 15 |
with gr.Row():
|
| 16 |
+
with gr.Column():
|
| 17 |
+
comic_id = gr.Textbox(label="Enter Comic ID:", placeholder="Enter Comic ID")
|
| 18 |
load_metadata = gr.Button("Load Metadata")
|
| 19 |
|
| 20 |
# Display information about current Image
|
| 21 |
with gr.Row():
|
| 22 |
+
episode_dropdown = gr.Dropdown(
|
| 23 |
+
choices=[], label="Current Episode", interactive=True
|
| 24 |
+
)
|
| 25 |
+
frame_dropdown = gr.Dropdown(
|
| 26 |
+
choices=[], label="Current Frame", interactive=True
|
| 27 |
+
)
|
| 28 |
+
with gr.Column():
|
| 29 |
+
load_images = gr.Button("Load Images")
|
| 30 |
+
developer = gr.Checkbox(
|
| 31 |
+
value=False, label="Enable Developer Mode", visible=False
|
| 32 |
+
)
|
| 33 |
|
| 34 |
images = gr.Gallery(
|
| 35 |
label="Select an Image", elem_id="image_select", columns=4, height=300
|
|
|
|
| 41 |
|
| 42 |
with gr.Row():
|
| 43 |
character = gr.Textbox(label="Character", interactive=False)
|
| 44 |
+
dialouge = gr.Textbox(label="Dialouge", interactive=False)
|
| 45 |
+
location = gr.Textbox(label="Location", interactive=False)
|
| 46 |
|
| 47 |
# buttons to interact with the data
|
| 48 |
with gr.Row():
|
|
|
|
| 50 |
save_button = gr.Button("Save Image")
|
| 51 |
next_button = gr.Button("Next Image")
|
| 52 |
|
| 53 |
+
with gr.Column(visible=False) as developer_options:
|
| 54 |
+
with gr.Column():
|
| 55 |
+
setting = gr.Textbox(label="Frame Setting")
|
| 56 |
+
with gr.Row():
|
| 57 |
+
with gr.Column():
|
| 58 |
+
gr.Markdown("## Composition #1")
|
| 59 |
+
prompt_1 = gr.TextArea(label="Image Prompt")
|
| 60 |
+
# shot_1 = gr.Textbox(label="Shot Type")
|
| 61 |
+
seed_1 = gr.Textbox(label="Generation Seed")
|
| 62 |
+
with gr.Column():
|
| 63 |
+
gr.Markdown("## Composition #2")
|
| 64 |
+
prompt_2 = gr.TextArea(label="Image Prompt")
|
| 65 |
+
# shot_2 = gr.Textbox(label="Shot Type")
|
| 66 |
+
seed_2 = gr.Textbox(label="Generation Seed")
|
| 67 |
+
with gr.Row():
|
| 68 |
+
with gr.Column():
|
| 69 |
+
gr.Markdown("## Composition #3")
|
| 70 |
+
prompt_3 = gr.TextArea(label="Image Prompt")
|
| 71 |
+
# shot_3 = gr.Textbox(label="Shot Type")
|
| 72 |
+
seed_3 = gr.Textbox(label="Generation Seed")
|
| 73 |
+
with gr.Column():
|
| 74 |
+
gr.Markdown("## Composition #4")
|
| 75 |
+
prompt_4 = gr.TextArea(label="Image Prompt")
|
| 76 |
+
# shot_4 = gr.Textbox(label="Shot Type")
|
| 77 |
+
seed_4 = gr.Textbox(label="Generation Seed")
|
| 78 |
+
regenerate_comps_btn = gr.Button(value="Regenerate Compositions")
|
| 79 |
+
|
| 80 |
+
with gr.Column():
|
| 81 |
+
negative_prompt = gr.TextArea(value="", label="Negative Prompt", )
|
| 82 |
+
chars = gr.Textbox(value="[]", label="Related Characters")
|
| 83 |
+
with gr.Row():
|
| 84 |
+
height = gr.Textbox(value="1024", label="Image Height")
|
| 85 |
+
width = gr.Textbox(value="1024", label="Image Width")
|
| 86 |
+
visual_style = gr.Dropdown(
|
| 87 |
+
choices=parameters.VISUAL_CHOICES, label="Current Visual Style", interactive=True
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
with gr.Row():
|
| 91 |
+
regenerate_btn = gr.Button("Regenerate")
|
| 92 |
+
save_btn = gr.Button("Save")
|
| 93 |
+
|
| 94 |
+
############################################ EVENTS ############################################
|
| 95 |
load_metadata.click(
|
| 96 |
+
core.load_metadata_fn,
|
| 97 |
inputs=[comic_id],
|
| 98 |
outputs=[
|
| 99 |
episode_dropdown,
|
| 100 |
frame_dropdown,
|
| 101 |
+
episodes_data,
|
| 102 |
+
developer,
|
| 103 |
],
|
| 104 |
)
|
| 105 |
|
| 106 |
episode_dropdown.input(
|
| 107 |
+
core.load_dropdown_fn,
|
| 108 |
inputs=[episode_dropdown],
|
| 109 |
outputs=[
|
| 110 |
episode_dropdown,
|
| 111 |
frame_dropdown,
|
| 112 |
current_episode,
|
| 113 |
current_frame,
|
| 114 |
+
],
|
| 115 |
)
|
| 116 |
|
| 117 |
frame_dropdown.input(
|
| 118 |
+
core.load_dropdown_fn_v2,
|
| 119 |
inputs=[frame_dropdown],
|
| 120 |
outputs=[
|
| 121 |
current_frame,
|
| 122 |
+
],
|
| 123 |
)
|
| 124 |
|
| 125 |
load_images.click(
|
| 126 |
+
core.load_from_dropdown,
|
| 127 |
+
inputs=[episodes_data, current_episode, current_frame, developer],
|
| 128 |
outputs=[
|
| 129 |
episode_dropdown,
|
| 130 |
frame_dropdown,
|
|
|
|
| 136 |
image_description,
|
| 137 |
narration,
|
| 138 |
character,
|
| 139 |
+
dialouge,
|
| 140 |
+
location,
|
| 141 |
+
setting,
|
| 142 |
+
prompt_1,
|
| 143 |
+
seed_1,
|
| 144 |
+
prompt_2,
|
| 145 |
+
seed_2,
|
| 146 |
+
prompt_3,
|
| 147 |
+
seed_3,
|
| 148 |
+
prompt_4,
|
| 149 |
+
seed_4,
|
| 150 |
+
chars,
|
| 151 |
],
|
| 152 |
)
|
| 153 |
|
|
|
|
| 158 |
images.select(get_select_index, images, selected_image)
|
| 159 |
|
| 160 |
next_button.click(
|
| 161 |
+
core.load_data_next,
|
| 162 |
inputs=[episodes_data, current_episode, current_frame],
|
| 163 |
outputs=[
|
| 164 |
episode_dropdown,
|
|
|
|
| 171 |
image_description,
|
| 172 |
narration,
|
| 173 |
character,
|
| 174 |
+
dialouge,
|
| 175 |
+
location,
|
| 176 |
+
setting,
|
| 177 |
+
prompt_1,
|
| 178 |
+
seed_1,
|
| 179 |
+
prompt_2,
|
| 180 |
+
seed_2,
|
| 181 |
+
prompt_3,
|
| 182 |
+
seed_3,
|
| 183 |
+
prompt_4,
|
| 184 |
+
seed_4,
|
| 185 |
+
chars,
|
| 186 |
],
|
| 187 |
)
|
| 188 |
|
| 189 |
prev_button.click(
|
| 190 |
+
core.load_data_prev,
|
| 191 |
inputs=[episodes_data, current_episode, current_frame],
|
| 192 |
outputs=[
|
| 193 |
episode_dropdown,
|
|
|
|
| 200 |
image_description,
|
| 201 |
narration,
|
| 202 |
character,
|
| 203 |
+
dialouge,
|
| 204 |
+
location,
|
| 205 |
+
setting,
|
| 206 |
+
prompt_1,
|
| 207 |
+
seed_1,
|
| 208 |
+
prompt_2,
|
| 209 |
+
seed_2,
|
| 210 |
+
prompt_3,
|
| 211 |
+
seed_3,
|
| 212 |
+
prompt_4,
|
| 213 |
+
seed_4,
|
| 214 |
+
chars,
|
| 215 |
],
|
| 216 |
)
|
| 217 |
|
| 218 |
save_button.click(
|
| 219 |
+
core.save_image,
|
| 220 |
inputs=[
|
| 221 |
selected_image,
|
| 222 |
comic_id,
|
|
|
|
| 226 |
outputs=[],
|
| 227 |
)
|
| 228 |
|
| 229 |
+
developer.change(
|
| 230 |
+
core.toggle_developer_options,
|
| 231 |
+
inputs=[developer],
|
| 232 |
+
outputs=[developer_options],
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
regenerate_comps_btn.click(
|
| 236 |
+
core.regenerate_composition_data,
|
| 237 |
+
inputs=[],
|
| 238 |
+
outputs=[]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
regenerate_btn.click(
|
| 242 |
+
core.regenerate_data,
|
| 243 |
+
inputs=[],
|
| 244 |
+
outputs=[]
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
demo.launch(auth=("admin", "Qrt@12*34#immersfy"), share=True, ssr_mode=False)
|
core.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""House of all specific functions used to control data flow."""
|
| 2 |
+
from typing import List
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import dataclasses
|
| 6 |
+
import io
|
| 7 |
+
import jinja2
|
| 8 |
+
|
| 9 |
+
import aws_utils
|
| 10 |
+
import parameters
|
| 11 |
+
import script_gen
|
| 12 |
+
import io as iowrapper
|
| 13 |
+
|
| 14 |
+
AWS_BUCKET = parameters.AWS_BUCKET
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclasses.dataclass
|
| 18 |
+
class Composition:
|
| 19 |
+
prompt: str
|
| 20 |
+
shot_type: str
|
| 21 |
+
seed: int
|
| 22 |
+
image: str
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclasses.dataclass
|
| 26 |
+
class ComicFrame:
|
| 27 |
+
description: str
|
| 28 |
+
narration: str
|
| 29 |
+
character_dilouge: str
|
| 30 |
+
character: str
|
| 31 |
+
compositions: List[Composition] = dataclasses.field(default_factory=list)
|
| 32 |
+
location: str
|
| 33 |
+
setting: str
|
| 34 |
+
all_characters: list
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def list_current_dir(bucket_name: str, folder_path: str = "") -> list:
|
| 39 |
+
response = aws_utils.S3_CLIENT.list_objects_v2(
|
| 40 |
+
Bucket=bucket_name, Prefix=folder_path, Delimiter="/"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Check if the bucket contains objects
|
| 44 |
+
folders = []
|
| 45 |
+
if "CommonPrefixes" in response:
|
| 46 |
+
for prefix in response["CommonPrefixes"]:
|
| 47 |
+
folders.append(prefix["Prefix"])
|
| 48 |
+
return folders
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_data_inner(
|
| 52 |
+
episodes_data: list, current_episode: int, current_frame: int, is_developer: bool
|
| 53 |
+
):
|
| 54 |
+
images = []
|
| 55 |
+
curr_frame = episodes_data[current_episode][current_frame]
|
| 56 |
+
# Loading the 0th frame of 0th scene in 0th episode.
|
| 57 |
+
for comp in curr_frame.compositions:
|
| 58 |
+
data = aws_utils.fetch_from_s3(comp.image)
|
| 59 |
+
images.append(Image.open(io.BytesIO(data)))
|
| 60 |
+
|
| 61 |
+
return (
|
| 62 |
+
images,
|
| 63 |
+
episodes_data,
|
| 64 |
+
current_episode,
|
| 65 |
+
current_frame,
|
| 66 |
+
gr.Textbox(value=curr_frame.description, interactive=is_developer),
|
| 67 |
+
gr.Textbox(value=curr_frame.narration, interactive=is_developer),
|
| 68 |
+
gr.Textbox(value=curr_frame.character, interactive=is_developer),
|
| 69 |
+
gr.Textbox(value=curr_frame.character_dilouge, interactive=is_developer),
|
| 70 |
+
gr.Textbox(value=curr_frame.location, interactive=is_developer),
|
| 71 |
+
curr_frame.setting,
|
| 72 |
+
curr_frame.compositions[0].prompt,
|
| 73 |
+
curr_frame.compositions[0].seed,
|
| 74 |
+
curr_frame.compositions[1].prompt,
|
| 75 |
+
curr_frame.compositions[1].seed,
|
| 76 |
+
curr_frame.compositions[2].prompt,
|
| 77 |
+
curr_frame.compositions[2].seed,
|
| 78 |
+
curr_frame.compositions[3].prompt,
|
| 79 |
+
curr_frame.compositions[3].seed,
|
| 80 |
+
curr_frame.all_characters,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def load_metadata_fn(comic_id: str):
|
| 85 |
+
# Logic to load and return images based on comic_id and episode
|
| 86 |
+
# You can replace this with actual image paths or generation logic
|
| 87 |
+
print(f"Getting episodes for comic id: {comic_id}")
|
| 88 |
+
episodes_data = {}
|
| 89 |
+
episode_idx = []
|
| 90 |
+
for folder in list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/"):
|
| 91 |
+
if "episode" in folder:
|
| 92 |
+
json_path = f"s3://{AWS_BUCKET}/{folder}episode.json"
|
| 93 |
+
idx = int(folder.split("/")[2].split("-")[-1])
|
| 94 |
+
episode_idx.append(idx)
|
| 95 |
+
data = eval(aws_utils.fetch_from_s3(source=json_path).decode("utf-8"))
|
| 96 |
+
comic_frames = []
|
| 97 |
+
for scene in data["scenes"]:
|
| 98 |
+
for frame in scene["frames"]:
|
| 99 |
+
comic_frames.append(
|
| 100 |
+
ComicFrame(
|
| 101 |
+
description=frame["description"],
|
| 102 |
+
narration=frame["narration"],
|
| 103 |
+
character=frame["audio_cue_character"],
|
| 104 |
+
character_dilouge=frame["audio_cue_text"],
|
| 105 |
+
compositions=[
|
| 106 |
+
Composition(**comp) for comp in frame["compositions"]
|
| 107 |
+
],
|
| 108 |
+
location=frame["location"],
|
| 109 |
+
setting=frame["frame_setting"],
|
| 110 |
+
all_characters=[
|
| 111 |
+
char["name"] for char in frame["characters"]
|
| 112 |
+
],
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
episodes_data[idx] = comic_frames
|
| 116 |
+
current_episode, current_frame = min(episode_idx), 0
|
| 117 |
+
return (
|
| 118 |
+
gr.update(choices=episode_idx, value=episode_idx[0]),
|
| 119 |
+
gr.update(
|
| 120 |
+
choices=range(len(episodes_data[current_episode])), value=current_frame
|
| 121 |
+
),
|
| 122 |
+
episodes_data,
|
| 123 |
+
gr.Checkbox(visible=True),
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def load_data_next(episodes_data: list, current_episode: int, current_frame: int):
|
| 128 |
+
if current_frame + 1 < len(episodes_data[current_episode]):
|
| 129 |
+
current_frame += 1
|
| 130 |
+
elif current_episode + 1 < len(episodes_data):
|
| 131 |
+
current_episode += 1
|
| 132 |
+
current_frame = 0
|
| 133 |
+
else:
|
| 134 |
+
return [], current_episode, current_frame
|
| 135 |
+
return (
|
| 136 |
+
gr.update(value=current_episode),
|
| 137 |
+
gr.update(value=current_frame),
|
| 138 |
+
*load_data_inner(episodes_data, current_episode, current_frame),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def load_data_prev(episodes_data: list, current_episode: int, current_frame: int):
|
| 143 |
+
if current_frame - 1 >= 0:
|
| 144 |
+
current_frame -= 1
|
| 145 |
+
elif current_episode - 1 > min(list(episodes_data.keys())):
|
| 146 |
+
current_episode -= 1
|
| 147 |
+
current_frame = 0
|
| 148 |
+
else:
|
| 149 |
+
return [], current_episode, current_frame
|
| 150 |
+
return (
|
| 151 |
+
gr.update(value=current_episode),
|
| 152 |
+
gr.update(value=current_frame),
|
| 153 |
+
*load_data_inner(episodes_data, current_episode, current_frame),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def load_from_dropdown(
|
| 158 |
+
episodes_data: dict, selected_episode: int, selected_frame: int, is_developer: bool
|
| 159 |
+
):
|
| 160 |
+
return (
|
| 161 |
+
gr.update(value=selected_episode),
|
| 162 |
+
gr.update(value=selected_frame),
|
| 163 |
+
*load_data_inner(episodes_data, selected_episode, selected_frame),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_dropdown_fn(selected_episode):
|
| 168 |
+
return (
|
| 169 |
+
gr.update(value=selected_episode),
|
| 170 |
+
gr.update(value=0),
|
| 171 |
+
selected_episode,
|
| 172 |
+
0,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def load_dropdown_fn_v2(selected_frame):
|
| 177 |
+
return selected_frame
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def save_image(
|
| 181 |
+
selected_image,
|
| 182 |
+
comic_id: str,
|
| 183 |
+
current_episode: int,
|
| 184 |
+
current_frame: int,
|
| 185 |
+
):
|
| 186 |
+
# Implement your AWS S3 save logic here
|
| 187 |
+
# print(f"Saving image: {selected_image}")
|
| 188 |
+
with Image.open(selected_image[0]) as img:
|
| 189 |
+
# Convert and save as JPG
|
| 190 |
+
img_bytes = io.BytesIO()
|
| 191 |
+
img.convert("RGB").save(img_bytes, "JPEG")
|
| 192 |
+
img_bytes.seek(0)
|
| 193 |
+
|
| 194 |
+
aws_utils.save_to_s3(
|
| 195 |
+
AWS_BUCKET,
|
| 196 |
+
f"{comic_id}/episode-{current_episode}/images",
|
| 197 |
+
img_bytes,
|
| 198 |
+
f"{current_frame}.jpg",
|
| 199 |
+
)
|
| 200 |
+
gr.Info("Saved Image successfully!")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def toggle_developer_options(is_developer: bool):
|
| 204 |
+
if is_developer:
|
| 205 |
+
return gr.Column(visible=True)
|
| 206 |
+
gr.Column(visible=False)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def regenerate_composition_data(
|
| 210 |
+
description: str,
|
| 211 |
+
narration: str,
|
| 212 |
+
character: str,
|
| 213 |
+
dialouge: str,
|
| 214 |
+
location: str,
|
| 215 |
+
frame_setting: str,
|
| 216 |
+
current_episode: int,
|
| 217 |
+
current_scene: int,
|
| 218 |
+
current_frame: int,
|
| 219 |
+
episodes_data: dict,
|
| 220 |
+
):
|
| 221 |
+
print(
|
| 222 |
+
f"Generating compositions for episode: {current_episode} and scene: {current_scene} and frame: {}."
|
| 223 |
+
)
|
| 224 |
+
prompt_dict = {
|
| 225 |
+
"system": script_gen.generate_image_compositions_instruction,
|
| 226 |
+
"user": jinja2.Template(
|
| 227 |
+
source=script_gen.generate_image_compositions_user_prompt
|
| 228 |
+
).render(
|
| 229 |
+
{
|
| 230 |
+
"FRAME": dataclasses.asdict(frame),
|
| 231 |
+
}
|
| 232 |
+
),
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
compositions = llm.generate_valid_json_response(prompt_dict)
|
| 236 |
+
print(compositions)
|
| 237 |
+
frame.compositions = [
|
| 238 |
+
Composition(**composition)
|
| 239 |
+
for composition in compositions["compositions"]
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def regenerate_data(
|
| 244 |
+
frame_data: ComicFrame,
|
| 245 |
+
):
|
| 246 |
+
# for
|
| 247 |
+
payload = {
|
| 248 |
+
"prompt": composition.prompt,
|
| 249 |
+
"characters": related_chars,
|
| 250 |
+
"parameters": {
|
| 251 |
+
"height": parameters.IMG_HEIGHT,
|
| 252 |
+
"width": parameters.IMG_WIDTH,
|
| 253 |
+
"visual_style": visual_style,
|
| 254 |
+
"seed": seed_val,
|
| 255 |
+
},
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
data = iowrapper.get_valid_post_response(
|
| 259 |
+
url=parameters.MODEL_SERVER_URL + "/generate_image",
|
| 260 |
+
payload=payload,
|
| 261 |
+
)
|
| 262 |
+
image_data = io.BytesIO(base64.b64decode(data["image"]))
|
| 263 |
+
path = aws_utils.save_to_s3(
|
| 264 |
+
parameters.AWS_BUCKET,
|
| 265 |
+
f"{self.id}/episodes/episode-{episode_num}/compositions/scene-{scene_num}/frame-{frame_num}",
|
| 266 |
+
image_data,
|
| 267 |
+
f"{num}.jpg",
|
| 268 |
+
)
|
| 269 |
+
pass
|
io.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
from typing import Dict
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import parameters
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
def send_post_request(url: str, payload: dict) -> Tuple[Dict[str, Any], int]:
|
| 10 |
+
"""
|
| 11 |
+
Sends a POST request to the given URL with the provided payload.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
url (str): The target URL.
|
| 15 |
+
payload (dict): The JSON payload to send.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Tuple[Dict[str, Any], int]: The JSON response and status code from the request.
|
| 19 |
+
"""
|
| 20 |
+
response = requests.post(url, json=payload)
|
| 21 |
+
|
| 22 |
+
return response.json(), response.status_code
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_valid_post_response(url: str, payload: dict) -> Dict[str, Any]:
|
| 26 |
+
"""
|
| 27 |
+
Assures a response by wrapping up the request sending utility in a for loop with maximum retries.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
url (str): The target URL.
|
| 31 |
+
payload (dict): The JSON payload to send.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Dict[str, Any]: The JSON response from the request.
|
| 35 |
+
"""
|
| 36 |
+
for _ in range(parameters.MAX_TRIES):
|
| 37 |
+
try:
|
| 38 |
+
response, status_code = send_post_request(url, payload)
|
| 39 |
+
if status_code != 200:
|
| 40 |
+
continue
|
| 41 |
+
return response
|
| 42 |
+
except:
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
print(
|
| 46 |
+
f"Max retries exceeded with POST request for url: {url} with payload:\n{payload}\n"
|
| 47 |
+
)
|
| 48 |
+
return {
|
| 49 |
+
"error": f"Max retries exceeded with POST request for url: {url} with payload:\n{payload}\n"
|
| 50 |
+
}
|
openai_wrapper.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model wrapper to interact with OpenAI models."""
|
| 2 |
+
import abc
|
| 3 |
+
import ast
|
| 4 |
+
from typing import Mapping
|
| 5 |
+
|
| 6 |
+
import openai
|
| 7 |
+
|
| 8 |
+
import parameters
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class OpenAIModel(abc.ABC):
|
| 12 |
+
API_KEY = ""
|
| 13 |
+
# TODO(Maani): Add support for more generation options like:
|
| 14 |
+
# 1. temperature
|
| 15 |
+
# 2. top-p
|
| 16 |
+
# 3. stop sequences
|
| 17 |
+
# 4. num_outputs
|
| 18 |
+
# 5. response_format
|
| 19 |
+
# 6. seed
|
| 20 |
+
|
| 21 |
+
def __init__(self, model_name: str, API_KEY: str):
|
| 22 |
+
try:
|
| 23 |
+
self.client = openai.OpenAI(
|
| 24 |
+
# This is the default and can be omitted
|
| 25 |
+
api_key=API_KEY
|
| 26 |
+
)
|
| 27 |
+
self.model_name = model_name
|
| 28 |
+
except Exception as exc:
|
| 29 |
+
raise Exception(
|
| 30 |
+
f"Failed to initialize OpenAI model client. See traceback for more details.",
|
| 31 |
+
) from exc
|
| 32 |
+
|
| 33 |
+
def prepare_input(self, prompt_dict: Mapping[str, str]) -> str:
|
| 34 |
+
conversation = []
|
| 35 |
+
try:
|
| 36 |
+
for role, content in prompt_dict.items():
|
| 37 |
+
conversation.append({"role": role, "content": content})
|
| 38 |
+
return conversation
|
| 39 |
+
except Exception as exc:
|
| 40 |
+
raise Exception(
|
| 41 |
+
f"Incomplete Prompt Dictionary Passed. Expected to have atleast a role and it's content.\nPassed dict: {prompt_dict}",
|
| 42 |
+
) from exc
|
| 43 |
+
|
| 44 |
+
def generate_response(
|
| 45 |
+
self,
|
| 46 |
+
prompt_dict: Mapping[str, str],
|
| 47 |
+
max_output_tokens: int = None,
|
| 48 |
+
temperature: int = 0.6,
|
| 49 |
+
response_format: dict = None,
|
| 50 |
+
) -> str:
|
| 51 |
+
conversation = self.prepare_input(prompt_dict)
|
| 52 |
+
try:
|
| 53 |
+
response = self.client.chat.completions.create(
|
| 54 |
+
model=self.model_name,
|
| 55 |
+
messages=conversation,
|
| 56 |
+
max_tokens=max_output_tokens if max_output_tokens else None,
|
| 57 |
+
temperature=temperature,
|
| 58 |
+
response_format=response_format,
|
| 59 |
+
)
|
| 60 |
+
return response.choices[0].message.content
|
| 61 |
+
except Exception as exc:
|
| 62 |
+
raise Exception(
|
| 63 |
+
f"Exception in generating model response.\nModel name: {self.model_name}\nInput prompt: {str(conversation)}",
|
| 64 |
+
) from exc
|
| 65 |
+
|
| 66 |
+
def generate_valid_json_response(
|
| 67 |
+
self,
|
| 68 |
+
prompt_dict: Mapping[str, str],
|
| 69 |
+
max_output_tokens: int = None,
|
| 70 |
+
temperature: int = 0.6,
|
| 71 |
+
) -> str:
|
| 72 |
+
"""Generate a response with retries, returning a valid JSON."""
|
| 73 |
+
for _ in range(parameters.MAX_TRIES):
|
| 74 |
+
try:
|
| 75 |
+
model_response = self.generate_response(
|
| 76 |
+
prompt_dict, max_output_tokens, temperature, {"type": "json_object"}
|
| 77 |
+
)
|
| 78 |
+
return ast.literal_eval(model_response)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
continue
|
| 81 |
+
raise Exception(
|
| 82 |
+
f"Maximum retries met before valid JSON structure was found.\nModel name: {self.model_name}\nInput prompt: {str(prompt_dict)}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
GPT_4O_MINI = OpenAIModel("gpt-4o-mini", parameters.OPEN_AI_API_KEY)
|
parameters.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
AWS_BUCKET = os.getenv("AWS_BUCKET")
|
| 4 |
+
os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
|
| 5 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
|
| 6 |
+
os.environ["S3_BUCKET_NAME"] = os.getenv("AWS_BUCKET")
|
| 7 |
+
VISUAL_CHOICES = ["DARK", "FLUX", "GHIBLI_COMIC"]
|
| 8 |
+
MAX_TRIES = os.getenv("MAX_TRIES")
|
| 9 |
+
OPEN_AI_API_KEY = os.getenv("OPEN_AI_API_KEY")
|
requirements.txt
CHANGED
|
@@ -1,2 +1,7 @@
|
|
| 1 |
boto3==1.35.41
|
| 2 |
-
python-dotenv==1.0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
boto3==1.35.41
|
| 2 |
+
python-dotenv==1.0.1
|
| 3 |
+
Jinja2==3.1.4
|
| 4 |
+
openai==1.23.6
|
| 5 |
+
dacite==1.8.1
|
| 6 |
+
colorlog==6.8.2
|
| 7 |
+
httpx==0.27.2
|
script_gen.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generate_image_compositions_instruction = """\
|
| 2 |
+
As a visual artist and cinematographer with extensive knowledge of different camera angles, visual styles, and aesthetics, your task is to analyze the provided frame details and create four distinct compositions. Each composition should follow a specific narrative structure in its prompt construction:
|
| 3 |
+
|
| 4 |
+
1) Output Requirements in json:
|
| 5 |
+
{
|
| 6 |
+
"compositions": [
|
| 7 |
+
{
|
| 8 |
+
"prompt": "Detailed visual description following the structure below (max 77 tokens)",
|
| 9 |
+
"shot_type": "Optimal cinematographic shot"
|
| 10 |
+
}
|
| 11 |
+
]
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
2) Prompt Structure (in this specific order):
|
| 15 |
+
a) Begin with the environment and setting:
|
| 16 |
+
- Establish the broader landscape/location first
|
| 17 |
+
- Describe key environmental elements
|
| 18 |
+
- Set the atmospheric conditions
|
| 19 |
+
- Define the lighting and mood
|
| 20 |
+
|
| 21 |
+
b) Then layer in the scene elements:
|
| 22 |
+
- How different parts of the environment interact
|
| 23 |
+
- Spatial relationships and depth
|
| 24 |
+
- Textures and materials
|
| 25 |
+
- Any dynamic elements (movement, weather effects)
|
| 26 |
+
|
| 27 |
+
c) Finally, integrate characters (if applicable):
|
| 28 |
+
- Their position within the established environment
|
| 29 |
+
- How they interact with the space
|
| 30 |
+
- Their expressions and actions as part of the scene
|
| 31 |
+
|
| 32 |
+
3) Each composition should:
|
| 33 |
+
- Flow naturally like a single, cohesive description
|
| 34 |
+
- Prioritize environmental storytelling
|
| 35 |
+
- Build the scene progressively from background to foreground
|
| 36 |
+
- Maintain consistent atmosphere throughout
|
| 37 |
+
|
| 38 |
+
4) For NO-CHAR compositions:
|
| 39 |
+
Focus entirely on a and b of the prompt structure, with extra emphasis on:
|
| 40 |
+
- Environmental details and patterns
|
| 41 |
+
- Architectural elements
|
| 42 |
+
- Natural phenomena
|
| 43 |
+
- Atmospheric qualities
|
| 44 |
+
|
| 45 |
+
Note: Avoid jumping between environment and character descriptions. Each element should flow naturally into the next, creating a unified visual narrative.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
generate_image_compositions_user_prompt = """\
|
| 49 |
+
Here's are the details:
|
| 50 |
+
|
| 51 |
+
## Synopsis:
|
| 52 |
+
{{FRAME}}
|
| 53 |
+
"""
|