Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
from fastapi import FastAPI, Query, File, UploadFile
|
| 2 |
from fastapi.responses import FileResponse
|
| 3 |
import torch
|
| 4 |
-
from diffusion import Diffusion
|
| 5 |
-
from utils import get_id_frame, get_audio_emb, save_video
|
| 6 |
import shutil
|
| 7 |
from pathlib import Path
|
| 8 |
|
|
@@ -12,18 +12,15 @@ app = FastAPI()
|
|
| 12 |
async def generate_video(
|
| 13 |
id_frame_file: UploadFile = File(...),
|
| 14 |
audio_file: UploadFile = File(...),
|
| 15 |
-
gpu: bool = Query(
|
| 16 |
id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
|
| 17 |
inference_steps: int = Query(100, description="Number of inference diffusion steps"),
|
| 18 |
-
output: str = Query("
|
| 19 |
):
|
| 20 |
device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'
|
| 21 |
|
| 22 |
print('Loading model...')
|
| 23 |
-
|
| 24 |
-
unet = torch.jit.load("your_checkpoint_path_here")
|
| 25 |
-
|
| 26 |
-
# Replace these arguments with the ones from your original args
|
| 27 |
diffusion_args = {
|
| 28 |
"in_channels": 3,
|
| 29 |
"image_size": 128,
|
|
@@ -43,9 +40,15 @@ async def generate_video(
|
|
| 43 |
shutil.copyfileobj(audio_file.file, buffer)
|
| 44 |
|
| 45 |
id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
|
| 46 |
-
audio, audio_emb = get_audio_emb(str(audio_path), "
|
| 47 |
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
|
| 51 |
print(f'Results saved at {output}')
|
|
|
|
| 1 |
from fastapi import FastAPI, Query, File, UploadFile
|
| 2 |
from fastapi.responses import FileResponse
|
| 3 |
import torch
|
| 4 |
+
from diffusion import Diffusion # Make sure you import your own modules correctly
|
| 5 |
+
from utils import get_id_frame, get_audio_emb, save_video # Make sure you import your own modules correctly
|
| 6 |
import shutil
|
| 7 |
from pathlib import Path
|
| 8 |
|
|
|
|
| 12 |
async def generate_video(
|
| 13 |
id_frame_file: UploadFile = File(...),
|
| 14 |
audio_file: UploadFile = File(...),
|
| 15 |
+
gpu: bool = Query(True, description="Use GPU if available"),
|
| 16 |
id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
|
| 17 |
inference_steps: int = Query(100, description="Number of inference diffusion steps"),
|
| 18 |
+
output: str = Query("/Users/a/Documents/Automations/git talking heads/output_video.mp4", description="Path to save the output video")
|
| 19 |
):
|
| 20 |
device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'
|
| 21 |
|
| 22 |
print('Loading model...')
|
| 23 |
+
unet = torch.jit.load("/Users/a/Documents/Automations/git talking heads/checkpoints/crema_script.pt")
|
|
|
|
|
|
|
|
|
|
| 24 |
diffusion_args = {
|
| 25 |
"in_channels": 3,
|
| 26 |
"image_size": 128,
|
|
|
|
| 40 |
shutil.copyfileobj(audio_file.file, buffer)
|
| 41 |
|
| 42 |
id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
|
| 43 |
+
audio, audio_emb = get_audio_emb(str(audio_path), "/Users/a/Documents/Automations/git talking heads/checkpoints/audio_encoder.pt", device)
|
| 44 |
|
| 45 |
+
unet_args = {
|
| 46 |
+
"n_audio_motion_embs": 2,
|
| 47 |
+
"n_motion_frames": 2,
|
| 48 |
+
"motion_channels": 3
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0), **unet_args)
|
| 52 |
|
| 53 |
save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
|
| 54 |
print(f'Results saved at {output}')
|