Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,8 @@ import subprocess
|
|
| 7 |
import gc
|
| 8 |
import requests
|
| 9 |
import time
|
|
|
|
|
|
|
| 10 |
from googletrans import Translator
|
| 11 |
import asyncio
|
| 12 |
from flask import Flask, request, jsonify, send_from_directory
|
|
@@ -33,10 +35,62 @@ HEYGEN_API_KEY = "NGM2N2VjNmM4NWM0NGQxMjkyNWFiMjg4OTdlMTI2MDItMTcyNDQ5ODM1MA=="
|
|
| 33 |
HEYGEN_GENERATE_URL = "https://api.heygen.com/v2/video/generate"
|
| 34 |
HEYGEN_STATUS_URL = "https://api.heygen.com/v1/video_status.get"
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
def clear_cuda_memory():
|
| 37 |
torch.cuda.empty_cache()
|
| 38 |
gc.collect()
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def run_inference(video_path, audio_path, video_out_path,
|
| 41 |
inference_ckpt_path, unet_config_path="configs/unet/second_stage.yaml",
|
| 42 |
inference_steps=20, guidance_scale=1.0, seed=1247):
|
|
@@ -391,6 +445,7 @@ def download_heygen_video(video_url):
|
|
| 391 |
def generate_video():
|
| 392 |
global TEMP_DIR
|
| 393 |
TEMP_DIR = create_temp_dir()
|
|
|
|
| 394 |
|
| 395 |
# Get form parameters
|
| 396 |
text_prompt = request.form.get('text_prompt', '').strip()
|
|
@@ -399,18 +454,53 @@ def generate_video():
|
|
| 399 |
|
| 400 |
print('Input text prompt:', text_prompt)
|
| 401 |
|
| 402 |
-
#
|
| 403 |
use_heygen = request.form.get('use_heygen', 'no').lower() == 'yes'
|
| 404 |
voice_cloning = request.form.get('voice_cloning', 'no')
|
| 405 |
target_language = request.form.get('target_language', 'original_text')
|
|
|
|
|
|
|
| 406 |
|
| 407 |
-
#
|
| 408 |
-
if
|
| 409 |
-
|
| 410 |
-
text_prompt = translated_text.strip()
|
| 411 |
-
print('Translated input text prompt:', text_prompt)
|
| 412 |
|
| 413 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
if use_heygen:
|
| 415 |
print("Using HeyGen API for video generation...")
|
| 416 |
|
|
@@ -476,10 +566,17 @@ def generate_video():
|
|
| 476 |
|
| 477 |
processing_method = "HeyGen API" if use_heygen else "Local AI Avatar"
|
| 478 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
return jsonify({
|
| 480 |
"message": f"Video processed successfully using {processing_method}.",
|
| 481 |
"output_video": video_url,
|
| 482 |
"processing_method": processing_method,
|
|
|
|
|
|
|
|
|
|
| 483 |
"status": "success"
|
| 484 |
}), 200
|
| 485 |
else:
|
|
|
|
| 7 |
import gc
|
| 8 |
import requests
|
| 9 |
import time
|
| 10 |
+
import random
|
| 11 |
+
import re
|
| 12 |
from googletrans import Translator
|
| 13 |
import asyncio
|
| 14 |
from flask import Flask, request, jsonify, send_from_directory
|
|
|
|
| 35 |
HEYGEN_GENERATE_URL = "https://api.heygen.com/v2/video/generate"
|
| 36 |
HEYGEN_STATUS_URL = "https://api.heygen.com/v1/video_status.get"
|
| 37 |
|
| 38 |
+
# Initialize OpenAI client
|
| 39 |
+
client = OpenAI(api_key="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")
|
| 40 |
+
|
| 41 |
def clear_cuda_memory():
|
| 42 |
torch.cuda.empty_cache()
|
| 43 |
gc.collect()
|
| 44 |
|
| 45 |
+
def openai_chat_avatar(text_prompt):
|
| 46 |
+
"""Summarize text using OpenAI GPT-4o-mini"""
|
| 47 |
+
response = client.chat.completions.create(
|
| 48 |
+
model="gpt-4o-mini",
|
| 49 |
+
messages=[
|
| 50 |
+
{"role": "system", "content": "Summarize the following paragraph into a complete and accurate single sentence with no more than 30 words. The summary should capture the gist of the paragraph and make sense and remove the citation and document name from the end."},
|
| 51 |
+
{"role": "user", "content": f"Please summarize the following paragraph into one sentence with 30 words or fewer, ensuring it makes sense and captures the gist and remove the citation from the end: {text_prompt}"},
|
| 52 |
+
],
|
| 53 |
+
max_tokens = len(text_prompt),
|
| 54 |
+
)
|
| 55 |
+
return response
|
| 56 |
+
|
| 57 |
+
def ryzedb_chat_avatar(question, app_id):
|
| 58 |
+
"""Query RyzeDB API for response"""
|
| 59 |
+
url = "https://inference.dev.ryzeai.ai/v2/chat/stream"
|
| 60 |
+
print("ryze db question", question)
|
| 61 |
+
|
| 62 |
+
payload = {
|
| 63 |
+
"input": {
|
| 64 |
+
"app_id": app_id,
|
| 65 |
+
"query": question,
|
| 66 |
+
"chat_history": []
|
| 67 |
+
},
|
| 68 |
+
"config": {
|
| 69 |
+
"thread_id": "123456"
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
headers = {
|
| 74 |
+
'Content-Type': 'application/json'
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
response = requests.post(url, json=payload, headers=headers, stream=True)
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
raw_text = response.text.strip()
|
| 81 |
+
|
| 82 |
+
if raw_text.startswith("data:"):
|
| 83 |
+
raw_text = raw_text[len("data:"):].strip()
|
| 84 |
+
|
| 85 |
+
json_data = json.loads(raw_text)
|
| 86 |
+
|
| 87 |
+
response_content = json_data.get("content", "")
|
| 88 |
+
return response_content
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print("Error parsing response:", e)
|
| 92 |
+
return ""
|
| 93 |
+
|
| 94 |
def run_inference(video_path, audio_path, video_out_path,
|
| 95 |
inference_ckpt_path, unet_config_path="configs/unet/second_stage.yaml",
|
| 96 |
inference_steps=20, guidance_scale=1.0, seed=1247):
|
|
|
|
| 445 |
def generate_video():
|
| 446 |
global TEMP_DIR
|
| 447 |
TEMP_DIR = create_temp_dir()
|
| 448 |
+
start_time = time.time()
|
| 449 |
|
| 450 |
# Get form parameters
|
| 451 |
text_prompt = request.form.get('text_prompt', '').strip()
|
|
|
|
| 454 |
|
| 455 |
print('Input text prompt:', text_prompt)
|
| 456 |
|
| 457 |
+
# Get processing parameters
|
| 458 |
use_heygen = request.form.get('use_heygen', 'no').lower() == 'yes'
|
| 459 |
voice_cloning = request.form.get('voice_cloning', 'no')
|
| 460 |
target_language = request.form.get('target_language', 'original_text')
|
| 461 |
+
chat_model_used = request.form.get('chat_model_used', 'ryzedb')
|
| 462 |
+
app_id = request.form.get('app_id', '')
|
| 463 |
|
| 464 |
+
# Validate app_id if using RyzeDB
|
| 465 |
+
if chat_model_used == 'ryzedb' and not app_id:
|
| 466 |
+
return jsonify({'error': 'App ID cannot be blank when using RyzeDB'}), 400
|
|
|
|
|
|
|
| 467 |
|
| 468 |
try:
|
| 469 |
+
# Process text prompt based on chat model selection
|
| 470 |
+
if chat_model_used == 'ryzedb':
|
| 471 |
+
start_time_ryze = time.time()
|
| 472 |
+
print("Processing text with RyzeDB...")
|
| 473 |
+
|
| 474 |
+
# Get response from RyzeDB
|
| 475 |
+
ryze_response = ryzedb_chat_avatar(text_prompt, app_id)
|
| 476 |
+
print("Response from RyzeDB inference:", ryze_response)
|
| 477 |
+
|
| 478 |
+
# Clean up response if needed
|
| 479 |
+
if "No information available" in ryze_response:
|
| 480 |
+
ryze_response = re.sub(r'\\+', '', ryze_response)
|
| 481 |
+
|
| 482 |
+
# Summarize with OpenAI
|
| 483 |
+
openai_response = openai_chat_avatar(ryze_response)
|
| 484 |
+
text_prompt = openai_response.choices[0].message.content.strip()
|
| 485 |
+
|
| 486 |
+
end_time_ryze = time.time()
|
| 487 |
+
ryze_processing_time = end_time_ryze - start_time_ryze
|
| 488 |
+
print(f'Final processed text prompt using RyzeDB + OpenAI: {text_prompt}')
|
| 489 |
+
print(f'Time to process with RyzeDB + OpenAI: {ryze_processing_time:.2f} seconds')
|
| 490 |
+
|
| 491 |
+
elif chat_model_used == 'self':
|
| 492 |
+
print("Using original text prompt without processing...")
|
| 493 |
+
text_prompt = text_prompt.strip()
|
| 494 |
+
else:
|
| 495 |
+
print("Unknown chat model specified, using original text...")
|
| 496 |
+
text_prompt = text_prompt.strip()
|
| 497 |
+
|
| 498 |
+
# Translate text if needed
|
| 499 |
+
if target_language != 'original_text':
|
| 500 |
+
translated_text = translate_text(text_prompt, target_language)
|
| 501 |
+
text_prompt = translated_text.strip()
|
| 502 |
+
print('Translated input text prompt:', text_prompt)
|
| 503 |
+
|
| 504 |
if use_heygen:
|
| 505 |
print("Using HeyGen API for video generation...")
|
| 506 |
|
|
|
|
| 566 |
|
| 567 |
processing_method = "HeyGen API" if use_heygen else "Local AI Avatar"
|
| 568 |
|
| 569 |
+
# Calculate total processing time
|
| 570 |
+
end_time = time.time()
|
| 571 |
+
total_time = end_time - start_time
|
| 572 |
+
|
| 573 |
return jsonify({
|
| 574 |
"message": f"Video processed successfully using {processing_method}.",
|
| 575 |
"output_video": video_url,
|
| 576 |
"processing_method": processing_method,
|
| 577 |
+
"text_prompt": text_prompt,
|
| 578 |
+
"chat_model_used": chat_model_used,
|
| 579 |
+
"time_taken": round(total_time, 2),
|
| 580 |
"status": "success"
|
| 581 |
}), 200
|
| 582 |
else:
|