Spaces:
Runtime error
Runtime error
File size: 6,645 Bytes
e90b704 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form
from pydantic import BaseModel
import shutil
import os
from loguru import logger
import sys
import gc
import torch
from typing import Optional
from fastapi.responses import FileResponse
sys.path.append("./")
app = FastAPI()
# Global variable to store the currently loaded Talker model
talker = None
USE_REF_VIDEO = False
REF_VIDEO = None
REF_INFO = 'pose'
USE_IDLE_MODE = False
AUDIO_LENGTH = 5
class TalkerRequest(BaseModel):
preprocess_type: str = 'crop'
is_still_mode: bool = False
enhancer: bool = False
batch_size: int = 4
size_of_image: int = 256
pose_style: int = 0
facerender: str = 'facevid2vid'
exp_weight: float = 1.0
blink_every: bool = True
talker_method: str = 'SadTalker'
fps: int = 30
async def clear_memory():
"""Asynchronous function to clear GPU memory."""
logger.info("Clearing GPU memory resources")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
logger.info(f"GPU memory usage after clearing: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")
@app.post("/talker_change_model/")
async def change_model(model_name: str = Query(..., description="Name of the Talker model to load")):
"""Change digital human conversation model and load corresponding resources."""
global talker
# Clear memory to free up unnecessary resources before loading a new model
await clear_memory()
if model_name not in ['SadTalker', 'Wav2Lip', 'Wav2Lipv2', 'NeRFTalk']:
raise HTTPException(status_code=400, detail="Other models are not integrated yet. Please wait for updates.")
try:
if model_name == 'SadTalker':
from TFG import SadTalker
talker = SadTalker(lazy_load=True)
logger.info("SadTalker model loaded successfully")
elif model_name == 'Wav2Lip':
from TFG import Wav2Lip
talker = Wav2Lip("checkpoints/wav2lip_gan.pth")
logger.info("Wav2Lip model loaded successfully")
elif model_name == 'Wav2Lipv2':
from TFG import Wav2Lipv2
talker = Wav2Lipv2('checkpoints/wav2lipv2.pth')
logger.info("Wav2Lipv2 model loaded successfully, capable of generating higher quality results")
elif model_name == 'NeRFTalk':
from TFG import NeRFTalk
talker = NeRFTalk()
talker.init_model('checkpoints/Obama_ave.pth', 'checkpoints/Obama.json')
logger.info("NeRFTalk model loaded successfully")
logger.warning("NeRFTalk model is trained only for a single person, built-in with the Obama model, uploading other images is ineffective.")
except Exception as e:
logger.error(f"Failed to load {model_name} model: {e}")
raise HTTPException(status_code=500, detail=f"Failed to load {model_name} model: {e}")
return {"message": f"{model_name} model loaded successfully"}
@app.post("/talker_response/")
async def talker_response(
preprocess_type: str = Form('crop'),
is_still_mode: bool = Form(False),
enhancer: bool = Form(False),
batch_size: int = Form(4),
size_of_image: int = Form(256),
pose_style: int = Form(0),
facerender: str = Form('facevid2vid'),
exp_weight: float = Form(1.0),
blink_every: bool = Form(True),
talker_method: str = Form('SadTalker'),
fps: int = Form(30),
source_image: UploadFile = File(..., description="The source image file"),
driven_audio: UploadFile = File(..., description="The audio file that will drive the talking head"),
):
"""Handle digital human conversation requests and generate video."""
global talker
if talker is None:
raise HTTPException(status_code=400, detail="Talker model not loaded. Please load a model first.")
# Assemble the request data into the TalkerRequest model
request = TalkerRequest(
preprocess_type=preprocess_type,
is_still_mode=is_still_mode,
enhancer=enhancer,
batch_size=batch_size,
size_of_image=size_of_image,
pose_style=pose_style,
facerender=facerender,
exp_weight=exp_weight,
blink_every=blink_every,
talker_method=talker_method,
fps=fps,
)
# print(request)
# Temporary file paths
temp_image_path = "temp_source_image.jpg"
temp_audio_path = "temp_driven_audio.wav"
try:
# Save uploaded files temporarily
with open(temp_image_path, "wb") as image_file:
shutil.copyfileobj(source_image.file, image_file)
with open(temp_audio_path, "wb") as audio_file:
shutil.copyfileobj(driven_audio.file, audio_file)
# Video generation
if request.talker_method == 'SadTalker':
video_path = talker.test2(
temp_image_path,
temp_audio_path,
request.preprocess_type,
request.is_still_mode,
request.enhancer,
request.batch_size,
request.size_of_image,
request.pose_style,
request.facerender,
request.exp_weight,
REF_VIDEO, REF_INFO, USE_IDLE_MODE, AUDIO_LENGTH,
request.blink_every,
request.fps,
)
elif request.talker_method == 'Wav2Lip':
video_path = talker.predict(temp_image_path, temp_audio_path, request.batch_size)
elif request.talker_method == 'Wav2Lipv2':
video_path = talker.run(temp_image_path, temp_audio_path, request.batch_size)
elif request.talker_method == 'NeRFTalk':
video_path = talker.predict(temp_audio_path)
else:
raise HTTPException(status_code=400, detail="Unsupported method")
# Ensure the video file exists and return it
if os.path.exists(video_path):
return FileResponse(video_path, media_type='video/mp4', filename=os.path.basename(video_path))
else:
raise HTTPException(status_code=404, detail="Video file not found")
except Exception as e:
logger.error(f"Video generation failed: {e}")
raise HTTPException(status_code=500, detail=f"Video generation failed: {e}")
finally:
# Clean up temporary files
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8003)
|