File size: 3,055 Bytes
b0c2db7
 
239a35e
 
 
 
 
 
 
 
 
b0c2db7
239a35e
 
 
 
 
 
 
 
 
 
 
 
 
 
b0c2db7
 
 
 
 
 
 
239a35e
b0c2db7
 
 
239a35e
 
 
b0c2db7
239a35e
 
b0c2db7
 
 
239a35e
 
b0c2db7
 
 
 
 
 
 
239a35e
b0c2db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239a35e
 
b0c2db7
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
'''
sudo docker run --gpus all --runtime=nvidia --rm \
    -v /home/ubuntu/dotdemo/third_party:/third_party \
    -v /home/ubuntu/dotdemo-dev:/dotdemo-dev \
    -v /home/ubuntu/dot-demo-assets/ml-logs:/logs \
    -v /home/ubuntu/dotdemo/train_server:/app \
    --network="host" \
    --shm-size 1G \
    -it fantasyfish677/rvc:v0 /bin/bash
pip3 install flask_cors
python3 /app/server.py 2>&1 | tee /logs/train_server.log

export FLASK_APP=server
export FLASK_DEBUG=true
pip3 install gunicorn
gunicorn -b :8080 --timeout=600 server:app

curl -X GET http://3.16.130.199:8080/ping

curl -X POST http://3.16.130.199:8080/train \
-H 'Content-Type: application/json' \
-d '{"expName":"varun124","trainsetDir":"varun124"}'

curl -X GET http://3.16.130.199:8080/check \
-H 'Content-Type: application/json' \
-d '{"expName":"kanye-1"}'
'''
import json
import os
from flask import Flask, request
from logging import exception
import time
from server_utils import train_model
from flask_cors import CORS, cross_origin

print("import successful!")

app = Flask("train server")
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'

@app.route("/ping", methods=['GET', 'POST'])
@cross_origin()
def healthcheck():
    return json.dumps({"code": 200, "message": "responding"}).encode('utf-8')

@app.route("/train", methods=['GET', 'POST'])
@cross_origin()
def train():
    if request.headers['Content-Type'] != 'application/json':
        exception("Header error")
        return json.dumps({"message":"Header error"}), 500
    try:
        content = request.get_json()
        exp_name = content['expName']
        trainset_dir = os.path.join('/dotdemo-dev', content['trainsetDir'])
        log_path = os.path.join("/logs{}.log".format(exp_name))
        if os.path.exists('/third_party/RVC/logs/{}'.format(exp_name)):
            os.system('rm -rf /third_party/RVC/logs/{}'.format(exp_name))
        if not os.path.exists(trainset_dir):
            exception("Training set doesn't exist")
            return json.dumps({"message":"Training set doesn't exist"}), 404

        start_time = time.time()
        train_model(exp_name, trainset_dir, log_path, total_epoch=20)
        end_time = time.time()
        return json.dumps({"message": "Training Completed in {} secs.".format(end_time - start_time)}), 200

    except Exception as e:
        exception("Training process failed")
        return json.dumps({"message":"Training process failed due to {}".format(e)}), 500

@app.route("/check", methods=['GET', 'POST'])
@cross_origin()
def check():
    if request.headers['Content-Type'] != 'application/json':
        exception("Header error")
        return json.dumps({"message":"Header error"}), 500
    content = request.get_json()
    exp_name = content['expName']
    if os.path.exists('/third_party/RVC/weights/{}.pth'.format(exp_name)):   
        return json.dumps({"message": "Model found."}), 200
    else:
        return json.dumps({"message": "Model not found."}), 200

if __name__ == "__main__":
     app.run(host="0.0.0.0", port=8080, debug=True)