Spaces:
Running
Running
| ''' | |
| sudo docker run --gpus all --runtime=nvidia --rm \ | |
| -v /home/ubuntu/dotdemo/third_party:/third_party \ | |
| -v /home/ubuntu/dotdemo-dev:/dotdemo-dev \ | |
| -v /home/ubuntu/dot-demo-assets/ml-logs:/logs \ | |
| -v /home/ubuntu/dotdemo/train_server:/app \ | |
| --network="host" \ | |
| --shm-size 1G \ | |
| -it fantasyfish677/rvc:v0 /bin/bash | |
| pip3 install flask_cors | |
| python3 /app/server.py 2>&1 | tee /logs/train_server.log | |
| export FLASK_APP=server | |
| export FLASK_DEBUG=true | |
| pip3 install gunicorn | |
| gunicorn -b :8080 --timeout=600 server:app | |
| curl -X GET http://3.16.130.199:8080/ping | |
| curl -X POST http://3.16.130.199:8080/train \ | |
| -H 'Content-Type: application/json' \ | |
| -d '{"expName":"varun124","trainsetDir":"varun124"}' | |
| curl -X GET http://3.16.130.199:8080/check \ | |
| -H 'Content-Type: application/json' \ | |
| -d '{"expName":"kanye-1"}' | |
| ''' | |
| import json | |
| import os | |
| from flask import Flask, request | |
| from logging import exception | |
| import time | |
| from server_utils import train_model | |
| from flask_cors import CORS, cross_origin | |
| print("import successful!") | |
| app = Flask("train server") | |
| cors = CORS(app) | |
| app.config['CORS_HEADERS'] = 'Content-Type' | |
| def healthcheck(): | |
| return json.dumps({"code": 200, "message": "responding"}).encode('utf-8') | |
| def train(): | |
| if request.headers['Content-Type'] != 'application/json': | |
| exception("Header error") | |
| return json.dumps({"message":"Header error"}), 500 | |
| try: | |
| content = request.get_json() | |
| exp_name = content['expName'] | |
| trainset_dir = os.path.join('/dotdemo-dev', content['trainsetDir']) | |
| log_path = os.path.join("/logs{}.log".format(exp_name)) | |
| if os.path.exists('/third_party/RVC/logs/{}'.format(exp_name)): | |
| os.system('rm -rf /third_party/RVC/logs/{}'.format(exp_name)) | |
| if not os.path.exists(trainset_dir): | |
| exception("Training set doesn't exist") | |
| return json.dumps({"message":"Training set doesn't exist"}), 404 | |
| start_time = time.time() | |
| train_model(exp_name, trainset_dir, log_path, total_epoch=20) | |
| end_time = time.time() | |
| return json.dumps({"message": "Training Completed in {} secs.".format(end_time - start_time)}), 200 | |
| except Exception as e: | |
| exception("Training process failed") | |
| return json.dumps({"message":"Training process failed due to {}".format(e)}), 500 | |
| def check(): | |
| if request.headers['Content-Type'] != 'application/json': | |
| exception("Header error") | |
| return json.dumps({"message":"Header error"}), 500 | |
| content = request.get_json() | |
| exp_name = content['expName'] | |
| if os.path.exists('/third_party/RVC/weights/{}.pth'.format(exp_name)): | |
| return json.dumps({"message": "Model found."}), 200 | |
| else: | |
| return json.dumps({"message": "Model not found."}), 200 | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=8080, debug=True) |