Moon11111 commited on
Commit
e3b1be4
·
verified ·
1 Parent(s): 4fa3412

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import tempfile
3
+ import os
4
+
5
+ from flask import Flask, request, jsonify
6
+ from omegaconf import OmegaConf
7
+ import torch
8
+ from diffusers import AutoencoderKL, DDIMScheduler
9
+ from latentsync.models.unet import UNet3DConditionModel
10
+ from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
11
+ from diffusers.utils.import_utils import is_xformers_available
12
+ from accelerate.utils import set_seed
13
+ from latentsync.whisper.audio2feature import Audio2Feature
14
+ from werkzeug.utils import secure_filename
15
+
16
+ # Initialize the Flask app
17
+ app = Flask(__name__)
18
+
19
+ def run_inference(video_path, audio_path, video_out_path,
20
+ inference_ckpt_path, unet_config_path="configs/unet/second_stage.yaml",
21
+ inference_steps=20, guidance_scale=1.0, seed=1247):
22
+ # Load configuration
23
+ config = OmegaConf.load(unet_config_path)
24
+
25
+ # Choose whisper model based on config settings
26
+ if config.model.cross_attention_dim == 768:
27
+ whisper_model_path = "checkpoints/whisper/small.pt"
28
+ elif config.model.cross_attention_dim == 384:
29
+ whisper_model_path = "checkpoints/whisper/tiny.pt"
30
+ else:
31
+ raise NotImplementedError("cross_attention_dim must be 768 or 384")
32
+
33
+ # Determine proper dtype based on GPU capabilities
34
+ is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
35
+ dtype = torch.float16 if is_fp16_supported else torch.float32
36
+
37
+ # Setup scheduler
38
+ scheduler = DDIMScheduler.from_pretrained("configs")
39
+
40
+ # Initialize the audio encoder
41
+ audio_encoder = Audio2Feature(model_path=whisper_model_path,
42
+ device="cuda", num_frames=config.data.num_frames)
43
+
44
+ # Load VAE
45
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
46
+ vae.config.scaling_factor = 0.18215
47
+ vae.config.shift_factor = 0
48
+
49
+ # Load UNet model from the checkpoint
50
+ unet, _ = UNet3DConditionModel.from_pretrained(
51
+ OmegaConf.to_container(config.model),
52
+ inference_ckpt_path, # load checkpoint
53
+ device="cpu",
54
+ )
55
+ unet = unet.to(dtype=dtype)
56
+
57
+ # Optionally enable memory-efficient attention if available
58
+ if is_xformers_available():
59
+ unet.enable_xformers_memory_efficient_attention()
60
+
61
+ # Initialize the pipeline and move to GPU
62
+ pipeline = LipsyncPipeline(
63
+ vae=vae,
64
+ audio_encoder=audio_encoder,
65
+ unet=unet,
66
+ scheduler=scheduler,
67
+ ).to("cuda")
68
+
69
+ # Set seed
70
+ if seed != -1:
71
+ set_seed(seed)
72
+ else:
73
+ torch.seed()
74
+
75
+ # Run the pipeline
76
+ pipeline(
77
+ video_path=video_path,
78
+ audio_path=audio_path,
79
+ video_out_path=video_out_path,
80
+ video_mask_path=video_out_path.replace(".mp4", "_mask.mp4"),
81
+ num_frames=config.data.num_frames,
82
+ num_inference_steps=inference_steps,
83
+ guidance_scale=guidance_scale,
84
+ weight_dtype=dtype,
85
+ width=config.data.resolution,
86
+ height=config.data.resolution,
87
+ )
88
+
89
+ @app.route('/lipsync', methods=['POST'])
90
+ def lipsync_endpoint():
91
+ # Ensure both video and audio files are present in the request
92
+ if 'video' not in request.files or 'audio' not in request.files:
93
+ return jsonify({'error': 'Both video and audio files are required.'}), 400
94
+
95
+ video_file = request.files['video']
96
+ audio_file = request.files['audio']
97
+
98
+ # Save uploaded files to temporary locations
99
+ temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
100
+ temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
101
+ video_file.save(temp_video.name)
102
+ audio_file.save(temp_audio.name)
103
+
104
+ # Create a temporary file for the output video
105
+ output_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
106
+
107
+ # You can pass additional parameters via form data if needed (e.g., checkpoint path)
108
+ inference_ckpt_path = request.form.get('inference_ckpt_path', 'checkpoints/latentsync_unet.pt')
109
+ unet_config_path = request.form.get('unet_config_path', 'configs/unet/second_stage.yaml')
110
+
111
+ try:
112
+ run_inference(
113
+ video_path=temp_video.name,
114
+ audio_path=temp_audio.name,
115
+ video_out_path=output_video,
116
+ inference_ckpt_path=inference_ckpt_path,
117
+ unet_config_path=unet_config_path,
118
+ inference_steps=int(request.form.get('inference_steps', 20)),
119
+ guidance_scale=float(request.form.get('guidance_scale', 1.0)),
120
+ seed=int(request.form.get('seed', 1247))
121
+ )
122
+ # Return the output video path or further process the file for download
123
+ return jsonify({'output_video': output_video}), 200
124
+ except Exception as e:
125
+ return jsonify({'error': str(e)}), 500
126
+
127
+ if __name__ == "__main__":
128
+ # Using pyngrok to expose the server to the internet
129
+ from pyngrok import ngrok
130
+ public_url = ngrok.connect(5000)
131
+ print(" * ngrok tunnel available at:", public_url)
132
+
133
+ # Run the Flask app on port 5000
134
+ app.run(port=5000)