Spaces:
Runtime error
Runtime error
TomatoCocotree commited on
Commit ·
8056b16
1
Parent(s): 19400f8
更新server.py
Browse files
server.py
CHANGED
|
@@ -86,13 +86,23 @@ parser.add_argument('--chroma-persist', help="ChromaDB persistence", default=Tru
|
|
| 86 |
parser.add_argument(
|
| 87 |
"--secure", action="store_true", help="Enforces the use of an API key"
|
| 88 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
sd_group = parser.add_mutually_exclusive_group()
|
| 90 |
|
| 91 |
-
local_sd =
|
| 92 |
local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
|
| 93 |
local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
|
| 94 |
|
| 95 |
-
remote_sd =
|
| 96 |
remote_sd.add_argument(
|
| 97 |
"--sd-remote", action="store_true", help="Use a remote backend for SD"
|
| 98 |
)
|
|
@@ -119,7 +129,7 @@ parser.add_argument(
|
|
| 119 |
)
|
| 120 |
|
| 121 |
args = parser.parse_args()
|
| 122 |
-
# [HF, Huggingface] Set port to 7860, set host to remote.
|
| 123 |
port = 7860
|
| 124 |
host = "0.0.0.0"
|
| 125 |
summarization_model = (
|
|
@@ -170,6 +180,28 @@ if not torch.cuda.is_available() and not args.cpu:
|
|
| 170 |
|
| 171 |
print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
if "caption" in modules:
|
| 174 |
print("Initializing an image captioning model...")
|
| 175 |
captioning_processor = AutoProcessor.from_pretrained(captioning_model)
|
|
@@ -189,16 +221,6 @@ if "summarize" in modules:
|
|
| 189 |
summarization_model, torch_dtype=torch_dtype
|
| 190 |
).to(device)
|
| 191 |
|
| 192 |
-
if "classify" in modules:
|
| 193 |
-
print("Initializing a sentiment classification pipeline...")
|
| 194 |
-
classification_pipe = pipeline(
|
| 195 |
-
"text-classification",
|
| 196 |
-
model=classification_model,
|
| 197 |
-
top_k=None,
|
| 198 |
-
device=device,
|
| 199 |
-
torch_dtype=torch_dtype,
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
if "sd" in modules and not sd_use_remote:
|
| 203 |
from diffusers import StableDiffusionPipeline
|
| 204 |
from diffusers import EulerAncestralDiscreteScheduler
|
|
@@ -251,7 +273,6 @@ if "silero-tts" in modules:
|
|
| 251 |
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
|
| 252 |
tts_service.generate_samples()
|
| 253 |
|
| 254 |
-
|
| 255 |
if "edge-tts" in modules:
|
| 256 |
print("Initializing Edge TTS client")
|
| 257 |
import tts_edge as edge
|
|
@@ -295,8 +316,112 @@ if "chromadb" in modules:
|
|
| 295 |
app = Flask(__name__)
|
| 296 |
CORS(app) # allow cross-domain requests
|
| 297 |
Compress(app) # compress responses
|
| 298 |
-
app.config["MAX_CONTENT_LENGTH"] =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
def require_module(name):
|
| 302 |
def wrapper(fn):
|
|
@@ -313,12 +438,7 @@ def require_module(name):
|
|
| 313 |
|
| 314 |
# AI stuff
|
| 315 |
def classify_text(text: str) -> list:
|
| 316 |
-
|
| 317 |
-
text,
|
| 318 |
-
truncation=True,
|
| 319 |
-
max_length=classification_pipe.model.config.max_position_embeddings,
|
| 320 |
-
)[0]
|
| 321 |
-
return sorted(output, key=lambda x: x["score"], reverse=True)
|
| 322 |
|
| 323 |
|
| 324 |
def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
|
|
@@ -417,7 +537,7 @@ def image_to_base64(image: Image, quality: int = 75) -> str:
|
|
| 417 |
return img_str
|
| 418 |
|
| 419 |
|
| 420 |
-
ignore_auth = []
|
| 421 |
# [HF, Huggingface] Get password instead of text file.
|
| 422 |
api_key = os.environ.get("password")
|
| 423 |
|
|
@@ -429,6 +549,7 @@ def is_authorize_ignored(request):
|
|
| 429 |
return True
|
| 430 |
return False
|
| 431 |
|
|
|
|
| 432 |
@app.before_request
|
| 433 |
def before_request():
|
| 434 |
# Request time measuring
|
|
@@ -532,6 +653,8 @@ def api_classify():
|
|
| 532 |
classification = classify_text(data["text"])
|
| 533 |
print("Classification output:", classification, sep="\n")
|
| 534 |
gc.collect()
|
|
|
|
|
|
|
| 535 |
return jsonify({"classification": classification})
|
| 536 |
|
| 537 |
|
|
@@ -540,8 +663,31 @@ def api_classify():
|
|
| 540 |
def api_classify_labels():
|
| 541 |
classification = classify_text("")
|
| 542 |
labels = [x["label"] for x in classification]
|
|
|
|
|
|
|
| 543 |
return jsonify({"labels": labels})
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
|
| 546 |
@app.route("/api/image", methods=["POST"])
|
| 547 |
@require_module("sd")
|
|
@@ -958,7 +1104,8 @@ if args.share:
|
|
| 958 |
cloudflare = _run_cloudflared(port, metrics_port)
|
| 959 |
else:
|
| 960 |
cloudflare = _run_cloudflared(port)
|
| 961 |
-
print("Running on
|
| 962 |
|
| 963 |
ignore_auth.append(tts_play_sample)
|
|
|
|
| 964 |
app.run(host=host, port=port)
|
|
|
|
| 86 |
parser.add_argument(
|
| 87 |
"--secure", action="store_true", help="Enforces the use of an API key"
|
| 88 |
)
|
| 89 |
+
parser.add_argument("--talkinghead-gpu", action="store_true", help="Run the talkinghead animation on the GPU (CPU is default)")
|
| 90 |
+
|
| 91 |
+
parser.add_argument("--coqui-gpu", action="store_true", help="Run the voice models on the GPU (CPU is default)")
|
| 92 |
+
parser.add_argument("--coqui-models", help="Install given Coqui-api TTS model at launch (comma separated list, last one will be loaded at start)")
|
| 93 |
+
|
| 94 |
+
parser.add_argument("--max-content-length", help="Set the max")
|
| 95 |
+
parser.add_argument("--rvc-save-file", action="store_true", help="Save the last rvc input/output audio file into data/tmp/ folder (for research)")
|
| 96 |
+
|
| 97 |
+
parser.add_argument("--stt-vosk-model-path", help="Load a custom vosk speech-to-text model")
|
| 98 |
+
parser.add_argument("--stt-whisper-model-path", help="Load a custom vosk speech-to-text model")
|
| 99 |
sd_group = parser.add_mutually_exclusive_group()
|
| 100 |
|
| 101 |
+
local_sd = parser.add_argument_group("sd-local")
|
| 102 |
local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
|
| 103 |
local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
|
| 104 |
|
| 105 |
+
remote_sd = parser.add_argument_group("sd-remote")
|
| 106 |
remote_sd.add_argument(
|
| 107 |
"--sd-remote", action="store_true", help="Use a remote backend for SD"
|
| 108 |
)
|
|
|
|
| 129 |
)
|
| 130 |
|
| 131 |
args = parser.parse_args()
|
| 132 |
+
# [HF, Huggingface] Set port to 7860, set host to remote.
|
| 133 |
port = 7860
|
| 134 |
host = "0.0.0.0"
|
| 135 |
summarization_model = (
|
|
|
|
| 180 |
|
| 181 |
print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
|
| 182 |
|
| 183 |
+
if "talkinghead" in modules:
|
| 184 |
+
import sys
|
| 185 |
+
import threading
|
| 186 |
+
mode = "cuda" if args.talkinghead_gpu else "cpu"
|
| 187 |
+
print("Initializing talkinghead pipeline in " + mode + " mode....")
|
| 188 |
+
talkinghead_path = os.path.abspath(os.path.join(os.getcwd(), "talkinghead"))
|
| 189 |
+
sys.path.append(talkinghead_path) # Add the path to the 'tha3' module to the sys.path list
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
import talkinghead.tha3.app.app as talkinghead
|
| 193 |
+
from talkinghead import *
|
| 194 |
+
def launch_talkinghead_gui():
|
| 195 |
+
talkinghead.launch_gui(mode, "separable_float")
|
| 196 |
+
#choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'],
|
| 197 |
+
#choices='The device to use for PyTorch ("cuda" for GPU, "cpu" for CPU).'
|
| 198 |
+
talkinghead_thread = threading.Thread(target=launch_talkinghead_gui)
|
| 199 |
+
talkinghead_thread.daemon = True # Set the thread as a daemon thread
|
| 200 |
+
talkinghead_thread.start()
|
| 201 |
+
|
| 202 |
+
except ModuleNotFoundError:
|
| 203 |
+
print("Error: Could not import the 'talkinghead' module.")
|
| 204 |
+
|
| 205 |
if "caption" in modules:
|
| 206 |
print("Initializing an image captioning model...")
|
| 207 |
captioning_processor = AutoProcessor.from_pretrained(captioning_model)
|
|
|
|
| 221 |
summarization_model, torch_dtype=torch_dtype
|
| 222 |
).to(device)
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
if "sd" in modules and not sd_use_remote:
|
| 225 |
from diffusers import StableDiffusionPipeline
|
| 226 |
from diffusers import EulerAncestralDiscreteScheduler
|
|
|
|
| 273 |
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
|
| 274 |
tts_service.generate_samples()
|
| 275 |
|
|
|
|
| 276 |
if "edge-tts" in modules:
|
| 277 |
print("Initializing Edge TTS client")
|
| 278 |
import tts_edge as edge
|
|
|
|
| 316 |
app = Flask(__name__)
|
| 317 |
CORS(app) # allow cross-domain requests
|
| 318 |
Compress(app) # compress responses
|
| 319 |
+
app.config["MAX_CONTENT_LENGTH"] = 500 * 1024 * 1024
|
| 320 |
+
|
| 321 |
+
max_content_length = (
|
| 322 |
+
args.max_content_length
|
| 323 |
+
if args.max_content_length
|
| 324 |
+
else None)
|
| 325 |
+
|
| 326 |
+
if max_content_length is not None:
|
| 327 |
+
print("Setting MAX_CONTENT_LENGTH to",max_content_length,"Mb")
|
| 328 |
+
app.config["MAX_CONTENT_LENGTH"] = int(max_content_length) * 1024 * 1024
|
| 329 |
+
|
| 330 |
+
if "classify" in modules:
|
| 331 |
+
import modules.classify.classify_module as classify_module
|
| 332 |
+
classify_module.init_text_emotion_classifier(classification_model, device, torch_dtype)
|
| 333 |
+
|
| 334 |
+
if "vosk-stt" in modules:
|
| 335 |
+
print("Initializing Vosk speech-recognition (from ST request file)")
|
| 336 |
+
vosk_model_path = (
|
| 337 |
+
args.stt_vosk_model_path
|
| 338 |
+
if args.stt_vosk_model_path
|
| 339 |
+
else None)
|
| 340 |
+
|
| 341 |
+
import modules.speech_recognition.vosk_module as vosk_module
|
| 342 |
+
|
| 343 |
+
vosk_module.model = vosk_module.load_model(file_path=vosk_model_path)
|
| 344 |
+
app.add_url_rule("/api/speech-recognition/vosk/process-audio", view_func=vosk_module.process_audio, methods=["POST"])
|
| 345 |
+
|
| 346 |
+
if "whisper-stt" in modules:
|
| 347 |
+
print("Initializing Whisper speech-recognition (from ST request file)")
|
| 348 |
+
whisper_model_path = (
|
| 349 |
+
args.stt_whisper_model_path
|
| 350 |
+
if args.stt_whisper_model_path
|
| 351 |
+
else None)
|
| 352 |
+
|
| 353 |
+
import modules.speech_recognition.whisper_module as whisper_module
|
| 354 |
+
|
| 355 |
+
whisper_module.model = whisper_module.load_model(file_path=whisper_model_path)
|
| 356 |
+
app.add_url_rule("/api/speech-recognition/whisper/process-audio", view_func=whisper_module.process_audio, methods=["POST"])
|
| 357 |
+
|
| 358 |
+
if "streaming-stt" in modules:
|
| 359 |
+
print("Initializing vosk/whisper speech-recognition (from extras server microphone)")
|
| 360 |
+
whisper_model_path = (
|
| 361 |
+
args.stt_whisper_model_path
|
| 362 |
+
if args.stt_whisper_model_path
|
| 363 |
+
else None)
|
| 364 |
+
|
| 365 |
+
import modules.speech_recognition.streaming_module as streaming_module
|
| 366 |
|
| 367 |
+
streaming_module.whisper_model, streaming_module.vosk_model = streaming_module.load_model(file_path=whisper_model_path)
|
| 368 |
+
app.add_url_rule("/api/speech-recognition/streaming/record-and-transcript", view_func=streaming_module.record_and_transcript, methods=["POST"])
|
| 369 |
+
|
| 370 |
+
if "rvc" in modules:
|
| 371 |
+
print("Initializing RVC voice conversion (from ST request file)")
|
| 372 |
+
print("Increasing server upload limit")
|
| 373 |
+
rvc_save_file = (
|
| 374 |
+
args.rvc_save_file
|
| 375 |
+
if args.rvc_save_file
|
| 376 |
+
else False)
|
| 377 |
+
|
| 378 |
+
if rvc_save_file:
|
| 379 |
+
print("RVC saving file option detected, input/output audio will be savec into data/tmp/ folder")
|
| 380 |
+
|
| 381 |
+
import sys
|
| 382 |
+
sys.path.insert(0,'modules/voice_conversion')
|
| 383 |
+
|
| 384 |
+
import modules.voice_conversion.rvc_module as rvc_module
|
| 385 |
+
rvc_module.save_file = rvc_save_file
|
| 386 |
+
|
| 387 |
+
if "classify" in modules:
|
| 388 |
+
rvc_module.classification_mode = True
|
| 389 |
+
|
| 390 |
+
rvc_module.fix_model_install()
|
| 391 |
+
app.add_url_rule("/api/voice-conversion/rvc/get-models-list", view_func=rvc_module.rvc_get_models_list, methods=["POST"])
|
| 392 |
+
app.add_url_rule("/api/voice-conversion/rvc/upload-models", view_func=rvc_module.rvc_upload_models, methods=["POST"])
|
| 393 |
+
app.add_url_rule("/api/voice-conversion/rvc/process-audio", view_func=rvc_module.rvc_process_audio, methods=["POST"])
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
if "coqui-tts" in modules:
|
| 397 |
+
mode = "GPU" if args.coqui_gpu else "CPU"
|
| 398 |
+
print("Initializing Coqui TTS client in " + mode + " mode")
|
| 399 |
+
import modules.text_to_speech.coqui.coqui_module as coqui_module
|
| 400 |
+
|
| 401 |
+
if mode == "GPU":
|
| 402 |
+
coqui_module.gpu_mode = True
|
| 403 |
+
|
| 404 |
+
coqui_models = (
|
| 405 |
+
args.coqui_models
|
| 406 |
+
if args.coqui_models
|
| 407 |
+
else None
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if coqui_models is not None:
|
| 411 |
+
coqui_models = coqui_models.split(",")
|
| 412 |
+
for i in coqui_models:
|
| 413 |
+
if not coqui_module.install_model(i):
|
| 414 |
+
raise ValueError("Coqui model loading failed, most likely a wrong model name in --coqui-models argument, check log above to see which one")
|
| 415 |
+
|
| 416 |
+
# Coqui-api models
|
| 417 |
+
app.add_url_rule("/api/text-to-speech/coqui/coqui-api/check-model-state", view_func=coqui_module.coqui_check_model_state, methods=["POST"])
|
| 418 |
+
app.add_url_rule("/api/text-to-speech/coqui/coqui-api/install-model", view_func=coqui_module.coqui_install_model, methods=["POST"])
|
| 419 |
+
|
| 420 |
+
# Users models
|
| 421 |
+
app.add_url_rule("/api/text-to-speech/coqui/local/get-models", view_func=coqui_module.coqui_get_local_models, methods=["POST"])
|
| 422 |
+
|
| 423 |
+
# Handle both coqui-api/users models
|
| 424 |
+
app.add_url_rule("/api/text-to-speech/coqui/generate-tts", view_func=coqui_module.coqui_generate_tts, methods=["POST"])
|
| 425 |
|
| 426 |
def require_module(name):
|
| 427 |
def wrapper(fn):
|
|
|
|
| 438 |
|
| 439 |
# AI stuff
|
| 440 |
def classify_text(text: str) -> list:
|
| 441 |
+
return classify_module.classify_text_emotion(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
|
| 444 |
def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
|
|
|
|
| 537 |
return img_str
|
| 538 |
|
| 539 |
|
| 540 |
+
ignore_auth = []
|
| 541 |
# [HF, Huggingface] Get password instead of text file.
|
| 542 |
api_key = os.environ.get("password")
|
| 543 |
|
|
|
|
| 549 |
return True
|
| 550 |
return False
|
| 551 |
|
| 552 |
+
|
| 553 |
@app.before_request
|
| 554 |
def before_request():
|
| 555 |
# Request time measuring
|
|
|
|
| 653 |
classification = classify_text(data["text"])
|
| 654 |
print("Classification output:", classification, sep="\n")
|
| 655 |
gc.collect()
|
| 656 |
+
if "talkinghead" in modules: #send emotion to talkinghead
|
| 657 |
+
talkinghead.setEmotion(classification)
|
| 658 |
return jsonify({"classification": classification})
|
| 659 |
|
| 660 |
|
|
|
|
| 663 |
def api_classify_labels():
|
| 664 |
classification = classify_text("")
|
| 665 |
labels = [x["label"] for x in classification]
|
| 666 |
+
if "talkinghead" in modules:
|
| 667 |
+
labels.append('talkinghead') # Add 'talkinghead' to the labels list
|
| 668 |
return jsonify({"labels": labels})
|
| 669 |
|
| 670 |
+
@app.route("/api/talkinghead/load", methods=["POST"])
|
| 671 |
+
def live_load():
|
| 672 |
+
file = request.files['file']
|
| 673 |
+
# convert stream to bytes and pass to talkinghead_load
|
| 674 |
+
return talkinghead.talkinghead_load_file(file.stream)
|
| 675 |
+
|
| 676 |
+
@app.route('/api/talkinghead/unload')
|
| 677 |
+
def live_unload():
|
| 678 |
+
return talkinghead.unload()
|
| 679 |
+
|
| 680 |
+
@app.route('/api/talkinghead/start_talking')
|
| 681 |
+
def start_talking():
|
| 682 |
+
return talkinghead.start_talking()
|
| 683 |
+
|
| 684 |
+
@app.route('/api/talkinghead/stop_talking')
|
| 685 |
+
def stop_talking():
|
| 686 |
+
return talkinghead.stop_talking()
|
| 687 |
+
|
| 688 |
+
@app.route('/api/talkinghead/result_feed')
|
| 689 |
+
def result_feed():
|
| 690 |
+
return talkinghead.result_feed()
|
| 691 |
|
| 692 |
@app.route("/api/image", methods=["POST"])
|
| 693 |
@require_module("sd")
|
|
|
|
| 1104 |
cloudflare = _run_cloudflared(port, metrics_port)
|
| 1105 |
else:
|
| 1106 |
cloudflare = _run_cloudflared(port)
|
| 1107 |
+
print(f"{Fore.GREEN}{Style.NORMAL}Running on: {cloudflare}{Style.RESET_ALL}")
|
| 1108 |
|
| 1109 |
ignore_auth.append(tts_play_sample)
|
| 1110 |
+
ignore_auth.append(result_feed)
|
| 1111 |
app.run(host=host, port=port)
|