Spaces:
Build error
Build error
| import gradio as gr | |
| import spaces | |
| import torch | |
| import gdown | |
| import os | |
| import setup_environment | |
| import zipfile | |
| import sys | |
| from setup_environment import initialize_environment | |
| from download import download_files_from_url | |
| zero = torch.Tensor([0]).cuda() | |
| print(zero.device) # <-- 'cpu' 🤔 | |
| download_files_from_url() | |
| initialize_environment() | |
| sys.path.append('/home/user/.local/lib/python3.10/site-packages') | |
| sys.path.append('/home/user/.local/lib/python3.10/site-packages/stf_alternative/src/stf_alternative') | |
| sys.path.append('/home/user/.local/lib/python3.10/site-packages/stf_tools/src/stf_tools') | |
| sys.path.append('/tmp/') | |
| sys.path.append('/tmp/stf/') | |
| sys.path.append('/tmp/stf/stf_alternative/') | |
| sys.path.append('/tmp/stf/stf_alternative/src/stf_alternative') | |
| # CUDA 경로를 환경 변수로 설정 | |
| os.environ['PATH'] = '/usr/local/cuda/bin:' + os.environ.get('PATH', '') | |
| os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64:' + os.environ.get('LD_LIBRARY_PATH', '') | |
| # 확인용 출력 | |
| print("PATH:", os.environ['PATH']) | |
| print("LD_LIBRARY_PATH:", os.environ['LD_LIBRARY_PATH']) | |
| from stf_utils import STFPipeline | |
| stf_pipeline = STFPipeline() | |
| def gpu_wrapped_stf_pipeline_execute(audio_path): | |
| return stf_pipeline.execute(audio_path) | |
| driving_video_path=gr.Video() | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| audio_path_component = gr.Audio(label="Upload or Record an audio", type="filepath") | |
| stf_button = gr.Button("stf test", variant="primary") | |
| driving_video_path.render() | |
| stf_button.click( | |
| fn=gpu_wrapped_stf_pipeline_execute, | |
| inputs=[ | |
| audio_path_component | |
| ], | |
| outputs=[driving_video_path] | |
| ) | |
| # @spaces.GPU | |
| # def greet(n): | |
| # print(zero.device) # <-- 'cuda:0' 🤗 | |
| # return f"Hello {zero + n} Tensor" | |
| #demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text()) | |
| demo.launch() | |