Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form | |
| from pydantic import BaseModel | |
| import shutil | |
| import os | |
| from loguru import logger | |
| import sys | |
| import gc | |
| import torch | |
| from typing import Optional | |
| from fastapi.responses import FileResponse | |
| sys.path.append("./") | |
| app = FastAPI() | |
| # Global variable to store the currently loaded Talker model | |
| talker = None | |
| USE_REF_VIDEO = False | |
| REF_VIDEO = None | |
| REF_INFO = 'pose' | |
| USE_IDLE_MODE = False | |
| AUDIO_LENGTH = 5 | |
| class TalkerRequest(BaseModel): | |
| preprocess_type: str = 'crop' | |
| is_still_mode: bool = False | |
| enhancer: bool = False | |
| batch_size: int = 4 | |
| size_of_image: int = 256 | |
| pose_style: int = 0 | |
| facerender: str = 'facevid2vid' | |
| exp_weight: float = 1.0 | |
| blink_every: bool = True | |
| talker_method: str = 'SadTalker' | |
| fps: int = 30 | |
| async def clear_memory(): | |
| """Asynchronous function to clear GPU memory.""" | |
| logger.info("Clearing GPU memory resources") | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| logger.info(f"GPU memory usage after clearing: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB") | |
| async def change_model(model_name: str = Query(..., description="Name of the Talker model to load")): | |
| """Change digital human conversation model and load corresponding resources.""" | |
| global talker | |
| # Clear memory to free up unnecessary resources before loading a new model | |
| await clear_memory() | |
| if model_name not in ['SadTalker', 'Wav2Lip', 'Wav2Lipv2', 'NeRFTalk']: | |
| raise HTTPException(status_code=400, detail="Other models are not integrated yet. Please wait for updates.") | |
| try: | |
| if model_name == 'SadTalker': | |
| from TFG import SadTalker | |
| talker = SadTalker(lazy_load=True) | |
| logger.info("SadTalker model loaded successfully") | |
| elif model_name == 'Wav2Lip': | |
| from TFG import Wav2Lip | |
| talker = Wav2Lip("checkpoints/wav2lip_gan.pth") | |
| logger.info("Wav2Lip model loaded successfully") | |
| elif model_name == 'Wav2Lipv2': | |
| from TFG import Wav2Lipv2 | |
| talker = Wav2Lipv2('checkpoints/wav2lipv2.pth') | |
| logger.info("Wav2Lipv2 model loaded successfully, capable of generating higher quality results") | |
| elif model_name == 'NeRFTalk': | |
| from TFG import NeRFTalk | |
| talker = NeRFTalk() | |
| talker.init_model('checkpoints/Obama_ave.pth', 'checkpoints/Obama.json') | |
| logger.info("NeRFTalk model loaded successfully") | |
| logger.warning("NeRFTalk model is trained only for a single person, built-in with the Obama model, uploading other images is ineffective.") | |
| except Exception as e: | |
| logger.error(f"Failed to load {model_name} model: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to load {model_name} model: {e}") | |
| return {"message": f"{model_name} model loaded successfully"} | |
| async def talker_response( | |
| preprocess_type: str = Form('crop'), | |
| is_still_mode: bool = Form(False), | |
| enhancer: bool = Form(False), | |
| batch_size: int = Form(4), | |
| size_of_image: int = Form(256), | |
| pose_style: int = Form(0), | |
| facerender: str = Form('facevid2vid'), | |
| exp_weight: float = Form(1.0), | |
| blink_every: bool = Form(True), | |
| talker_method: str = Form('SadTalker'), | |
| fps: int = Form(30), | |
| source_image: UploadFile = File(..., description="The source image file"), | |
| driven_audio: UploadFile = File(..., description="The audio file that will drive the talking head"), | |
| ): | |
| """Handle digital human conversation requests and generate video.""" | |
| global talker | |
| if talker is None: | |
| raise HTTPException(status_code=400, detail="Talker model not loaded. Please load a model first.") | |
| # Assemble the request data into the TalkerRequest model | |
| request = TalkerRequest( | |
| preprocess_type=preprocess_type, | |
| is_still_mode=is_still_mode, | |
| enhancer=enhancer, | |
| batch_size=batch_size, | |
| size_of_image=size_of_image, | |
| pose_style=pose_style, | |
| facerender=facerender, | |
| exp_weight=exp_weight, | |
| blink_every=blink_every, | |
| talker_method=talker_method, | |
| fps=fps, | |
| ) | |
| # print(request) | |
| # Temporary file paths | |
| temp_image_path = "temp_source_image.jpg" | |
| temp_audio_path = "temp_driven_audio.wav" | |
| try: | |
| # Save uploaded files temporarily | |
| with open(temp_image_path, "wb") as image_file: | |
| shutil.copyfileobj(source_image.file, image_file) | |
| with open(temp_audio_path, "wb") as audio_file: | |
| shutil.copyfileobj(driven_audio.file, audio_file) | |
| # Video generation | |
| if request.talker_method == 'SadTalker': | |
| video_path = talker.test2( | |
| temp_image_path, | |
| temp_audio_path, | |
| request.preprocess_type, | |
| request.is_still_mode, | |
| request.enhancer, | |
| request.batch_size, | |
| request.size_of_image, | |
| request.pose_style, | |
| request.facerender, | |
| request.exp_weight, | |
| REF_VIDEO, REF_INFO, USE_IDLE_MODE, AUDIO_LENGTH, | |
| request.blink_every, | |
| request.fps, | |
| ) | |
| elif request.talker_method == 'Wav2Lip': | |
| video_path = talker.predict(temp_image_path, temp_audio_path, request.batch_size) | |
| elif request.talker_method == 'Wav2Lipv2': | |
| video_path = talker.run(temp_image_path, temp_audio_path, request.batch_size) | |
| elif request.talker_method == 'NeRFTalk': | |
| video_path = talker.predict(temp_audio_path) | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported method") | |
| # Ensure the video file exists and return it | |
| if os.path.exists(video_path): | |
| return FileResponse(video_path, media_type='video/mp4', filename=os.path.basename(video_path)) | |
| else: | |
| raise HTTPException(status_code=404, detail="Video file not found") | |
| except Exception as e: | |
| logger.error(f"Video generation failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Video generation failed: {e}") | |
| finally: | |
| # Clean up temporary files | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| if os.path.exists(temp_audio_path): | |
| os.remove(temp_audio_path) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8003) | |