Linly-Talker / api /talker_api.py
weismart1807's picture
Upload folder using huggingface_hub
e90b704 verified
from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form
from pydantic import BaseModel
import shutil
import os
from loguru import logger
import sys
import gc
import torch
from typing import Optional
from fastapi.responses import FileResponse
sys.path.append("./")
app = FastAPI()
# Global variable to store the currently loaded Talker model
talker = None
USE_REF_VIDEO = False
REF_VIDEO = None
REF_INFO = 'pose'
USE_IDLE_MODE = False
AUDIO_LENGTH = 5
class TalkerRequest(BaseModel):
preprocess_type: str = 'crop'
is_still_mode: bool = False
enhancer: bool = False
batch_size: int = 4
size_of_image: int = 256
pose_style: int = 0
facerender: str = 'facevid2vid'
exp_weight: float = 1.0
blink_every: bool = True
talker_method: str = 'SadTalker'
fps: int = 30
async def clear_memory():
"""Asynchronous function to clear GPU memory."""
logger.info("Clearing GPU memory resources")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
logger.info(f"GPU memory usage after clearing: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")
@app.post("/talker_change_model/")
async def change_model(model_name: str = Query(..., description="Name of the Talker model to load")):
"""Change digital human conversation model and load corresponding resources."""
global talker
# Clear memory to free up unnecessary resources before loading a new model
await clear_memory()
if model_name not in ['SadTalker', 'Wav2Lip', 'Wav2Lipv2', 'NeRFTalk']:
raise HTTPException(status_code=400, detail="Other models are not integrated yet. Please wait for updates.")
try:
if model_name == 'SadTalker':
from TFG import SadTalker
talker = SadTalker(lazy_load=True)
logger.info("SadTalker model loaded successfully")
elif model_name == 'Wav2Lip':
from TFG import Wav2Lip
talker = Wav2Lip("checkpoints/wav2lip_gan.pth")
logger.info("Wav2Lip model loaded successfully")
elif model_name == 'Wav2Lipv2':
from TFG import Wav2Lipv2
talker = Wav2Lipv2('checkpoints/wav2lipv2.pth')
logger.info("Wav2Lipv2 model loaded successfully, capable of generating higher quality results")
elif model_name == 'NeRFTalk':
from TFG import NeRFTalk
talker = NeRFTalk()
talker.init_model('checkpoints/Obama_ave.pth', 'checkpoints/Obama.json')
logger.info("NeRFTalk model loaded successfully")
logger.warning("NeRFTalk model is trained only for a single person, built-in with the Obama model, uploading other images is ineffective.")
except Exception as e:
logger.error(f"Failed to load {model_name} model: {e}")
raise HTTPException(status_code=500, detail=f"Failed to load {model_name} model: {e}")
return {"message": f"{model_name} model loaded successfully"}
@app.post("/talker_response/")
async def talker_response(
preprocess_type: str = Form('crop'),
is_still_mode: bool = Form(False),
enhancer: bool = Form(False),
batch_size: int = Form(4),
size_of_image: int = Form(256),
pose_style: int = Form(0),
facerender: str = Form('facevid2vid'),
exp_weight: float = Form(1.0),
blink_every: bool = Form(True),
talker_method: str = Form('SadTalker'),
fps: int = Form(30),
source_image: UploadFile = File(..., description="The source image file"),
driven_audio: UploadFile = File(..., description="The audio file that will drive the talking head"),
):
"""Handle digital human conversation requests and generate video."""
global talker
if talker is None:
raise HTTPException(status_code=400, detail="Talker model not loaded. Please load a model first.")
# Assemble the request data into the TalkerRequest model
request = TalkerRequest(
preprocess_type=preprocess_type,
is_still_mode=is_still_mode,
enhancer=enhancer,
batch_size=batch_size,
size_of_image=size_of_image,
pose_style=pose_style,
facerender=facerender,
exp_weight=exp_weight,
blink_every=blink_every,
talker_method=talker_method,
fps=fps,
)
# print(request)
# Temporary file paths
temp_image_path = "temp_source_image.jpg"
temp_audio_path = "temp_driven_audio.wav"
try:
# Save uploaded files temporarily
with open(temp_image_path, "wb") as image_file:
shutil.copyfileobj(source_image.file, image_file)
with open(temp_audio_path, "wb") as audio_file:
shutil.copyfileobj(driven_audio.file, audio_file)
# Video generation
if request.talker_method == 'SadTalker':
video_path = talker.test2(
temp_image_path,
temp_audio_path,
request.preprocess_type,
request.is_still_mode,
request.enhancer,
request.batch_size,
request.size_of_image,
request.pose_style,
request.facerender,
request.exp_weight,
REF_VIDEO, REF_INFO, USE_IDLE_MODE, AUDIO_LENGTH,
request.blink_every,
request.fps,
)
elif request.talker_method == 'Wav2Lip':
video_path = talker.predict(temp_image_path, temp_audio_path, request.batch_size)
elif request.talker_method == 'Wav2Lipv2':
video_path = talker.run(temp_image_path, temp_audio_path, request.batch_size)
elif request.talker_method == 'NeRFTalk':
video_path = talker.predict(temp_audio_path)
else:
raise HTTPException(status_code=400, detail="Unsupported method")
# Ensure the video file exists and return it
if os.path.exists(video_path):
return FileResponse(video_path, media_type='video/mp4', filename=os.path.basename(video_path))
else:
raise HTTPException(status_code=404, detail="Video file not found")
except Exception as e:
logger.error(f"Video generation failed: {e}")
raise HTTPException(status_code=500, detail=f"Video generation failed: {e}")
finally:
# Clean up temporary files
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
if os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8003)