Manish Gupta
Path changes.
323c22b
import os
import io
from PIL import Image
import gradio as gr
import aws_utils
AWS_BUCKET = os.getenv("AWS_BUCKET")
os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("AWS_ACCESS_KEY_ID")
os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("AWS_SECRET_ACCESS_KEY")
os.environ["S3_BUCKET_NAME"] = os.getenv("AWS_BUCKET")
def load_text_data(characters: list, current_index: int):
curr_char = characters[current_index]
return (
characters,
current_index,
curr_char["name"],
curr_char["age"],
curr_char["gender"],
curr_char["description"],
)
def load_data(comic_id: str, character_data: list, current_index: int):
if current_index < len(character_data) - 1:
current_index += 1
else:
return [], *load_text_data(character_data, current_index)
images = []
name = character_data[current_index]["name"]
for idx in range(1, 5):
url = f"s3://{AWS_BUCKET}/{comic_id}/characters/compositions/{name}/{idx}.jpg"
data = aws_utils.fetch_from_s3(url)
images.append(Image.open(io.BytesIO(data)))
return images, *load_text_data(character_data, current_index)
def load_data_once(comic_id: str, current_index: int):
# Logic to load and return character images based on comic_id
# You can replace this with actual image paths or generation logic
print(f"Getting characters for comic id: {comic_id}")
characters = []
data = eval(
aws_utils.fetch_from_s3(
source=f"s3://{AWS_BUCKET}/{comic_id}/characters/characters.json"
).decode("utf-8")
)
for _, profile in data.items():
characters.append(profile)
images = []
# Loading the 0th frame of 0th scene in 0th episode.
name = characters[current_index]["name"]
for idx in range(1, 5):
url = f"s3://{AWS_BUCKET}/{comic_id}/characters/compositions/{name}/{idx}.jpg"
data = aws_utils.fetch_from_s3(url)
images.append(Image.open(io.BytesIO(data)))
return images, *load_text_data(characters, current_index)
def save_image(
selected_image,
comic_id: str,
character_data: list,
current_index: int,
):
# Implement your AWS S3 save logic here
print(f"Saving image: {selected_image}")
name = character_data[current_index]["name"]
with Image.open(selected_image[0]) as img:
# Convert and save as JPG
img_bytes = io.BytesIO()
img.convert("RGB").save(img_bytes, "JPEG")
img_bytes.seek(0)
aws_utils.save_to_s3(
AWS_BUCKET,
f"{comic_id}/characters/images",
img_bytes,
f"{name}.jpg",
)
print("Image saved successfully!")
gr.Info("Saved Image successfully!")
with gr.Blocks() as demo:
# selected_image = gr.State(None)
selected_image = gr.State()
current_index = gr.State(0)
character_data = gr.State([])
with gr.Row():
comic_id = gr.Textbox(label="Enter Comic ID:", placeholder="Enter Comic ID")
load_button = gr.Button("Load Data")
images = gr.Gallery(
label="Select an Image", elem_id="image_select", columns=4, height=300
)
# Display information about current Character
with gr.Row():
name = gr.Textbox(label="Name", interactive=False)
age = gr.Textbox(label="Age", interactive=False)
gender = gr.Textbox(label="Gender", interactive=False)
description = gr.Textbox(label="description", interactive=False)
# buttons to interact with the data
with gr.Row():
save_button = gr.Button("Save Image")
next_button = gr.Button("Next Image")
load_button.click(
load_data_once,
inputs=[comic_id, current_index],
outputs=[images, character_data, current_index, name, age, gender, description],
)
# When an image is clicked
def get_select_index(evt: gr.SelectData, images):
return images[evt.index]
images.select(get_select_index, images, selected_image)
save_button.click(
save_image,
inputs=[
selected_image,
comic_id,
character_data,
current_index,
],
outputs=[],
)
next_button.click(
load_data,
inputs=[comic_id, character_data, current_index],
outputs=[images, character_data, current_index, name, age, gender, description],
)
demo.launch(auth=("admin", "Qrt@12*34#immersfy"), share=True, ssr_mode=False)