Fabrice-TIERCELIN commited on
Commit
6d0c1b7
·
verified ·
1 Parent(s): e45b74a

Merge code

Browse files
Files changed (1) hide show
  1. app_v2v.py +279 -0
app_v2v.py CHANGED
@@ -296,6 +296,285 @@ def set_mp4_comments_imageio_ffmpeg(input_file, comments):
296
  print(f"Error saving prompt to video metadata, ffmpeg may be required: "+str(e))
297
  return False
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  # 20250506 pftq: Modified worker to accept video input and clean frame count
300
  @spaces.GPU()
301
  @torch.no_grad()
 
296
  print(f"Error saving prompt to video metadata, ffmpeg may be required: "+str(e))
297
  return False
298
 
299
+ @torch.no_grad()
300
+ def worker(input_image, prompts, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
301
+ def encode_prompt(prompt, n_prompt):
302
+ llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
303
+
304
+ if cfg == 1:
305
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
306
+ else:
307
+ llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
308
+
309
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
310
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
311
+
312
+ llama_vec = llama_vec.to(transformer.dtype)
313
+ llama_vec_n = llama_vec_n.to(transformer.dtype)
314
+ clip_l_pooler = clip_l_pooler.to(transformer.dtype)
315
+ clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
316
+ return [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n]
317
+
318
+ total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
319
+ total_latent_sections = int(max(round(total_latent_sections), 1))
320
+
321
+ job_id = generate_timestamp()
322
+
323
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
324
+
325
+ try:
326
+ # Clean GPU
327
+ if not high_vram:
328
+ unload_complete_models(
329
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
330
+ )
331
+
332
+ # Text encoding
333
+
334
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
335
+
336
+ if not high_vram:
337
+ fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
338
+ load_model_as_complete(text_encoder_2, target_device=gpu)
339
+
340
+ prompt_parameters = []
341
+
342
+ for prompt_part in prompts:
343
+ prompt_parameters.append(encode_prompt(prompt_part, n_prompt))
344
+
345
+ # Processing input image
346
+
347
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
348
+
349
+ H, W, C = input_image.shape
350
+ height, width = find_nearest_bucket(H, W, resolution=640)
351
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
352
+
353
+ Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
354
+
355
+ input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
356
+ input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
357
+
358
+ # VAE encoding
359
+
360
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
361
+
362
+ if not high_vram:
363
+ load_model_as_complete(vae, target_device=gpu)
364
+
365
+ start_latent = vae_encode(input_image_pt, vae)
366
+
367
+ # CLIP Vision
368
+
369
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
370
+
371
+ if not high_vram:
372
+ load_model_as_complete(image_encoder, target_device=gpu)
373
+
374
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
375
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
376
+
377
+ # Dtype
378
+
379
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
380
+
381
+ # Sampling
382
+
383
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
384
+
385
+ rnd = torch.Generator("cpu").manual_seed(seed)
386
+
387
+ history_latents = torch.zeros(size=(1, 16, 16 + 2 + 1, height // 8, width // 8), dtype=torch.float32).cpu()
388
+ history_pixels = None
389
+
390
+ history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
391
+ total_generated_latent_frames = 1
392
+
393
+ for section_index in range(total_latent_sections):
394
+ if stream.input_queue.top() == 'end':
395
+ stream.output_queue.push(('end', None))
396
+ return
397
+
398
+ print(f'section_index = {section_index}, total_latent_sections = {total_latent_sections}')
399
+
400
+ if len(prompt_parameters) > 0:
401
+ [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n] = prompt_parameters.pop(0)
402
+
403
+ if not high_vram:
404
+ unload_complete_models()
405
+ move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
406
+
407
+ if use_teacache:
408
+ transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
409
+ else:
410
+ transformer.initialize_teacache(enable_teacache=False)
411
+
412
+ def callback(d):
413
+ preview = d['denoised']
414
+ preview = vae_decode_fake(preview)
415
+
416
+ preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
417
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
418
+
419
+ if stream.input_queue.top() == 'end':
420
+ stream.output_queue.push(('end', None))
421
+ raise KeyboardInterrupt('User ends the task.')
422
+
423
+ current_step = d['i'] + 1
424
+ percentage = int(100.0 * current_step / steps)
425
+ hint = f'Sampling {current_step}/{steps}'
426
+ desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). The video is being extended now ...'
427
+ stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
428
+ return
429
+
430
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
431
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
432
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
433
+
434
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
435
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
436
+
437
+ generated_latents = sample_hunyuan(
438
+ transformer=transformer,
439
+ sampler='unipc',
440
+ width=width,
441
+ height=height,
442
+ frames=latent_window_size * 4 - 3,
443
+ real_guidance_scale=cfg,
444
+ distilled_guidance_scale=gs,
445
+ guidance_rescale=rs,
446
+ # shift=3.0,
447
+ num_inference_steps=steps,
448
+ generator=rnd,
449
+ prompt_embeds=llama_vec,
450
+ prompt_embeds_mask=llama_attention_mask,
451
+ prompt_poolers=clip_l_pooler,
452
+ negative_prompt_embeds=llama_vec_n,
453
+ negative_prompt_embeds_mask=llama_attention_mask_n,
454
+ negative_prompt_poolers=clip_l_pooler_n,
455
+ device=gpu,
456
+ dtype=torch.bfloat16,
457
+ image_embeddings=image_encoder_last_hidden_state,
458
+ latent_indices=latent_indices,
459
+ clean_latents=clean_latents,
460
+ clean_latent_indices=clean_latent_indices,
461
+ clean_latents_2x=clean_latents_2x,
462
+ clean_latent_2x_indices=clean_latent_2x_indices,
463
+ clean_latents_4x=clean_latents_4x,
464
+ clean_latent_4x_indices=clean_latent_4x_indices,
465
+ callback=callback,
466
+ )
467
+
468
+ total_generated_latent_frames += int(generated_latents.shape[2])
469
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
470
+
471
+ if not high_vram:
472
+ offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
473
+ load_model_as_complete(vae, target_device=gpu)
474
+
475
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
476
+
477
+ if history_pixels is None:
478
+ history_pixels = vae_decode(real_history_latents, vae).cpu()
479
+ else:
480
+ section_latent_frames = latent_window_size * 2
481
+ overlapped_frames = latent_window_size * 4 - 3
482
+
483
+ current_pixels = vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
484
+ history_pixels = soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
485
+
486
+ if not high_vram:
487
+ unload_complete_models()
488
+
489
+ output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
490
+
491
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=mp4_crf)
492
+
493
+ print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
494
+
495
+ stream.output_queue.push(('file', output_filename))
496
+ except:
497
+ traceback.print_exc()
498
+
499
+ if not high_vram:
500
+ unload_complete_models(
501
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
502
+ )
503
+
504
+ stream.output_queue.push(('end', None))
505
+ return
506
+
507
+ def get_duration(input_image, prompt, t2v, n_prompt, randomize_seed, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
508
+ global total_second_length_debug_value
509
+
510
+ if total_second_length_debug_value is not None:
511
+ return min(total_second_length_debug_value * 60, 600)
512
+ return total_second_length * 60
513
+
514
+
515
+ @spaces.GPU(duration=get_duration)
516
+ def process(input_image, prompt,
517
+ t2v=False,
518
+ n_prompt="",
519
+ randomize_seed=True,
520
+ seed=31337,
521
+ total_second_length=5,
522
+ latent_window_size=9,
523
+ steps=25,
524
+ cfg=1.0,
525
+ gs=10.0,
526
+ rs=0.0,
527
+ gpu_memory_preservation=6,
528
+ use_teacache=True,
529
+ mp4_crf=16
530
+ ):
531
+ global stream, input_image_debug_value, prompt_debug_value, total_second_length_debug_value
532
+
533
+ if torch.cuda.device_count() == 0:
534
+ gr.Warning('Set this space to GPU config to make it work.')
535
+ return None, None, None, None, None, None
536
+
537
+ if input_image_debug_value is not None or prompt_debug_value is not None or total_second_length_debug_value is not None:
538
+ print("Debug mode")
539
+ input_image = input_image_debug_value
540
+ prompt = prompt_debug_value
541
+ total_second_length = total_second_length_debug_value
542
+ input_image_debug_value = prompt_debug_value = total_second_length_debug_value = None
543
+
544
+ if randomize_seed:
545
+ seed = random.randint(0, np.iinfo(np.int32).max)
546
+
547
+ prompts = prompt.split(";")
548
+
549
+ # assert input_image is not None, 'No input image!'
550
+ if t2v:
551
+ default_height, default_width = 640, 640
552
+ input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
553
+ print("No input image provided. Using a blank white image.")
554
+
555
+ yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
556
+
557
+ stream = AsyncStream()
558
+
559
+ async_run(worker, input_image, prompts, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf)
560
+
561
+ output_filename = None
562
+
563
+ while True:
564
+ flag, data = stream.output_queue.next()
565
+
566
+ if flag == 'file':
567
+ output_filename = data
568
+ yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
569
+
570
+ if flag == 'progress':
571
+ preview, desc, html = data
572
+ yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
573
+
574
+ if flag == 'end':
575
+ yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
576
+ break
577
+
578
  # 20250506 pftq: Modified worker to accept video input and clean frame count
579
  @spaces.GPU()
580
  @torch.no_grad()