File size: 6,645 Bytes
e90b704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form
from pydantic import BaseModel
import shutil
import os
from loguru import logger
import sys
import gc
import torch
from typing import Optional
from fastapi.responses import FileResponse

sys.path.append("./")
app = FastAPI()

# Global variable to store the currently loaded Talker model
talker = None

USE_REF_VIDEO = False
REF_VIDEO = None
REF_INFO = 'pose'
USE_IDLE_MODE = False
AUDIO_LENGTH = 5

class TalkerRequest(BaseModel):
    preprocess_type: str = 'crop'
    is_still_mode: bool = False
    enhancer: bool = False
    batch_size: int = 4
    size_of_image: int = 256
    pose_style: int = 0
    facerender: str = 'facevid2vid'
    exp_weight: float = 1.0
    blink_every: bool = True
    talker_method: str = 'SadTalker'
    fps: int = 30

async def clear_memory():
    """Asynchronous function to clear GPU memory."""
    logger.info("Clearing GPU memory resources")
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        logger.info(f"GPU memory usage after clearing: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")

@app.post("/talker_change_model/")
async def change_model(model_name: str = Query(..., description="Name of the Talker model to load")):
    """Change digital human conversation model and load corresponding resources."""
    global talker

    # Clear memory to free up unnecessary resources before loading a new model
    await clear_memory()

    if model_name not in ['SadTalker', 'Wav2Lip', 'Wav2Lipv2', 'NeRFTalk']:
        raise HTTPException(status_code=400, detail="Other models are not integrated yet. Please wait for updates.")

    try:
        if model_name == 'SadTalker':
            from TFG import SadTalker
            talker = SadTalker(lazy_load=True)
            logger.info("SadTalker model loaded successfully")
        elif model_name == 'Wav2Lip':
            from TFG import Wav2Lip
            talker = Wav2Lip("checkpoints/wav2lip_gan.pth")
            logger.info("Wav2Lip model loaded successfully")
        elif model_name == 'Wav2Lipv2':
            from TFG import Wav2Lipv2
            talker = Wav2Lipv2('checkpoints/wav2lipv2.pth')
            logger.info("Wav2Lipv2 model loaded successfully, capable of generating higher quality results")
        elif model_name == 'NeRFTalk':
            from TFG import NeRFTalk
            talker = NeRFTalk()
            talker.init_model('checkpoints/Obama_ave.pth', 'checkpoints/Obama.json')
            logger.info("NeRFTalk model loaded successfully")
            logger.warning("NeRFTalk model is trained only for a single person, built-in with the Obama model, uploading other images is ineffective.")
    except Exception as e:
        logger.error(f"Failed to load {model_name} model: {e}")
        raise HTTPException(status_code=500, detail=f"Failed to load {model_name} model: {e}")

    return {"message": f"{model_name} model loaded successfully"}

@app.post("/talker_response/")
async def talker_response(
    preprocess_type: str = Form('crop'),
    is_still_mode: bool = Form(False),
    enhancer: bool = Form(False),
    batch_size: int = Form(4),
    size_of_image: int = Form(256),
    pose_style: int = Form(0),
    facerender: str = Form('facevid2vid'),
    exp_weight: float = Form(1.0),
    blink_every: bool = Form(True),
    talker_method: str = Form('SadTalker'),
    fps: int = Form(30),
    source_image: UploadFile = File(..., description="The source image file"),
    driven_audio: UploadFile = File(..., description="The audio file that will drive the talking head"),
):
    """Handle digital human conversation requests and generate video."""
    global talker

    if talker is None:
        raise HTTPException(status_code=400, detail="Talker model not loaded. Please load a model first.")

    # Assemble the request data into the TalkerRequest model
    request = TalkerRequest(
        preprocess_type=preprocess_type,
        is_still_mode=is_still_mode,
        enhancer=enhancer,
        batch_size=batch_size,
        size_of_image=size_of_image,
        pose_style=pose_style,
        facerender=facerender,
        exp_weight=exp_weight,
        blink_every=blink_every,
        talker_method=talker_method,
        fps=fps,
    )
    # print(request)

    # Temporary file paths
    temp_image_path = "temp_source_image.jpg"
    temp_audio_path = "temp_driven_audio.wav"

    try:
        # Save uploaded files temporarily
        with open(temp_image_path, "wb") as image_file:
            shutil.copyfileobj(source_image.file, image_file)
        with open(temp_audio_path, "wb") as audio_file:
            shutil.copyfileobj(driven_audio.file, audio_file)

        # Video generation
        if request.talker_method == 'SadTalker':
            video_path = talker.test2(
                temp_image_path, 
                temp_audio_path,
                request.preprocess_type, 
                request.is_still_mode, 
                request.enhancer, 
                request.batch_size, 
                request.size_of_image, 
                request.pose_style, 
                request.facerender, 
                request.exp_weight, 
                REF_VIDEO, REF_INFO, USE_IDLE_MODE, AUDIO_LENGTH,
                request.blink_every, 
                request.fps,
            )
        elif request.talker_method == 'Wav2Lip':
            video_path = talker.predict(temp_image_path, temp_audio_path, request.batch_size)
        elif request.talker_method == 'Wav2Lipv2':
            video_path = talker.run(temp_image_path, temp_audio_path, request.batch_size)
        elif request.talker_method == 'NeRFTalk':
            video_path = talker.predict(temp_audio_path)
        else:
            raise HTTPException(status_code=400, detail="Unsupported method")

        # Ensure the video file exists and return it
        if os.path.exists(video_path):
            return FileResponse(video_path, media_type='video/mp4', filename=os.path.basename(video_path))
        else:
            raise HTTPException(status_code=404, detail="Video file not found")
    except Exception as e:
        logger.error(f"Video generation failed: {e}")
        raise HTTPException(status_code=500, detail=f"Video generation failed: {e}")
    finally:
        # Clean up temporary files
        if os.path.exists(temp_image_path):
            os.remove(temp_image_path)
        if os.path.exists(temp_audio_path):
            os.remove(temp_audio_path)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8003)