Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -108,11 +108,9 @@ def generate_image(prompt: str):
|
|
| 108 |
# @spaces.GPU(duration=300, gpu_type="l40s")
|
| 109 |
def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
|
| 110 |
try:
|
| 111 |
-
# 이미지 생성
|
| 112 |
image_path = generate_image(prompt)
|
| 113 |
image = torchvision.io.read_image(image_path).float() / 255.0
|
| 114 |
|
| 115 |
-
# 한글 입력 확인 및 번역
|
| 116 |
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
|
| 117 |
translated = translator(prompt, max_length=512)
|
| 118 |
prompt = translated[0]['translation_text']
|
|
@@ -120,9 +118,7 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
|
|
| 120 |
resolution = (576, 1024)
|
| 121 |
save_fps = 8
|
| 122 |
seed_everything(seed)
|
| 123 |
-
transform = transforms.Compose([
|
| 124 |
-
transforms.Resize(resolution, antialias=True),
|
| 125 |
-
])
|
| 126 |
|
| 127 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
|
| 128 |
start = time.time()
|
|
@@ -130,30 +126,30 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
|
|
| 130 |
steps = 60
|
| 131 |
|
| 132 |
batch_size = 1
|
| 133 |
-
channels = model.
|
| 134 |
h, w = resolution[0] // 8, resolution[1] // 8
|
| 135 |
noise_shape = [batch_size, channels, frames, h, w]
|
| 136 |
|
| 137 |
with torch.no_grad(), torch.cuda.amp.autocast():
|
| 138 |
-
text_emb = model.
|
| 139 |
|
| 140 |
img_tensor = image.to(torch.cuda.current_device())
|
| 141 |
img_tensor = (img_tensor - 0.5) * 2
|
| 142 |
image_tensor_resized = transform(img_tensor)
|
| 143 |
videos = image_tensor_resized.unsqueeze(0)
|
| 144 |
|
| 145 |
-
z = get_latent_z(model
|
| 146 |
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
|
| 147 |
|
| 148 |
-
cond_images = model.
|
| 149 |
-
img_emb = model.
|
| 150 |
|
| 151 |
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
|
| 152 |
|
| 153 |
fs = torch.tensor([fs], dtype=torch.long, device=torch.cuda.current_device())
|
| 154 |
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
|
| 155 |
|
| 156 |
-
batch_samples = batch_ddim_sampling(model
|
| 157 |
|
| 158 |
video_path = './output.mp4'
|
| 159 |
save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
|
|
@@ -168,7 +164,8 @@ def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
|
|
| 168 |
print(f"Error occurred: {e}")
|
| 169 |
return None
|
| 170 |
finally:
|
| 171 |
-
torch.cuda.empty_cache()
|
|
|
|
| 172 |
|
| 173 |
i2v_examples = [
|
| 174 |
['우주인 복장으로 기타를 치는 남자', 30, 7.5, 1.0, 6, 123, 64],
|
|
|
|
| 108 |
# @spaces.GPU(duration=300, gpu_type="l40s")
|
| 109 |
def infer(prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, frames=64):
|
| 110 |
try:
|
|
|
|
| 111 |
image_path = generate_image(prompt)
|
| 112 |
image = torchvision.io.read_image(image_path).float() / 255.0
|
| 113 |
|
|
|
|
| 114 |
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
|
| 115 |
translated = translator(prompt, max_length=512)
|
| 116 |
prompt = translated[0]['translation_text']
|
|
|
|
| 118 |
resolution = (576, 1024)
|
| 119 |
save_fps = 8
|
| 120 |
seed_everything(seed)
|
| 121 |
+
transform = transforms.Compose([transforms.Resize(resolution, antialias=True)])
|
|
|
|
|
|
|
| 122 |
|
| 123 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
|
| 124 |
start = time.time()
|
|
|
|
| 126 |
steps = 60
|
| 127 |
|
| 128 |
batch_size = 1
|
| 129 |
+
channels = model.diffusion_model.out_channels # model.module 제거
|
| 130 |
h, w = resolution[0] // 8, resolution[1] // 8
|
| 131 |
noise_shape = [batch_size, channels, frames, h, w]
|
| 132 |
|
| 133 |
with torch.no_grad(), torch.cuda.amp.autocast():
|
| 134 |
+
text_emb = model.get_learned_conditioning([prompt]) # model.module 제거
|
| 135 |
|
| 136 |
img_tensor = image.to(torch.cuda.current_device())
|
| 137 |
img_tensor = (img_tensor - 0.5) * 2
|
| 138 |
image_tensor_resized = transform(img_tensor)
|
| 139 |
videos = image_tensor_resized.unsqueeze(0)
|
| 140 |
|
| 141 |
+
z = get_latent_z(model, videos.unsqueeze(2)) # model.module 제거
|
| 142 |
img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames)
|
| 143 |
|
| 144 |
+
cond_images = model.embedder(img_tensor.unsqueeze(0)) # model.module 제거
|
| 145 |
+
img_emb = model.image_proj_model(cond_images) # model.module 제거
|
| 146 |
|
| 147 |
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
|
| 148 |
|
| 149 |
fs = torch.tensor([fs], dtype=torch.long, device=torch.cuda.current_device())
|
| 150 |
cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]}
|
| 151 |
|
| 152 |
+
batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) # model.module 제거
|
| 153 |
|
| 154 |
video_path = './output.mp4'
|
| 155 |
save_videos(batch_samples, './', filenames=['output'], fps=save_fps)
|
|
|
|
| 164 |
print(f"Error occurred: {e}")
|
| 165 |
return None
|
| 166 |
finally:
|
| 167 |
+
torch.cuda.empty_cache()
|
| 168 |
+
|
| 169 |
|
| 170 |
i2v_examples = [
|
| 171 |
['우주인 복장으로 기타를 치는 남자', 30, 7.5, 1.0, 6, 123, 64],
|