Nomnommish commited on
Commit
5efedb6
·
verified ·
1 Parent(s): 453bc8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +239 -23
app.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  import base64
5
  import mimetypes
6
  import tempfile
 
7
  from pathlib import Path
8
  from urllib.parse import quote
9
 
@@ -22,7 +23,7 @@ IMAGE_RESOLUTIONS = ["1k", "2k"]
22
  VIDEO_ASPECT_RATIOS = ["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3"]
23
  VIDEO_RESOLUTIONS = ["480p", "720p"]
24
 
25
- APP_TITLE = "xAI Imagine Studio — T2I + I2I + I2V + V2V"
26
 
27
 
28
  def auth_headers(api_key: str) -> dict:
@@ -49,6 +50,25 @@ def file_to_data_uri(filepath: str) -> str:
49
  return f"data:{mime};base64,{b64}"
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def download_bytes_to_temp(content: bytes, suffix: str) -> str:
53
  fd, out_path = tempfile.mkstemp(suffix=suffix)
54
  os.close(fd)
@@ -100,6 +120,31 @@ def extract_video_path(video_input):
100
  return str(video_input)
101
 
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def generate_t2i(api_key, model, prompt, n, aspect_ratio, resolution, progress=gr.Progress(track_tqdm=False)):
104
  headers = auth_headers(api_key)
105
  payload = {
@@ -146,28 +191,12 @@ def generate_t2i(api_key, model, prompt, n, aspect_ratio, resolution, progress=g
146
  return gallery, paths[0], paths, f"Generated {len(paths)} image(s)."
147
 
148
 
149
- def edit_like_i2i(api_key, model, prompt, input_image_path, aspect_ratio, progress=gr.Progress(track_tqdm=False)):
150
  headers = auth_headers(api_key)
 
 
151
 
152
- if not input_image_path:
153
- raise gr.Error("Please upload an image.")
154
- if not (prompt or "").strip():
155
- raise gr.Error("Please enter a prompt.")
156
-
157
- payload = {
158
- "model": model or DEFAULT_IMAGE_MODEL,
159
- "prompt": prompt.strip(),
160
- "image": {
161
- "url": file_to_data_uri(input_image_path),
162
- "type": "image_url",
163
- },
164
- "response_format": "b64_json",
165
- }
166
-
167
- if aspect_ratio and aspect_ratio != "auto":
168
- payload["aspect_ratio"] = aspect_ratio
169
-
170
- progress(0.2, desc="Editing image...")
171
  resp = requests.post(
172
  f"{API_BASE}/images/edits",
173
  headers=headers,
@@ -332,6 +361,98 @@ def generate_v2v(
332
  return out, out, f"V2V complete. Request ID: {request_id}. Duration: {actual_duration}s"
333
 
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  css_path = Path("style.css")
336
  css = css_path.read_text(encoding="utf-8") if css_path.exists() else ""
337
 
@@ -343,6 +464,9 @@ with gr.Blocks(title=APP_TITLE, css=css, theme=gr.themes.Soft()) as demo:
343
 
344
  t2i_first_image_state = gr.State(None)
345
  t2i_all_images_state = gr.State([])
 
 
 
346
 
347
  with gr.Row():
348
  with gr.Column(scale=2):
@@ -404,7 +528,12 @@ with gr.Blocks(title=APP_TITLE, css=css, theme=gr.themes.Soft()) as demo:
404
  with gr.Tab("Image → Image"):
405
  with gr.Row():
406
  with gr.Column():
407
- i2i_input = gr.Image(label="Upload Source Image", type="filepath")
 
 
 
 
 
408
  i2i_prompt = gr.Textbox(label="Transformation Prompt", lines=6)
409
  i2i_aspect = gr.Dropdown(label="Aspect Ratio Override", choices=IMAGE_ASPECT_RATIOS, value="auto")
410
  i2i_btn = gr.Button("Generate I2I", variant="primary")
@@ -417,7 +546,12 @@ with gr.Blocks(title=APP_TITLE, css=css, theme=gr.themes.Soft()) as demo:
417
  with gr.Tab("Image Edit"):
418
  with gr.Row():
419
  with gr.Column():
420
- edit_input = gr.Image(label="Upload Image", type="filepath")
 
 
 
 
 
421
  edit_prompt = gr.Textbox(label="Edit Prompt", lines=6)
422
  edit_aspect = gr.Dropdown(label="Aspect Ratio Override", choices=IMAGE_ASPECT_RATIOS, value="auto")
423
  edit_btn = gr.Button("Edit Image", variant="primary")
@@ -456,6 +590,37 @@ with gr.Blocks(title=APP_TITLE, css=css, theme=gr.themes.Soft()) as demo:
456
  v2v_video_out = gr.Video(label="Generated V2V Video")
457
  v2v_download = gr.File(label="Download V2V Video")
458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  t2i_btn.click(
460
  fn=generate_t2i,
461
  inputs=[api_key, image_model, t2i_prompt, t2i_n, t2i_aspect, t2i_resolution],
@@ -499,6 +664,11 @@ with gr.Blocks(title=APP_TITLE, css=css, theme=gr.themes.Soft()) as demo:
499
  ],
500
  outputs=[i2v_video, i2v_download, i2v_status],
501
  api_name=False,
 
 
 
 
 
502
  )
503
 
504
  v2v_btn.click(
@@ -506,6 +676,52 @@ with gr.Blocks(title=APP_TITLE, css=css, theme=gr.themes.Soft()) as demo:
506
  inputs=[api_key, video_model, v2v_prompt, v2v_video_in, poll_timeout, poll_interval],
507
  outputs=[v2v_video_out, v2v_download, v2v_status],
508
  api_name=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  )
510
 
511
  if __name__ == "__main__":
 
4
  import base64
5
  import mimetypes
6
  import tempfile
7
+ import subprocess
8
  from pathlib import Path
9
  from urllib.parse import quote
10
 
 
23
  VIDEO_ASPECT_RATIOS = ["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3"]
24
  VIDEO_RESOLUTIONS = ["480p", "720p"]
25
 
26
+ APP_TITLE = "xAI Imagine Studio — T2I + I2I + I2V + V2V + Video Expand"
27
 
28
 
29
  def auth_headers(api_key: str) -> dict:
 
50
  return f"data:{mime};base64,{b64}"
51
 
52
 
53
+ def normalize_uploaded_files(files):
54
+ if not files:
55
+ return []
56
+ if isinstance(files, (str, Path)):
57
+ return [str(files)]
58
+
59
+ normalized = []
60
+ for item in files:
61
+ if isinstance(item, str):
62
+ normalized.append(item)
63
+ elif isinstance(item, dict):
64
+ path = item.get("path") or item.get("name")
65
+ if path:
66
+ normalized.append(path)
67
+ else:
68
+ normalized.append(str(item))
69
+ return normalized
70
+
71
+
72
  def download_bytes_to_temp(content: bytes, suffix: str) -> str:
73
  fd, out_path = tempfile.mkstemp(suffix=suffix)
74
  os.close(fd)
 
120
  return str(video_input)
121
 
122
 
123
+ def build_image_edit_payload(prompt: str, image_paths: list[str], model: str, aspect_ratio: str):
124
+ if not image_paths:
125
+ raise gr.Error("Please upload at least one image.")
126
+ if not (prompt or "").strip():
127
+ raise gr.Error("Please enter a prompt.")
128
+
129
+ images = [{"url": file_to_data_uri(path), "type": "image_url"} for path in image_paths]
130
+
131
+ payload = {
132
+ "model": model or DEFAULT_IMAGE_MODEL,
133
+ "prompt": prompt.strip(),
134
+ "response_format": "b64_json",
135
+ }
136
+
137
+ if len(images) == 1:
138
+ payload["image"] = images[0]
139
+ else:
140
+ payload["images"] = images
141
+
142
+ if aspect_ratio and aspect_ratio != "auto":
143
+ payload["aspect_ratio"] = aspect_ratio
144
+
145
+ return payload
146
+
147
+
148
  def generate_t2i(api_key, model, prompt, n, aspect_ratio, resolution, progress=gr.Progress(track_tqdm=False)):
149
  headers = auth_headers(api_key)
150
  payload = {
 
191
  return gallery, paths[0], paths, f"Generated {len(paths)} image(s)."
192
 
193
 
194
+ def edit_like_i2i(api_key, model, prompt, input_images, aspect_ratio, progress=gr.Progress(track_tqdm=False)):
195
  headers = auth_headers(api_key)
196
+ image_paths = normalize_uploaded_files(input_images)
197
+ payload = build_image_edit_payload(prompt, image_paths, model, aspect_ratio)
198
 
199
+ progress(0.2, desc="Editing image(s)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  resp = requests.post(
201
  f"{API_BASE}/images/edits",
202
  headers=headers,
 
361
  return out, out, f"V2V complete. Request ID: {request_id}. Duration: {actual_duration}s"
362
 
363
 
364
+ def get_video_duration(video_path: str) -> float:
365
+ cmd = [
366
+ "ffprobe",
367
+ "-v", "error",
368
+ "-show_entries", "format=duration",
369
+ "-of", "default=noprint_wrappers=1:nokey=1",
370
+ video_path,
371
+ ]
372
+ result = subprocess.run(cmd, capture_output=True, text=True)
373
+ if result.returncode != 0:
374
+ raise gr.Error(f"Could not read video duration:\n{result.stderr}")
375
+ try:
376
+ return float(result.stdout.strip())
377
+ except Exception:
378
+ raise gr.Error("Could not parse video duration.")
379
+
380
+
381
+ def extract_frame_from_video(video_path: str, seconds: float) -> str:
382
+ fd, frame_path = tempfile.mkstemp(suffix=".png")
383
+ os.close(fd)
384
+
385
+ cmd = [
386
+ "ffmpeg",
387
+ "-y",
388
+ "-ss", str(seconds),
389
+ "-i", video_path,
390
+ "-frames:v", "1",
391
+ frame_path,
392
+ ]
393
+ result = subprocess.run(cmd, capture_output=True, text=True)
394
+ if result.returncode != 0:
395
+ raise gr.Error(f"Frame extraction failed:\n{result.stderr}")
396
+
397
+ return frame_path
398
+
399
+
400
+ def prepare_expand_video(video_input, use_last_generated_video, last_generated_video_path):
401
+ video_path = last_generated_video_path if use_last_generated_video and last_generated_video_path else extract_video_path(video_input)
402
+ if not video_path:
403
+ raise gr.Error("Upload a video or enable 'Use last generated video'.")
404
+
405
+ duration = get_video_duration(video_path)
406
+ max_time = max(0.1, duration)
407
+ return video_path, gr.update(maximum=max_time, value=min(max_time / 2, max_time)), f"Loaded video. Duration: {duration:.2f}s"
408
+
409
+
410
+ def extract_expand_frame(
411
+ video_input,
412
+ use_last_generated_video,
413
+ last_generated_video_path,
414
+ timestamp_seconds,
415
+ ):
416
+ video_path = last_generated_video_path if use_last_generated_video and last_generated_video_path else extract_video_path(video_input)
417
+ if not video_path:
418
+ raise gr.Error("Upload a video or enable 'Use last generated video'.")
419
+
420
+ duration = get_video_duration(video_path)
421
+ ts = max(0.0, min(float(timestamp_seconds), max(duration - 0.01, 0.0)))
422
+ frame_path = extract_frame_from_video(video_path, ts)
423
+ return frame_path, frame_path, f"Extracted frame at {ts:.2f}s"
424
+
425
+
426
+ def continue_video_from_frame(
427
+ api_key,
428
+ model,
429
+ prompt,
430
+ extracted_frame_path,
431
+ duration,
432
+ aspect_ratio,
433
+ resolution,
434
+ timeout_seconds,
435
+ poll_interval,
436
+ progress=gr.Progress(track_tqdm=False),
437
+ ):
438
+ if not extracted_frame_path:
439
+ raise gr.Error("Extract a frame first.")
440
+ return generate_i2v(
441
+ api_key=api_key,
442
+ model=model,
443
+ prompt=prompt,
444
+ uploaded_image_path=extracted_frame_path,
445
+ use_last_t2i_image=False,
446
+ last_t2i_first_image=None,
447
+ duration=duration,
448
+ aspect_ratio=aspect_ratio,
449
+ resolution=resolution,
450
+ timeout_seconds=timeout_seconds,
451
+ poll_interval=poll_interval,
452
+ progress=progress,
453
+ )
454
+
455
+
456
  css_path = Path("style.css")
457
  css = css_path.read_text(encoding="utf-8") if css_path.exists() else ""
458
 
 
464
 
465
  t2i_first_image_state = gr.State(None)
466
  t2i_all_images_state = gr.State([])
467
+ last_generated_video_state = gr.State(None)
468
+ expand_source_video_state = gr.State(None)
469
+ expand_frame_state = gr.State(None)
470
 
471
  with gr.Row():
472
  with gr.Column(scale=2):
 
528
  with gr.Tab("Image → Image"):
529
  with gr.Row():
530
  with gr.Column():
531
+ i2i_input = gr.File(
532
+ label="Upload Source Image(s)",
533
+ file_count="multiple",
534
+ file_types=["image"],
535
+ type="filepath",
536
+ )
537
  i2i_prompt = gr.Textbox(label="Transformation Prompt", lines=6)
538
  i2i_aspect = gr.Dropdown(label="Aspect Ratio Override", choices=IMAGE_ASPECT_RATIOS, value="auto")
539
  i2i_btn = gr.Button("Generate I2I", variant="primary")
 
546
  with gr.Tab("Image Edit"):
547
  with gr.Row():
548
  with gr.Column():
549
+ edit_input = gr.File(
550
+ label="Upload Image(s)",
551
+ file_count="multiple",
552
+ file_types=["image"],
553
+ type="filepath",
554
+ )
555
  edit_prompt = gr.Textbox(label="Edit Prompt", lines=6)
556
  edit_aspect = gr.Dropdown(label="Aspect Ratio Override", choices=IMAGE_ASPECT_RATIOS, value="auto")
557
  edit_btn = gr.Button("Edit Image", variant="primary")
 
590
  v2v_video_out = gr.Video(label="Generated V2V Video")
591
  v2v_download = gr.File(label="Download V2V Video")
592
 
593
+ with gr.Tab("Video Expand"):
594
+ with gr.Row():
595
+ with gr.Column():
596
+ expand_video_input = gr.Video(label="Upload Source Video")
597
+ use_last_generated_video = gr.Checkbox(label="Use last generated video", value=True)
598
+ expand_load_btn = gr.Button("Load Video", variant="secondary")
599
+ expand_video_status = gr.Textbox(label="Video Status", interactive=False, lines=3)
600
+
601
+ expand_timestamp = gr.Slider(
602
+ label="Frame timestamp (seconds)",
603
+ minimum=0,
604
+ maximum=10,
605
+ step=0.1,
606
+ value=0,
607
+ )
608
+ expand_extract_btn = gr.Button("Extract Frame", variant="secondary")
609
+ expand_frame_status = gr.Textbox(label="Frame Status", interactive=False, lines=3)
610
+
611
+ expand_prompt = gr.Textbox(label="Continuation Prompt", lines=6)
612
+ expand_duration = gr.Slider(label="Next Segment Duration", minimum=1, maximum=15, step=1, value=5)
613
+ expand_aspect = gr.Dropdown(label="Aspect Ratio", choices=VIDEO_ASPECT_RATIOS, value="16:9")
614
+ expand_resolution = gr.Dropdown(label="Resolution", choices=VIDEO_RESOLUTIONS, value="480p")
615
+ expand_btn = gr.Button("Generate Next Video Segment", variant="primary")
616
+ expand_status = gr.Textbox(label="Expand Status", interactive=False, lines=5)
617
+
618
+ with gr.Column():
619
+ expand_frame_preview = gr.Image(label="Extracted Frame", type="filepath")
620
+ expand_frame_download = gr.File(label="Download Extracted Frame")
621
+ expand_video_out = gr.Video(label="Expanded Video")
622
+ expand_video_download = gr.File(label="Download Expanded Video")
623
+
624
  t2i_btn.click(
625
  fn=generate_t2i,
626
  inputs=[api_key, image_model, t2i_prompt, t2i_n, t2i_aspect, t2i_resolution],
 
664
  ],
665
  outputs=[i2v_video, i2v_download, i2v_status],
666
  api_name=False,
667
+ ).then(
668
+ fn=lambda p: p,
669
+ inputs=[i2v_video],
670
+ outputs=[last_generated_video_state],
671
+ api_name=False,
672
  )
673
 
674
  v2v_btn.click(
 
676
  inputs=[api_key, video_model, v2v_prompt, v2v_video_in, poll_timeout, poll_interval],
677
  outputs=[v2v_video_out, v2v_download, v2v_status],
678
  api_name=False,
679
+ ).then(
680
+ fn=lambda p: p,
681
+ inputs=[v2v_video_out],
682
+ outputs=[last_generated_video_state],
683
+ api_name=False,
684
+ )
685
+
686
+ expand_load_btn.click(
687
+ fn=prepare_expand_video,
688
+ inputs=[expand_video_input, use_last_generated_video, last_generated_video_state],
689
+ outputs=[expand_source_video_state, expand_timestamp, expand_video_status],
690
+ api_name=False,
691
+ )
692
+
693
+ expand_extract_btn.click(
694
+ fn=extract_expand_frame,
695
+ inputs=[expand_video_input, use_last_generated_video, last_generated_video_state, expand_timestamp],
696
+ outputs=[expand_frame_preview, expand_frame_download, expand_frame_status],
697
+ api_name=False,
698
+ ).then(
699
+ fn=lambda p: p,
700
+ inputs=[expand_frame_preview],
701
+ outputs=[expand_frame_state],
702
+ api_name=False,
703
+ )
704
+
705
+ expand_btn.click(
706
+ fn=continue_video_from_frame,
707
+ inputs=[
708
+ api_key,
709
+ video_model,
710
+ expand_prompt,
711
+ expand_frame_state,
712
+ expand_duration,
713
+ expand_aspect,
714
+ expand_resolution,
715
+ poll_timeout,
716
+ poll_interval,
717
+ ],
718
+ outputs=[expand_video_out, expand_video_download, expand_status],
719
+ api_name=False,
720
+ ).then(
721
+ fn=lambda p: p,
722
+ inputs=[expand_video_out],
723
+ outputs=[last_generated_video_state],
724
+ api_name=False,
725
  )
726
 
727
  if __name__ == "__main__":