umang-immersfy commited on
Commit
a77d74a
·
1 Parent(s): 6e0fda9

regeneration part added

Browse files
Files changed (3) hide show
  1. __pycache__/core.cpython-310.pyc +0 -0
  2. app.py +9 -4
  3. core.py +234 -178
__pycache__/core.cpython-310.pyc CHANGED
Binary files a/__pycache__/core.cpython-310.pyc and b/__pycache__/core.cpython-310.pyc differ
 
app.py CHANGED
@@ -13,6 +13,7 @@ with gr.Blocks() as demo:
13
  episodes_data = gr.State({})
14
  character_data = gr.State({})
15
  current_frame_data = gr.State(None)
 
16
 
17
  with gr.Row():
18
  with gr.Column():
@@ -98,6 +99,7 @@ with gr.Blocks() as demo:
98
  frame_dropdown,
99
  episodes_data,
100
  character_data,
 
101
  developer
102
  ],
103
  )
@@ -158,7 +160,7 @@ with gr.Blocks() as demo:
158
 
159
  next_button.click(
160
  core.load_data_next,
161
- inputs=[episodes_data, current_episode, current_frame],
162
  outputs=[
163
  episode_dropdown,
164
  frame_dropdown,
@@ -187,7 +189,7 @@ with gr.Blocks() as demo:
187
 
188
  prev_button.click(
189
  core.load_data_prev,
190
- inputs=[episodes_data, current_episode, current_frame],
191
  outputs=[
192
  episode_dropdown,
193
  frame_dropdown,
@@ -237,7 +239,9 @@ with gr.Blocks() as demo:
237
  current_episode,
238
  current_scene,
239
  current_frame,
240
- episodes_data
 
 
241
  ],
242
  outputs=[
243
  prompt_1,
@@ -263,7 +267,8 @@ with gr.Blocks() as demo:
263
  height,
264
  width
265
  ],
266
- outputs=[]
267
  )
268
 
269
  demo.launch(auth=("admin", "Qrt@12*34#immersfy"), share=True, ssr_mode=False, debug=True)
 
 
13
  episodes_data = gr.State({})
14
  character_data = gr.State({})
15
  current_frame_data = gr.State(None)
16
+ details = gr.State({})
17
 
18
  with gr.Row():
19
  with gr.Column():
 
99
  frame_dropdown,
100
  episodes_data,
101
  character_data,
102
+ details,
103
  developer
104
  ],
105
  )
 
160
 
161
  next_button.click(
162
  core.load_data_next,
163
+ inputs=[episodes_data, current_episode, current_frame, developer],
164
  outputs=[
165
  episode_dropdown,
166
  frame_dropdown,
 
189
 
190
  prev_button.click(
191
  core.load_data_prev,
192
+ inputs=[episodes_data, current_episode, current_frame, developer],
193
  outputs=[
194
  episode_dropdown,
195
  frame_dropdown,
 
239
  current_episode,
240
  current_scene,
241
  current_frame,
242
+ episodes_data,
243
+ details,
244
+ comic_id
245
  ],
246
  outputs=[
247
  prompt_1,
 
267
  height,
268
  width
269
  ],
270
+ outputs=[images]
271
  )
272
 
273
  demo.launch(auth=("admin", "Qrt@12*34#immersfy"), share=True, ssr_mode=False, debug=True)
274
+ # demo.launch(share=True, ssr_mode=False, debug=True)
core.py CHANGED
@@ -1,5 +1,3 @@
1
- """House of all specific functions used to control data flow."""
2
-
3
  from typing import List
4
  from PIL import Image
5
  import gradio as gr
@@ -12,6 +10,8 @@ import parameters
12
  import script_gen
13
  import inout as iowrapper
14
  import openai_wrapper
 
 
15
 
16
  AWS_BUCKET = parameters.AWS_BUCKET
17
  llm = openai_wrapper.GPT_4O_MINI
@@ -31,112 +31,151 @@ class ComicFrame:
31
  narration: str
32
  character_dilouge: str
33
  character: str
34
- location: str # Moved up here
35
  setting: str
36
  all_characters: list
37
- compositions: List[Composition] = dataclasses.field(
38
- default_factory=list
39
- ) # Keep this as the last argument
40
 
41
 
42
  def list_current_dir(bucket_name: str, folder_path: str = "") -> list:
43
- response = aws_utils.S3_CLIENT.list_objects_v2(
44
- Bucket=bucket_name, Prefix=folder_path, Delimiter="/"
45
- )
46
-
47
- # Check if the bucket contains objects
48
- folders = []
49
- if "CommonPrefixes" in response:
50
- for prefix in response["CommonPrefixes"]:
51
- folders.append(prefix["Prefix"])
52
- return folders
 
53
 
54
 
55
  def load_data_inner(
56
  episodes_data: list, current_episode: int, current_frame: int, is_developer: bool
57
  ):
58
- images = []
59
- curr_frame = episodes_data[current_episode][current_frame]
60
- print(episodes_data[current_episode][current_frame])
61
- # Loading the 0th frame of 0th scene in 0th episode.
62
- for comp in curr_frame.compositions:
63
- data = aws_utils.fetch_from_s3(comp.image)
64
- images.append(Image.open(io.BytesIO(data)))
65
- return (
66
- images,
67
- episodes_data,
68
- current_episode,
69
- current_frame,
70
- gr.Textbox(value=curr_frame.description, interactive=is_developer),
71
- gr.Textbox(value=curr_frame.narration, interactive=is_developer),
72
- gr.Textbox(value=curr_frame.character, interactive=is_developer),
73
- gr.Textbox(value=curr_frame.character_dilouge, interactive=is_developer),
74
- gr.Textbox(value=curr_frame.location, interactive=is_developer),
75
- curr_frame.setting,
76
- curr_frame.compositions[0].prompt,
77
- curr_frame.compositions[0].seed,
78
- curr_frame.compositions[1].prompt,
79
- curr_frame.compositions[1].seed,
80
- curr_frame.compositions[2].prompt,
81
- curr_frame.compositions[2].seed,
82
- curr_frame.compositions[3].prompt,
83
- curr_frame.compositions[3].seed,
84
- curr_frame.all_characters,
85
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  def load_metadata_fn(comic_id: str):
89
- # Logic to load and return images based on comic_id and episode
90
- # You can replace this with actual image paths or generation logic
91
- print(f"Getting episodes for comic id: {comic_id}")
92
- episodes_data = {}
93
- episode_idx = []
94
- character_data = {}
95
- character_path = f"s3://blix-demo-v0/{comic_id}/characters/characters.json"
96
- char_data = eval(aws_utils.fetch_from_s3(source=character_path).decode("utf-8"))
97
- for name, char in char_data.items():
98
- character_data[name] = char["profile_image"]
99
- print(character_data)
100
- for folder in list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/"):
101
- if "episode" in folder:
102
- json_path = f"s3://{AWS_BUCKET}/{folder}episode.json"
103
- idx = int(folder.split("/")[2].split("-")[-1])
104
- episode_idx.append(idx)
105
- data = eval(aws_utils.fetch_from_s3(source=json_path).decode("utf-8"))
106
- comic_frames = []
107
- for scene in data["scenes"]:
108
- for frame in scene["frames"]:
109
- comic_frames.append(
110
- ComicFrame(
111
- description=frame["description"],
112
- narration=frame["narration"],
113
- character=frame["audio_cue_character"],
114
- character_dilouge=frame["audio_cue_text"],
115
- compositions=[
116
- Composition(**comp) for comp in frame["compositions"]
117
- ],
118
- location=frame["location"],
119
- setting=frame["frame_setting"],
120
- all_characters=[
121
- char["name"] for char in frame["characters"]
122
- ],
 
 
 
 
 
 
 
 
 
123
  )
124
- )
125
- episodes_data[idx] = comic_frames
126
- current_episode, current_frame = min(episode_idx), 0
127
- return (
128
- gr.update(choices=episode_idx, value=episode_idx[0]),
129
- gr.update(
130
- choices=range(len(episodes_data[current_episode])), value=current_frame
131
- ),
132
- episodes_data,
133
- character_data,
134
- gr.Checkbox(visible=True),
135
- )
 
 
 
 
 
 
 
 
 
 
136
 
137
 
138
  def load_data_next(
139
- episodes_data: list, current_episode: int, current_frame: int, is_developer=False
140
  ):
141
  if current_frame + 1 < len(episodes_data[current_episode]):
142
  current_frame += 1
@@ -153,7 +192,7 @@ def load_data_next(
153
 
154
 
155
  def load_data_prev(
156
- episodes_data: list, current_episode: int, current_frame: int, is_developer=False
157
  ):
158
  if current_frame - 1 >= 0:
159
  current_frame -= 1
@@ -180,32 +219,18 @@ def load_from_dropdown(
180
 
181
 
182
  def load_dropdown_fn(selected_episode):
183
- return (
184
- gr.update(value=selected_episode),
185
- gr.update(value=0),
186
- selected_episode,
187
- 0,
188
- )
189
 
190
 
191
  def load_dropdown_fn_v2(selected_frame):
192
  return selected_frame
193
 
194
 
195
- def save_image(
196
- selected_image,
197
- comic_id: str,
198
- current_episode: int,
199
- current_frame: int,
200
- ):
201
- # Implement your AWS S3 save logic here
202
- # print(f"Saving image: {selected_image}")
203
  with Image.open(selected_image[0]) as img:
204
- # Convert and save as JPG
205
  img_bytes = io.BytesIO()
206
  img.convert("RGB").save(img_bytes, "JPEG")
207
  img_bytes.seek(0)
208
-
209
  aws_utils.save_to_s3(
210
  AWS_BUCKET,
211
  f"{comic_id}/episode-{current_episode}/images",
@@ -226,64 +251,92 @@ def regenerate_composition_data(
226
  current_scene: int,
227
  current_frame: int,
228
  episodes_data: dict,
 
 
229
  ):
230
- print(
231
- f"Generating compositions for episode: {current_episode} and scene: {current_scene} and frame: {current_frame}."
232
- )
233
-
234
- # Retrieve the current frame data
235
- frame = episodes_data[current_episode][current_frame]
236
-
237
- # Generate the prompt for compositions
238
- prompt_dict = {
239
- "system": script_gen.generate_image_compositions_instruction,
240
- "user": jinja2.Template(
241
- source=script_gen.generate_image_compositions_user_prompt
242
- ).render(
243
- {
244
- "FRAME": {
245
- "description": frame.description,
246
- "narration": frame.narration,
247
- "character_dilouge": frame.character_dilouge,
248
- "character": frame.character,
249
- "location": frame.location,
250
- "setting": frame.setting,
251
- "all_characters": frame.all_characters,
252
- },
253
- }
254
- ),
255
- }
256
- # Generate composition s using LLM
257
- compositions = llm.generate_valid_json_response(prompt_dict)
258
- # Update frame with new compositions
259
- frame.compositions = [
260
- Composition(
261
- **composition,
262
- seed=frame.compositions[idx].seed if idx < len(frame.compositions) else "",
263
- image=(
264
- frame.compositions[idx].image if idx < len(frame.compositions) else ""
265
- ),
266
  )
267
- for idx, composition in enumerate(compositions["compositions"])
268
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- # Update the episodes_data dictionary with the modified frame
271
- episodes_data[current_episode][current_frame] = frame
272
- print(
273
- f"Updated frame {current_frame} in episode {current_episode} with new compositions."
274
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
- # Return the updated composition values for the UI
277
- return [
278
- frame.compositions[0].prompt,
279
- frame.compositions[0].seed,
280
- frame.compositions[1].prompt,
281
- frame.compositions[1].seed,
282
- frame.compositions[2].prompt,
283
- frame.compositions[2].seed,
284
- frame.compositions[3].prompt,
285
- frame.compositions[3].seed,
286
- ]
 
287
 
288
 
289
  def regenerate_data(
@@ -297,9 +350,10 @@ def regenerate_data(
297
  height,
298
  width,
299
  ):
300
-
301
  frame = episodes_data[current_episode][current_frame]
302
  related_chars = [character_data[ch] for ch in frame.all_characters]
 
303
  for i, composition in enumerate(frame.compositions):
304
  payload = {
305
  "prompt": composition.prompt,
@@ -312,18 +366,20 @@ def regenerate_data(
312
  },
313
  }
314
 
315
- data = iowrapper.get_valid_post_response(
316
- url="http://10.100.111.13:4389/generate_image",
317
- payload=payload,
318
- )
319
- image_data = io.BytesIO(base64.b64decode(data["image"]))
320
- path = aws_utils.save_to_s3(
321
- parameters.AWS_BUCKET,
322
- f"{comic_id}/episodes/episode-{current_episode}/compositions/scene-{0}/frame-{current_frame}",
323
- image_data,
324
- f"{i}.jpg",
325
- )
326
- load_data_inner(episodes_data, current_episode, current_frame, is_developer=True)
327
-
328
-
329
- # pass
 
 
 
 
 
1
  from typing import List
2
  from PIL import Image
3
  import gradio as gr
 
10
  import script_gen
11
  import inout as iowrapper
12
  import openai_wrapper
13
+ import json
14
+ from dataclasses import asdict
15
 
16
  AWS_BUCKET = parameters.AWS_BUCKET
17
  llm = openai_wrapper.GPT_4O_MINI
 
31
  narration: str
32
  character_dilouge: str
33
  character: str
34
+ location: str
35
  setting: str
36
  all_characters: list
37
+ compositions: List[Composition] = dataclasses.field(default_factory=list)
 
 
38
 
39
 
40
  def list_current_dir(bucket_name: str, folder_path: str = "") -> list:
41
+ try:
42
+ response = aws_utils.S3_CLIENT.list_objects_v2(
43
+ Bucket=bucket_name, Prefix=folder_path, Delimiter="/"
44
+ )
45
+ folders = []
46
+ if "CommonPrefixes" in response:
47
+ for prefix in response["CommonPrefixes"]:
48
+ folders.append(prefix["Prefix"])
49
+ return folders
50
+ except Exception as e:
51
+ return []
52
 
53
 
54
  def load_data_inner(
55
  episodes_data: list, current_episode: int, current_frame: int, is_developer: bool
56
  ):
57
+ try:
58
+ images = []
59
+ curr_frame = episodes_data[current_episode][current_frame]
60
+ for comp in curr_frame.compositions:
61
+ data = aws_utils.fetch_from_s3(comp.image)
62
+ images.append(Image.open(io.BytesIO(data)))
63
+ return (
64
+ images,
65
+ episodes_data,
66
+ current_episode,
67
+ current_frame,
68
+ gr.Textbox(value=curr_frame.description, interactive=is_developer),
69
+ gr.Textbox(value=curr_frame.narration, interactive=is_developer),
70
+ gr.Textbox(value=curr_frame.character, interactive=is_developer),
71
+ gr.Textbox(value=curr_frame.character_dilouge, interactive=is_developer),
72
+ gr.Textbox(value=curr_frame.location, interactive=is_developer),
73
+ curr_frame.setting,
74
+ curr_frame.compositions[0].prompt,
75
+ curr_frame.compositions[0].seed,
76
+ curr_frame.compositions[1].prompt,
77
+ curr_frame.compositions[1].seed,
78
+ curr_frame.compositions[2].prompt,
79
+ curr_frame.compositions[2].seed,
80
+ curr_frame.compositions[3].prompt,
81
+ curr_frame.compositions[3].seed,
82
+ curr_frame.all_characters,
83
+ )
84
+ except Exception as e:
85
+ return (
86
+ [],
87
+ episodes_data,
88
+ current_episode,
89
+ current_frame,
90
+ gr.Textbox(),
91
+ gr.Textbox(),
92
+ gr.Textbox(),
93
+ gr.Textbox(),
94
+ gr.Textbox(),
95
+ "",
96
+ "",
97
+ "",
98
+ "",
99
+ "",
100
+ "",
101
+ "",
102
+ "",
103
+ "",
104
+ [],
105
+ )
106
 
107
 
108
  def load_metadata_fn(comic_id: str):
109
+ try:
110
+ episodes_data = {}
111
+ episode_idx = []
112
+ character_data = {}
113
+ details = {}
114
+ character_path = f"s3://blix-demo-v0/{comic_id}/characters/characters.json"
115
+ char_data = eval(aws_utils.fetch_from_s3(source=character_path).decode("utf-8"))
116
+
117
+ for name, char in char_data.items():
118
+ character_data[name] = char["profile_image"]
119
+
120
+ for folder in list_current_dir(AWS_BUCKET, f"{comic_id}/episodes/"):
121
+ if "episode" in folder:
122
+ json_path = f"s3://{AWS_BUCKET}/{folder}episode.json"
123
+ idx = int(folder.split("/")[2].split("-")[-1])
124
+ episode_idx.append(idx)
125
+ data = eval(aws_utils.fetch_from_s3(source=json_path).decode("utf-8"))
126
+ comic_frames = []
127
+ details[idx] = {}
128
+ cumulative_frame_count = 0
129
+
130
+ for scene_num, scene in enumerate(data["scenes"]):
131
+ scene_frame_count = len(scene["frames"])
132
+ cumulative_frame_count += scene_frame_count
133
+ details[idx][scene_num] = cumulative_frame_count
134
+
135
+ for frame in scene["frames"]:
136
+ comic_frames.append(
137
+ ComicFrame(
138
+ description=frame["description"],
139
+ narration=frame["narration"],
140
+ character=frame["audio_cue_character"],
141
+ character_dilouge=frame["audio_cue_text"],
142
+ compositions=[
143
+ Composition(**comp)
144
+ for comp in frame["compositions"]
145
+ ],
146
+ location=frame["location"],
147
+ setting=frame["frame_setting"],
148
+ all_characters=[
149
+ char["name"] for char in frame["characters"]
150
+ ],
151
+ )
152
  )
153
+ episodes_data[idx] = comic_frames
154
+
155
+ current_episode, current_frame = min(episode_idx), 0
156
+ return (
157
+ gr.update(choices=episode_idx, value=episode_idx[0]),
158
+ gr.update(
159
+ choices=range(len(episodes_data[current_episode])), value=current_frame
160
+ ),
161
+ episodes_data,
162
+ character_data,
163
+ details,
164
+ gr.Checkbox(visible=True),
165
+ )
166
+ except Exception as e:
167
+ return (
168
+ gr.update(choices=[]),
169
+ gr.update(choices=[]),
170
+ {},
171
+ {},
172
+ {},
173
+ gr.Checkbox(visible=False),
174
+ )
175
 
176
 
177
  def load_data_next(
178
+ episodes_data: list, current_episode: int, current_frame: int, is_developer: bool
179
  ):
180
  if current_frame + 1 < len(episodes_data[current_episode]):
181
  current_frame += 1
 
192
 
193
 
194
  def load_data_prev(
195
+ episodes_data: list, current_episode: int, current_frame: int, is_developer: bool
196
  ):
197
  if current_frame - 1 >= 0:
198
  current_frame -= 1
 
219
 
220
 
221
  def load_dropdown_fn(selected_episode):
222
+ return (gr.update(value=selected_episode), gr.update(value=0), selected_episode, 0)
 
 
 
 
 
223
 
224
 
225
  def load_dropdown_fn_v2(selected_frame):
226
  return selected_frame
227
 
228
 
229
+ def save_image(selected_image, comic_id: str, current_episode: int, current_frame: int):
 
 
 
 
 
 
 
230
  with Image.open(selected_image[0]) as img:
 
231
  img_bytes = io.BytesIO()
232
  img.convert("RGB").save(img_bytes, "JPEG")
233
  img_bytes.seek(0)
 
234
  aws_utils.save_to_s3(
235
  AWS_BUCKET,
236
  f"{comic_id}/episode-{current_episode}/images",
 
251
  current_scene: int,
252
  current_frame: int,
253
  episodes_data: dict,
254
+ details: dict,
255
+ comic_id: str,
256
  ):
257
+ try:
258
+ episode_details = details.get(current_episode)
259
+ if not episode_details:
260
+ return
261
+
262
+ scene_num, frame_num_in_scene = None, None
263
+ prev_frame_count = 0
264
+
265
+ for scene_idx, cumulative_frame_count in episode_details.items():
266
+ if current_frame < cumulative_frame_count:
267
+ scene_num = scene_idx
268
+ frame_num_in_scene = current_frame - prev_frame_count
269
+ break
270
+ prev_frame_count = cumulative_frame_count
271
+
272
+ if scene_num is None:
273
+ return
274
+
275
+ frame = episodes_data[current_episode][current_frame]
276
+ prompt_template = jinja2.Template(
277
+ script_gen.generate_image_compositions_user_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  )
279
+ prompt_dict = {
280
+ "system": script_gen.generate_image_compositions_instruction,
281
+ "user": prompt_template.render(
282
+ {
283
+ "FRAME": {
284
+ "description": frame.description,
285
+ "narration": frame.narration,
286
+ "character_dilouge": frame.character_dilouge,
287
+ "character": frame.character,
288
+ "location": frame.location,
289
+ "setting": frame.setting,
290
+ "all_characters": frame.all_characters,
291
+ }
292
+ }
293
+ ),
294
+ }
295
 
296
+ compositions = llm.generate_valid_json_response(prompt_dict)
297
+ frame.compositions = [
298
+ Composition(
299
+ **comp,
300
+ seed=(
301
+ frame.compositions[idx].seed
302
+ if idx < len(frame.compositions)
303
+ else ""
304
+ ),
305
+ image=(
306
+ frame.compositions[idx].image
307
+ if idx < len(frame.compositions)
308
+ else ""
309
+ ),
310
+ )
311
+ for idx, comp in enumerate(compositions["compositions"])
312
+ ]
313
+
314
+ episode_path = f"s3://blix-demo-v0/{comic_id}/episodes/episode-{current_episode}/episode.json"
315
+ episode = json.loads(aws_utils.fetch_from_s3(episode_path).decode("utf-8"))
316
+ episode["scenes"][scene_num]["frames"][frame_num_in_scene]["compositions"] = [
317
+ asdict(comp) for comp in frame.compositions
318
+ ]
319
+
320
+ episode_json = json.dumps(episode)
321
+ aws_utils.save_to_s3(
322
+ bucket_name=parameters.AWS_BUCKET,
323
+ folder_name=f"{comic_id}/episodes/episode-{current_episode}",
324
+ content=episode_json,
325
+ file_name="episode.json",
326
+ )
327
 
328
+ return [
329
+ frame.compositions[0].prompt,
330
+ frame.compositions[0].seed,
331
+ frame.compositions[1].prompt,
332
+ frame.compositions[1].seed,
333
+ frame.compositions[2].prompt,
334
+ frame.compositions[2].seed,
335
+ frame.compositions[3].prompt,
336
+ frame.compositions[3].seed,
337
+ ]
338
+ except Exception as e:
339
+ return [""] * 8
340
 
341
 
342
  def regenerate_data(
 
350
  height,
351
  width,
352
  ):
353
+ images = []
354
  frame = episodes_data[current_episode][current_frame]
355
  related_chars = [character_data[ch] for ch in frame.all_characters]
356
+
357
  for i, composition in enumerate(frame.compositions):
358
  payload = {
359
  "prompt": composition.prompt,
 
366
  },
367
  }
368
 
369
+ try:
370
+ data = iowrapper.get_valid_post_response(
371
+ url="http://10.100.111.13:4389/generate_image",
372
+ payload=payload,
373
+ )
374
+ image_data = io.BytesIO(base64.b64decode(data["image"]))
375
+ aws_utils.save_to_s3(
376
+ parameters.AWS_BUCKET,
377
+ f"{comic_id}/episodes/episode-{current_episode}/compositions/scene-{0}/frame-{current_frame}",
378
+ image_data,
379
+ f"{i}.jpg",
380
+ )
381
+ images.append(Image.open(image_data))
382
+ except Exception as e:
383
+ continue
384
+
385
+ return images