aiqtech commited on
Commit
8275889
·
verified ·
1 Parent(s): 5e12b00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -440
app.py CHANGED
@@ -1,444 +1,35 @@
1
- import gradio as gr
2
  import os
3
- import torch
4
- import argparse
5
- import torchvision
6
-
7
- # Disable all automatic translation and model downloading BEFORE any imports
8
- os.environ['TRANSFORMERS_OFFLINE'] = '1'
9
- os.environ['HF_DATASETS_OFFLINE'] = '1'
10
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
11
- os.environ['GRADIO_ANALYTICS_ENABLED'] = 'false'
12
- # Disable translation specifically
13
- os.environ['GRADIO_TRANSLATION_ENABLED'] = 'false'
14
- os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
15
-
16
- from pipelines.pipeline_videogen import VideoGenPipeline
17
- from diffusers.schedulers import DDIMScheduler
18
- from diffusers.models import AutoencoderKL
19
- from diffusers.models import AutoencoderKLTemporalDecoder
20
- from transformers import CLIPTokenizer, CLIPTextModel
21
- from omegaconf import OmegaConf
22
-
23
  import sys
24
- sys.path.append(os.path.split(sys.path[0])[0])
25
- from models import get_models
26
- import imageio
27
- from PIL import Image
28
- import numpy as np
29
- from datasets import video_transforms
30
- from torchvision import transforms
31
- from einops import rearrange, repeat
32
- from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
33
- from copy import deepcopy
34
- import spaces
35
- import requests
36
- from datetime import datetime
37
- import random
38
-
39
- parser = argparse.ArgumentParser()
40
- parser.add_argument("--config", type=str, default="./configs/sample.yaml")
41
- args = parser.parse_args()
42
- args = OmegaConf.load(args.config)
43
-
44
- torch.set_grad_enabled(False)
45
- device = "cuda" if torch.cuda.is_available() else "cpu"
46
- dtype = torch.float16
47
-
48
- # Load models
49
- unet = get_models(args).to(device, dtype=dtype)
50
-
51
- if args.enable_vae_temporal_decoder:
52
- if args.use_dct:
53
- vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(
54
- args.pretrained_model_path,
55
- subfolder="vae_temporal_decoder",
56
- torch_dtype=torch.float64
57
- ).to(device)
58
- else:
59
- vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(
60
- args.pretrained_model_path,
61
- subfolder="vae_temporal_decoder",
62
- torch_dtype=torch.float16
63
- ).to(device)
64
- vae = deepcopy(vae_for_base_content).to(dtype=dtype)
65
- else:
66
- vae_for_base_content = AutoencoderKL.from_pretrained(
67
- args.pretrained_model_path,
68
- subfolder="vae"
69
- ).to(device, dtype=torch.float64)
70
- vae = deepcopy(vae_for_base_content).to(dtype=dtype)
71
-
72
- tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
73
- text_encoder = CLIPTextModel.from_pretrained(
74
- args.pretrained_model_path,
75
- subfolder="text_encoder",
76
- torch_dtype=dtype
77
- ).to(device)
78
-
79
- # Set eval mode
80
- unet.eval()
81
- vae.eval()
82
- text_encoder.eval()
83
-
84
- # Setup directories
85
- basedir = os.getcwd()
86
- savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
87
- savedir_sample = os.path.join(savedir, "sample")
88
- os.makedirs(savedir, exist_ok=True)
89
-
90
- def update_and_resize_image(input_image_path, height_slider, width_slider):
91
- """Update and resize input image to match specified dimensions."""
92
- if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
93
- pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
94
- else:
95
- pil_image = Image.open(input_image_path).convert('RGB')
96
-
97
- original_width, original_height = pil_image.size
98
-
99
- if original_height == height_slider and original_width == width_slider:
100
- return gr.Image(value=np.array(pil_image))
101
-
102
- ratio1 = height_slider / original_height
103
- ratio2 = width_slider / original_width
104
-
105
- if ratio1 > ratio2:
106
- new_width = int(original_width * ratio1)
107
- new_height = int(original_height * ratio1)
108
- else:
109
- new_width = int(original_width * ratio2)
110
- new_height = int(original_height * ratio2)
111
-
112
- pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
113
-
114
- left = (new_width - width_slider) / 2
115
- top = (new_height - height_slider) / 2
116
- right = left + width_slider
117
- bottom = top + height_slider
118
-
119
- pil_image = pil_image.crop((left, top, right, bottom))
120
-
121
- return gr.Image(value=np.array(pil_image))
122
-
123
- def update_textbox_and_save_image(input_image, height_slider, width_slider):
124
- """Process uploaded image and save to disk."""
125
- pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
126
-
127
- original_width, original_height = pil_image.size
128
-
129
- ratio1 = height_slider / original_height
130
- ratio2 = width_slider / original_width
131
-
132
- if ratio1 > ratio2:
133
- new_width = int(original_width * ratio1)
134
- new_height = int(original_height * ratio1)
135
- else:
136
- new_width = int(original_width * ratio2)
137
- new_height = int(original_height * ratio2)
138
-
139
- pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
140
-
141
- left = (new_width - width_slider) / 2
142
- top = (new_height - height_slider) / 2
143
- right = left + width_slider
144
- bottom = top + height_slider
145
-
146
- pil_image = pil_image.crop((left, top, right, bottom))
147
-
148
- img_path = os.path.join(savedir, "input_image.png")
149
- pil_image.save(img_path)
150
-
151
- return gr.Textbox(value=img_path), gr.Image(value=np.array(pil_image))
152
-
153
- def prepare_image(image, vae, transform_video, device, dtype=torch.float16):
154
- """Prepare image for video generation pipeline."""
155
- image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
156
- image = transform_video(image)
157
- image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
158
- image = image.unsqueeze(2)
159
- return image
160
-
161
- @spaces.GPU
162
- def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, width, scfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed):
163
- """Generate video from input image and prompt."""
164
-
165
- torch.manual_seed(seed)
166
-
167
- scheduler = DDIMScheduler.from_pretrained(
168
- args.pretrained_model_path,
169
- subfolder="scheduler",
170
- beta_start=args.beta_start,
171
- beta_end=args.beta_end,
172
- beta_schedule=args.beta_schedule
173
- )
174
-
175
- videogen_pipeline = VideoGenPipeline(
176
- vae=vae,
177
- text_encoder=text_encoder,
178
- tokenizer=tokenizer,
179
- scheduler=scheduler,
180
- unet=unet
181
- ).to(device)
182
-
183
- transform_video = transforms.Compose([
184
- video_transforms.ToTensorVideo(),
185
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
186
- ])
187
-
188
- if args.use_dct:
189
- base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
190
- else:
191
- base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
192
-
193
- if use_dctinit:
194
- # Filter params
195
- print("Using DCT!")
196
- base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
197
-
198
- # Define filter
199
- freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
200
-
201
- noise = torch.randn(1, 4, 15, 40, 64).to(device)
202
-
203
- # Add noise to base_content
204
- diffuse_timesteps = torch.full((1,), int(noise_level))
205
- diffuse_timesteps = diffuse_timesteps.long()
206
-
207
- # 3D content
208
- base_content_noise = scheduler.add_noise(
209
- original_samples=base_content_repeat.to(device),
210
- noise=noise,
211
- timesteps=diffuse_timesteps.to(device)
212
- )
213
-
214
- # 3D content with DCT
215
- latents = exchanged_mixed_dct_freq(
216
- noise=noise,
217
- base_content=base_content_noise,
218
- LPF_3d=freq_filter
219
- ).to(dtype=torch.float16)
220
- else:
221
- latents = None
222
-
223
- base_content = base_content.to(dtype=torch.float16)
224
-
225
- videos = videogen_pipeline(
226
- prompt,
227
- negative_prompt=negative_prompt,
228
- latents=latents,
229
- base_content=base_content,
230
- video_length=15,
231
- height=height,
232
- width=width,
233
- num_inference_steps=diffusion_step,
234
- guidance_scale=scfg_scale,
235
- motion_bucket_id=100-motion_bucket_id,
236
- enable_vae_temporal_decoder=args.enable_vae_temporal_decoder
237
- ).video
238
-
239
- save_path = args.save_img_path + 'temp' + '.mp4'
240
- imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
241
- return save_path
242
-
243
- # Create output directory
244
- if not os.path.exists(args.save_img_path):
245
- os.makedirs(args.save_img_path)
246
-
247
- # CSS for interface
248
- css = """
249
- footer {
250
- visibility: hidden;
251
- }
252
- """
253
-
254
- # Create Gradio interface with translation disabled
255
- with gr.Blocks(theme="soft", css=css, analytics_enabled=False) as demo:
256
- gr.Markdown("# Video Generation with DCTInit")
257
- gr.Markdown("Generate videos from static images. Please use English prompts only.")
258
-
259
- with gr.Column(variant="panel"):
260
- with gr.Row():
261
- prompt_textbox = gr.Textbox(
262
- label="Prompt (English only)",
263
- lines=1,
264
- placeholder="Describe the motion you want to see..."
265
- )
266
- negative_prompt_textbox = gr.Textbox(
267
- label="Negative prompt",
268
- lines=1,
269
- placeholder="What to avoid in the generation..."
270
- )
271
 
272
- with gr.Row(equal_height=False):
273
- with gr.Column():
274
- with gr.Row():
275
- input_image = gr.Image(label="Input Image", interactive=True)
276
- result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
277
-
278
- generate_button = gr.Button(value="Generate", variant='primary')
279
-
280
- with gr.Accordion("Advanced options", open=False):
281
- with gr.Column():
282
- with gr.Row():
283
- input_image_path = gr.Textbox(
284
- label="Input Image URL",
285
- lines=1,
286
- scale=10,
287
- info="Press Enter or the Preview button to confirm the input image."
288
- )
289
- preview_button = gr.Button(value="Preview")
290
-
291
- with gr.Row():
292
- sample_step_slider = gr.Slider(
293
- label="Sampling steps",
294
- value=50,
295
- minimum=10,
296
- maximum=250,
297
- step=1
298
- )
299
-
300
- with gr.Row():
301
- seed_textbox = gr.Slider(
302
- label="Seed",
303
- value=100,
304
- minimum=1,
305
- maximum=int(1e8),
306
- step=1,
307
- interactive=True
308
- )
309
-
310
- with gr.Row():
311
- height = gr.Slider(
312
- label="Height",
313
- value=320,
314
- minimum=0,
315
- maximum=512,
316
- step=16,
317
- interactive=False
318
- )
319
- width = gr.Slider(
320
- label="Width",
321
- value=512,
322
- minimum=0,
323
- maximum=512,
324
- step=16,
325
- interactive=False
326
- )
327
-
328
- with gr.Row():
329
- txt_cfg_scale = gr.Slider(
330
- label="CFG Scale",
331
- value=7.5,
332
- minimum=1.0,
333
- maximum=20.0,
334
- step=0.1,
335
- interactive=True
336
- )
337
- motion_bucket_id = gr.Slider(
338
- label="Motion Intensity",
339
- value=10,
340
- minimum=1,
341
- maximum=20,
342
- step=1,
343
- interactive=True
344
- )
345
-
346
- with gr.Row():
347
- use_dctinit = gr.Checkbox(label="Enable DCTInit", value=True)
348
- dct_coefficients = gr.Slider(
349
- label="DCT Coefficients",
350
- value=0.23,
351
- minimum=0,
352
- maximum=1,
353
- step=0.01,
354
- interactive=True
355
- )
356
- noise_level = gr.Slider(
357
- label="Noise Level",
358
- value=985,
359
- minimum=1,
360
- maximum=999,
361
- step=1,
362
- interactive=True
363
- )
364
-
365
- # Event handlers
366
- input_image.upload(
367
- fn=update_textbox_and_save_image,
368
- inputs=[input_image, height, width],
369
- outputs=[input_image_path, input_image]
370
- )
371
-
372
- preview_button.click(
373
- fn=update_and_resize_image,
374
- inputs=[input_image_path, height, width],
375
- outputs=[input_image]
376
- )
377
-
378
- input_image_path.submit(
379
- fn=update_and_resize_image,
380
- inputs=[input_image_path, height, width],
381
- outputs=[input_image]
382
- )
383
-
384
- # Examples
385
- EXAMPLES = [
386
- ["./example/aircrafts_flying/0.jpg", "aircrafts flying", "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
387
- ["./example/fireworks/0.jpg", "fireworks", "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
388
- ["./example/flowers_swaying/0.jpg", "flowers swaying", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
389
- ["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach", "low quality, background changing", 50, 320, 512, 7.5, True, 0.25, 995, 10, 49494220],
390
- ["./example/house_rotating/0.jpg", "house rotating", "low quality", 50, 320, 512, 7.5, True, 0.23, 985, 10, 46640174],
391
- ["./example/people_runing/0.jpg", "people runing", "low quality, background changing", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
392
- ["./example/shark_swimming/0.jpg", "shark swimming", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 32947978],
393
- ["./example/car_moving/0.jpg", "car moving", "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 75469653],
394
- ["./example/windmill_turning/0.jpg", "windmill turning", "background changing", 50, 320, 512, 7.5, True, 0.21, 975, 10, 89378613],
395
- ]
396
-
397
- examples = gr.Examples(
398
- examples=EXAMPLES,
399
- fn=gen_video,
400
- inputs=[
401
- input_image,
402
- prompt_textbox,
403
- negative_prompt_textbox,
404
- sample_step_slider,
405
- height,
406
- width,
407
- txt_cfg_scale,
408
- use_dctinit,
409
- dct_coefficients,
410
- noise_level,
411
- motion_bucket_id,
412
- seed_textbox
413
- ],
414
- outputs=[result_video],
415
- cache_examples=False, # Changed from "lazy" to False to avoid caching issues
416
- )
417
-
418
- generate_button.click(
419
- fn=gen_video,
420
- inputs=[
421
- input_image,
422
- prompt_textbox,
423
- negative_prompt_textbox,
424
- sample_step_slider,
425
- height,
426
- width,
427
- txt_cfg_scale,
428
- use_dctinit,
429
- dct_coefficients,
430
- noise_level,
431
- motion_bucket_id,
432
- seed_textbox,
433
- ],
434
- outputs=[result_video]
435
- )
436
 
437
- # Launch the interface with analytics disabled
438
- demo.launch(
439
- debug=False,
440
- share=True,
441
- server_name="127.0.0.1",
442
- analytics_enabled=False,
443
- enable_queue=True
444
- )
 
 
1
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import sys
3
+ import streamlit as st
4
+ from tempfile import NamedTemporaryFile
5
+
6
+ def main():
7
+ try:
8
+ # Get the code from secrets
9
+ code = os.environ.get("MAIN_CODE")
10
+
11
+ if not code:
12
+ st.error("⚠️ The application code wasn't found in secrets. Please add the MAIN_CODE secret.")
13
+ return
14
+
15
+ # Create a temporary Python file
16
+ with NamedTemporaryFile(suffix='.py', delete=False, mode='w') as tmp:
17
+ tmp.write(code)
18
+ tmp_path = tmp.name
19
+
20
+ # Execute the code
21
+ exec(compile(code, tmp_path, 'exec'), globals())
22
+
23
+ # Clean up the temporary file
24
+ try:
25
+ os.unlink(tmp_path)
26
+ except:
27
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ except Exception as e:
30
+ st.error(f"⚠️ Error loading or executing the application: {str(e)}")
31
+ import traceback
32
+ st.code(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ if __name__ == "__main__":
35
+ main()