Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: UTF-8 -*- | |
| ''' | |
| @Project : OmniTalker | |
| @File : app.py | |
| @Author : zhongjian.wzj | |
| @Date : 2025/4/7 19:55 | |
| Copyright (c) 2025, Alibaba Cloud. All rights reserved. | |
| ''' | |
| import os | |
| import json | |
| import time | |
| import random | |
| import requests | |
| import argparse | |
| import uuid | |
| import gradio as gr | |
| from pathlib import Path | |
| local_ip = requests.get('http://myip.ipip.net', timeout=5).text | |
| print('local_ip: ', local_ip) | |
| url = os.getenv('OMNITALKER_URL', "http://localhost:8012") | |
| headers = {"Content-Type": "application/json"} | |
| script_dir = Path(__file__).parent.absolute() | |
| static_folder = script_dir / "static" | |
| static_folder.mkdir(parents=True, exist_ok=True) | |
| result_folder = script_dir / "result" | |
| result_folder.mkdir(parents=True, exist_ok=True) | |
| def auto_remove(folder, max_files=1000): | |
| folder = Path(folder) | |
| if not folder.exists() or not folder.is_dir(): | |
| return | |
| files = [p for p in folder.iterdir() if p.is_file()] | |
| if not files or len(files) < max_files: | |
| return | |
| files.sort(key=lambda x: x.stat().st_ctime) | |
| # oldest_file = min(files, key=lambda x: x.stat().st_ctime) | |
| for i in range(len(files) - max_files + 1): | |
| oldest_file = files[i] | |
| try: | |
| oldest_file.unlink() | |
| print(f"remove file: {oldest_file}") | |
| except PermissionError: | |
| print(f"permission denied: {oldest_file}") | |
| except Exception as e: | |
| print(f"failed: {str(e)}") | |
| def predict(role, content, seed, speed): | |
| data = { | |
| "role": role, | |
| "content": content, | |
| "seed": seed, | |
| "speed": speed, | |
| } | |
| response = requests.post(f'{url}/predict', headers=headers, data=json.dumps(data)) | |
| gen_file_path = result_folder / f"result-{uuid.uuid4().hex}.mp4" | |
| auto_remove(result_folder) | |
| if response.status_code == 200: | |
| with gen_file_path.open(mode='wb') as vid: | |
| vid.write(response.content) | |
| # else: | |
| # raise gr.Error(response.status_code) | |
| return gen_file_path | |
| def generate_seed(): | |
| seed = random.randint(0, 2**32 - 1) | |
| return { | |
| "__type__": "update", | |
| "value": seed | |
| } | |
| def update_examples(): | |
| response = requests.get(f"{url}/get_examples") | |
| if response.status_code == 200: | |
| examples_dict = response.json() | |
| print(examples_dict.keys()) | |
| else: | |
| examples_dict = {} | |
| examples = [] | |
| for role_id, role_cfg in examples_dict.items(): | |
| ref_video_path = static_folder / f'{role_id}.mp4' | |
| if not ref_video_path.is_file(): | |
| response = requests.get(f"{url}/get_video/", params={'role': role_id}) | |
| if response.status_code == 200: | |
| with ref_video_path.open(mode='wb') as vid: | |
| vid.write(response.content) | |
| else: | |
| break | |
| examples.append([role_id, ref_video_path, *list(role_cfg.values())]) | |
| return { | |
| "__type__": "update", | |
| "samples": examples, | |
| } | |
| def check_http(url, timeout=5): | |
| try: | |
| response = requests.get(url, timeout=timeout) | |
| if response.status_code == 200: | |
| print(f"Succeed: {response.status_code}") | |
| return True | |
| else: | |
| print(f"Faild: {response.status_code}") | |
| return False | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error: {e}") | |
| return False | |
| MAX_CONNECT_TIMES = 100 | |
| for try_loop in range(MAX_CONNECT_TIMES): | |
| print(f'Try: {try_loop}/{MAX_CONNECT_TIMES}') | |
| if check_http(url): | |
| break | |
| time.sleep(10) | |
| with open('style.css', 'r') as f: | |
| custom_css = f.read() | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.Markdown("# <center> OmniTalker </center>") | |
| gr.Markdown("### <center> 🏠 [project](https://humanaigc.github.io/omnitalker) 🚀[Paper](https://arxiv.org/abs/2504.02433v1) </center>") | |
| gr.Markdown(''' | |
| ### 步骤 Steps: | |
| 1. 选择角色, 等待`参考视频`加载完成 (自定义角色开发中) | |
| Select a character in Examples. **Waiting for `Reference Video` to load the video**. (Custom upload is currently under development.) | |
| 2. 输入`文本` (目前只支持中英文, 限制100字左右) | |
| Enter `text` (Only **Chinese** and **English** are supported by far. **100** characters limited for performance.) | |
| 3. 生成 (受限于网络和资源推理速度可能达不到1:1, 感谢理解)。 | |
| Generate (Due to limitations in network speed and GPU resources, the generation speed may not achieve a real-time 1:1 ratio. Appreciate.) | |
| ### 技巧 Tips: | |
| 1. 中文中数字尽量用汉字 | |
| The numbers in Chinese text are best written in Chinese characters. | |
| 2. 尝试不同的`seed`来获取最好的结果 (设为-1则每次会自动更改) | |
| Try different `Seed` to achieve the best generation(Use -1 for automatic seeding). God may not play dice, but AI does. | |
| 3. 适当调整语速`Speed`, 尤其是当英语参考人物讲中文时最好调慢一些(0.9-0.95) | |
| Adjust `Speed` to control the speech rate. Especially when generating Chinese speech for native English speakers, recommended setting is 0.9-0.95. | |
| ''') | |
| with gr.Group(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| reference_video = gr.Video(label='Reference Video', interactive=False) | |
| with gr.Column(): | |
| output_video = gr.Video(label='Output Video', streaming=False, autoplay=True) | |
| with gr.Row(equal_height=True): | |
| input_text = gr.Textbox(label="Input Text", lines=8, scale=5) | |
| with gr.Column(scale=1): | |
| with gr.Row(equal_height=True): | |
| seed = gr.Number(value=-1, label="Seed", elem_classes="gradio-number") | |
| btn_seed = gr.Button(value="\U0001F3B2", elem_classes="gradio-button") | |
| speed = gr.Slider(0, 2, value=1, step=0.01, label="Speed", scale=2) | |
| btn_run = gr.Button('Submit', variant='primary') | |
| with gr.Row(equal_height=True): | |
| role = gr.Textbox(label="Role", lines=1, max_lines=1, elem_classes="gradio-textbox", interactive=False) | |
| btn_refresh = gr.Button(value='\U0001f504', elem_classes="gradio-button") | |
| examples = gr.Examples( | |
| examples=update_examples()['samples'], | |
| inputs=[role, reference_video, input_text, seed, speed], | |
| examples_per_page=10, | |
| ) | |
| btn_seed.click(generate_seed, inputs=[], outputs=seed) | |
| btn_refresh.click(update_examples, outputs=[examples.dataset]) | |
| btn_run.click(predict, inputs=[role, input_text, seed, speed], outputs=[output_video]) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-ip", "--server_ip", type=str, default="0.0.0.0", | |
| ) | |
| parser.add_argument( | |
| "-p", "--server_port", type=int, default=7860, | |
| ) | |
| args = parser.parse_args() | |
| demo.launch(server_name=args.server_ip, server_port=args.server_port, share=False) | |