Manish Gupta commited on
Commit
e9ed267
·
1 Parent(s): b3e1c75

Added gradio file.

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +228 -0
  3. aws_utils.py +138 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ import aws_utils
7
+ from dotenv import load_dotenv
8
+ load_dotenv()
9
+
10
+ AWS_BUCKET = os.environ.get("AWS_BUCKET")
11
+
12
+
13
+ def list_current_dir(bucket_name: str, folder_path: str = "") -> list:
14
+ response = aws_utils.S3_CLIENT.list_objects_v2(
15
+ Bucket=bucket_name, Prefix=folder_path, Delimiter="/"
16
+ )
17
+
18
+ # Check if the bucket contains objects
19
+ folders = []
20
+ if "CommonPrefixes" in response:
21
+ for prefix in response["CommonPrefixes"]:
22
+ folders.append(prefix["Prefix"])
23
+ return folders
24
+
25
+
26
+ def load_text_data(
27
+ episodes: list, current_episode: int, current_scene: int, current_frame: int
28
+ ):
29
+ curr_frame = episodes[current_episode]["scenes"][current_scene]["frames"][
30
+ current_frame
31
+ ]
32
+ return (
33
+ episodes,
34
+ current_episode,
35
+ current_scene,
36
+ current_frame,
37
+ str(current_episode + 1),
38
+ str(current_scene + 1),
39
+ str(current_frame + 1),
40
+ curr_frame["description"],
41
+ curr_frame["narration"],
42
+ curr_frame["audio_cue_character"],
43
+ curr_frame["audio_cue_text"],
44
+ )
45
+
46
+
47
+ def load_data(
48
+ episodes: list, current_episode: int, current_scene: int, current_frame: int
49
+ ):
50
+ if current_frame + 1 < len(
51
+ episodes[current_episode]["scenes"][current_scene]["frames"]
52
+ ):
53
+ current_frame += 1
54
+ else:
55
+ if current_scene + 1 < len(episodes[current_episode]["scenes"]):
56
+ current_scene += 1
57
+ current_frame = 0
58
+ else:
59
+ if current_episode + 1 < len(episodes):
60
+ current_episode += 1
61
+ current_scene = 0
62
+ current_frame = 0
63
+ else:
64
+ return [], current_episode, current_scene, current_frame
65
+
66
+ images = []
67
+ # Loading the 0th frame of 0th scene in 0th episode.
68
+ for comps in episodes[current_episode]["scenes"][current_scene]["frames"][
69
+ current_frame
70
+ ]["compositions"]:
71
+ data = aws_utils.fetch_from_s3(comps["image"])
72
+ images.append(Image.open(io.BytesIO(data)))
73
+
74
+ return images, *load_text_data(
75
+ episodes, current_episode, current_scene, current_frame
76
+ )
77
+
78
+
79
+ def load_data_once(
80
+ comic_id: str, current_episode: int, current_scene: int, current_frame: int
81
+ ):
82
+ # Logic to load and return images based on comic_id and episode
83
+ # You can replace this with actual image paths or generation logic
84
+ print(f"Getting episodes for comic id: {comic_id}")
85
+ episodes = {}
86
+ idx = 0
87
+ for folder in list_current_dir(AWS_BUCKET, f"{comic_id}/"):
88
+ if "episode" in folder:
89
+ json_path = f"s3://{AWS_BUCKET}/{folder}episode.json"
90
+ episodes[idx] = eval(
91
+ aws_utils.fetch_from_s3(source=json_path).decode("utf-8")
92
+ )
93
+ idx += 1
94
+
95
+ images = []
96
+ # Loading the 0th frame of 0th scene in 0th episode.
97
+ for comps in episodes[current_episode]["scenes"][current_scene]["frames"][
98
+ current_frame
99
+ ]["compositions"]:
100
+ data = aws_utils.fetch_from_s3(comps["image"])
101
+ images.append(Image.open(io.BytesIO(data)))
102
+
103
+ return images, *load_text_data(
104
+ episodes, current_episode, current_scene, current_frame
105
+ )
106
+
107
+
108
+ def save_image(
109
+ selected_image,
110
+ comic_id: str,
111
+ current_episode: int,
112
+ current_scene: int,
113
+ current_frame: int,
114
+ ):
115
+ # Implement your AWS S3 save logic here
116
+ print(f"Saving image: {selected_image}")
117
+ with Image.open(selected_image[0]) as img:
118
+ # Convert and save as JPG
119
+ img_bytes = io.BytesIO()
120
+ img.convert("RGB").save(img_bytes, "JPEG")
121
+ img_bytes.seek(0)
122
+
123
+ aws_utils.save_to_s3(
124
+ AWS_BUCKET,
125
+ f"{comic_id}/episode-{current_episode}/images/scene-{current_scene}",
126
+ img_bytes,
127
+ f"{current_frame}.jpg",
128
+ )
129
+ print("Image saved successfully!")
130
+
131
+
132
+ # Function to handle image selection and enable the save button
133
+ def select_image(selected_image_index, images):
134
+ # Get the selected image from its index
135
+ selected_image = images[selected_image_index]
136
+ return gr.update(interactive=True), selected_image
137
+
138
+
139
+ with gr.Blocks() as demo:
140
+ selected_image = gr.State(None)
141
+ current_episode = gr.State(0)
142
+ current_scene = gr.State(0)
143
+ current_frame = gr.State(0)
144
+ episodes_data = gr.State({})
145
+
146
+ with gr.Row():
147
+ comic_id = gr.Textbox(label="Enter Comic ID:", placeholder="Enter Comic ID")
148
+ load_button = gr.Button("Load Data")
149
+
150
+ images = gr.Gallery(
151
+ label="Select an Image", elem_id="image_select", columns=4, height=300
152
+ )
153
+
154
+ # Display information about current Image
155
+ with gr.Row():
156
+ episode = gr.Textbox(label="Current Episode", interactive=False)
157
+ scene = gr.Textbox(label="Current Scene", interactive=False)
158
+ frame = gr.Textbox(label="Current Frame", interactive=False)
159
+
160
+ image_description = gr.Textbox(label="Description", interactive=False)
161
+ narration = gr.Textbox(label="narration", interactive=False)
162
+ with gr.Row():
163
+ character = gr.Textbox(label="Character", interactive=False)
164
+ dialouge = gr.Textbox(label="dialouge", interactive=False)
165
+
166
+ # buttons to interact with the data
167
+ with gr.Row():
168
+ save_button = gr.Button("Save Image")
169
+ next_button = gr.Button("Next Image")
170
+
171
+ load_button.click(
172
+ load_data_once,
173
+ inputs=[comic_id, current_episode, current_scene, current_frame],
174
+ outputs=[
175
+ images,
176
+ episodes_data,
177
+ current_episode,
178
+ current_scene,
179
+ current_frame,
180
+ episode,
181
+ scene,
182
+ frame,
183
+ image_description,
184
+ narration,
185
+ character,
186
+ dialouge
187
+ ],
188
+ )
189
+
190
+ # When an image is clicked
191
+ images.select(
192
+ select_image,
193
+ inputs=[gr.Number(), images],
194
+ outputs=[save_button, selected_image],
195
+ )
196
+
197
+ save_button.click(
198
+ save_image,
199
+ inputs=[
200
+ selected_image,
201
+ comic_id,
202
+ current_episode,
203
+ current_scene,
204
+ current_frame,
205
+ ],
206
+ outputs=[],
207
+ )
208
+
209
+ next_button.click(
210
+ load_data,
211
+ inputs=[episodes_data, current_episode, current_scene, current_frame],
212
+ outputs=[
213
+ images,
214
+ episodes_data,
215
+ current_episode,
216
+ current_scene,
217
+ current_frame,
218
+ episode,
219
+ scene,
220
+ frame,
221
+ image_description,
222
+ narration,
223
+ character,
224
+ dialouge
225
+ ],
226
+ )
227
+
228
+ demo.launch()
aws_utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from io import BytesIO
4
+ from typing import Union
5
+ from urllib.parse import urlparse
6
+
7
+ import boto3
8
+ from botocore.client import Config
9
+ from botocore.exceptions import NoCredentialsError
10
+ from dotenv import load_dotenv
11
+ load_dotenv()
12
+
13
+ AWS_REGION = os.environ.get("AWS_REGION")
14
+
15
+ # Initialize the S3 client
16
+ S3_CLIENT = boto3.client(
17
+ "s3", region_name=AWS_REGION, config=Config(signature_version="s3v4")
18
+ )
19
+
20
+
21
+ def save_to_s3(
22
+ bucket_name: str,
23
+ folder_name: str,
24
+ content: Union[str, dict, BytesIO],
25
+ file_name: str,
26
+ ) -> str:
27
+ """
28
+ Save a file to an S3 bucket, determining the content type based on the input type.
29
+
30
+ Args:
31
+ bucket_name (str): The name of the S3 bucket.
32
+ folder_name (str): The folder path in the S3 bucket.
33
+ content (Union[str, dict, BytesIO]): The content to save, can be a string, dictionary, or BytesIO.
34
+ file_name (str): The file name under which the content should be saved.
35
+
36
+ Returns:
37
+ str: The S3 URL of the uploaded file, or an error message if credentials are not available.
38
+ """
39
+ # Ensure the folder name ends with a '/'
40
+ # if not folder_name.endswith('/'):
41
+ # folder_name += '/'
42
+ # Determine file name and content type based on the input
43
+ if isinstance(content, str):
44
+ file_content = content
45
+ content_type = "text/plain"
46
+ elif isinstance(content, dict):
47
+ file_content = json.dumps(content)
48
+ content_type = "application/json"
49
+ elif isinstance(content, BytesIO):
50
+ file_content = content
51
+ content_type = "image/jpeg"
52
+ else:
53
+ print(
54
+ "Invalid content type. Content must be a string, dictionary, or BytesIO."
55
+ )
56
+ raise ValueError("Content must be either a string, dictionary, or BytesIO.")
57
+
58
+ # Ensure the folder name ends with a '/'
59
+ s3_file_path = f"{folder_name.rstrip('/')}/{file_name}"
60
+
61
+ try:
62
+ # Upload the file to S3
63
+ S3_CLIENT.put_object(
64
+ Bucket=bucket_name,
65
+ Key=s3_file_path,
66
+ Body=file_content,
67
+ ContentType=content_type,
68
+ )
69
+ s3_url = f"s3://{bucket_name}/{s3_file_path}"
70
+ print(f"File successfully uploaded to {s3_url}")
71
+ return s3_url
72
+
73
+ except NoCredentialsError:
74
+ print("AWS credentials not available.")
75
+ return "Error: AWS credentials not available."
76
+
77
+
78
+ def fetch_from_s3(source: Union[str, dict], region_name: str = "ap-south-1") -> bytes:
79
+ """
80
+ Fetch a file's content from S3 given a source URL or dictionary with bucket and key.
81
+
82
+ Args:
83
+ source (Union[str, dict]): The source S3 URL or a dictionary with 'bucket_name' and 'file_key'.
84
+ region_name (str): The AWS region name for the S3 client (default is 'ap-south-1').
85
+
86
+ Returns:
87
+ bytes: The content of the file fetched from S3.
88
+ """
89
+ print(f"Fetching file from S3. Source: {source}")
90
+ s3_client = boto3.client("s3", region_name=region_name)
91
+
92
+ # Parse the source depending on its type
93
+ if isinstance(source, str):
94
+ parsed_url = urlparse(source)
95
+ bucket_name = parsed_url.netloc.split(".")[0]
96
+ file_path = parsed_url.path.lstrip("/")
97
+ elif isinstance(source, dict):
98
+ bucket_name = source.get("bucket_name")
99
+ file_path = source.get("file_key")
100
+ if not bucket_name or not file_path:
101
+ print("Dictionary input must contain 'bucket_name' and 'file_key'.")
102
+ raise ValueError(
103
+ "Dictionary input must contain 'bucket_name' and 'file_key'."
104
+ )
105
+ else:
106
+ print("Source must be a string URL or a dictionary.")
107
+ raise ValueError("Source must be a string URL or a dictionary.")
108
+
109
+ print(f"Attempting to download from bucket: {bucket_name}, path: {file_path}")
110
+ try:
111
+ response = s3_client.get_object(Bucket=bucket_name, Key=file_path)
112
+ file_content = response["Body"].read()
113
+ print(f"File fetched successfully from {bucket_name}/{file_path}")
114
+ return file_content
115
+ except Exception as e:
116
+ print(f"Failed to fetch file from S3: {e}")
117
+ raise
118
+
119
+
120
+ def list_s3_objects(bucket_name: str, folder_path: str = "") -> list:
121
+ """
122
+ Lists a content of the given a directory URL.
123
+
124
+ Args:
125
+ bucket_name (str): The name of the S3 bucket.
126
+ folder_name (str): The folder path in the S3 bucket.
127
+
128
+ Returns:
129
+ list: The list of files found inside the given directory URL.
130
+ """
131
+ response = S3_CLIENT.list_objects_v2(Bucket=bucket_name, Prefix=folder_path)
132
+ # Check if the bucket contains objects
133
+ objects = []
134
+ if "Contents" in response:
135
+ for obj in response["Contents"]:
136
+ objects.append(obj["Key"])
137
+
138
+ return objects