Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import streamlit as st | |
| from src.gradio_demo import SadTalker | |
| import tempfile | |
| from PIL import Image | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="SadTalker - Talking Face Animation", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS styling | |
| st.markdown(""" | |
| <style> | |
| .header { | |
| text-align: center; | |
| padding: 1.5rem 0; | |
| margin-bottom: 2rem; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
| } | |
| .header h1 { | |
| margin-bottom: 0.5rem; | |
| font-size: 2.5rem; | |
| } | |
| .header p { | |
| margin-bottom: 0; | |
| font-size: 1.1rem; | |
| } | |
| .tab-content { | |
| padding: 1.5rem; | |
| background: #f8f9fa; | |
| border-radius: 10px; | |
| margin-bottom: 1.5rem; | |
| } | |
| .stVideo { | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
| } | |
| .stImage { | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
| } | |
| .settings-section { | |
| background: #ffffff; | |
| padding: 1.5rem; | |
| border-radius: 10px; | |
| margin-bottom: 1.5rem; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
| } | |
| .warning-box { | |
| background-color: #fff3cd; | |
| color: #856404; | |
| padding: 0.75rem 1.25rem; | |
| border-radius: 0.25rem; | |
| margin-bottom: 1rem; | |
| border: 1px solid #ffeeba; | |
| } | |
| .download-btn { | |
| display: flex; | |
| justify-content: center; | |
| margin-top: 1rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Initialize SadTalker with caching | |
| def load_sadtalker(): | |
| return SadTalker('checkpoints', 'src/config', lazy_load=True) | |
| sad_talker = load_sadtalker() | |
| # Check if running in webui | |
| try: | |
| import webui | |
| in_webui = True | |
| except: | |
| in_webui = False | |
| # Header section | |
| st.markdown(""" | |
| <div class="header"> | |
| <h1>π SadTalker</h1> | |
| <p>Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023)</p> | |
| <div style="display: flex; justify-content: center; gap: 1.5rem; margin-top: 0.5rem;"> | |
| <a href="https://arxiv.org/abs/2211.12194" style="color: white; text-decoration: none; font-weight: 500;">π Arxiv</a> | |
| <a href="https://sadtalker.github.io" style="color: white; text-decoration: none; font-weight: 500;">π Homepage</a> | |
| <a href="https://github.com/Winfredy/SadTalker" style="color: white; text-decoration: none; font-weight: 500;">π» GitHub</a> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Initialize session state | |
| if 'generated_video' not in st.session_state: | |
| st.session_state.generated_video = None | |
| if 'tts_audio' not in st.session_state: | |
| st.session_state.tts_audio = None | |
| if 'source_image' not in st.session_state: | |
| st.session_state.source_image = None | |
| if 'driven_audio' not in st.session_state: | |
| st.session_state.driven_audio = None | |
| # Main columns layout | |
| col1, col2 = st.columns([1, 1], gap="large") | |
| with col1: | |
| st.markdown("### Input Settings") | |
| # Source Image Upload | |
| with st.expander("π¨ Source Image", expanded=True): | |
| uploaded_image = st.file_uploader( | |
| "Upload a clear frontal face image", | |
| type=["jpg", "jpeg", "png"], | |
| key="source_image_upload" | |
| ) | |
| if uploaded_image: | |
| st.session_state.source_image = uploaded_image | |
| image = Image.open(uploaded_image) | |
| st.image(image, caption="Source Image", use_container_width=True) | |
| elif st.session_state.source_image: | |
| image = Image.open(st.session_state.source_image) | |
| st.image(image, caption="Source Image (from session)", use_container_width=True) | |
| else: | |
| st.warning("Please upload a source image") | |
| # Audio Input | |
| with st.expander("π΅ Audio Input", expanded=True): | |
| input_method = st.radio( | |
| "Select input method:", | |
| ["Upload audio file", "Text-to-speech"], | |
| index=0, | |
| key="audio_input_method", | |
| horizontal=True | |
| ) | |
| if input_method == "Upload audio file": | |
| audio_file = st.file_uploader( | |
| "Upload an audio file (WAV, MP3)", | |
| type=["wav", "mp3"], | |
| key="audio_file_upload" | |
| ) | |
| if audio_file: | |
| st.session_state.driven_audio = audio_file | |
| st.audio(audio_file) | |
| elif st.session_state.driven_audio and isinstance(st.session_state.driven_audio, str): | |
| st.audio(st.session_state.driven_audio) | |
| else: | |
| if sys.platform != 'win32' and not in_webui: | |
| from src.utils.text2speech import TTSTalker | |
| tts_talker = TTSTalker() | |
| input_text = st.text_area( | |
| "Enter text for speech synthesis:", | |
| height=150, | |
| placeholder="Type what you want the face to say...", | |
| key="tts_input_text" | |
| ) | |
| if st.button("Generate Speech", key="tts_generate_button"): | |
| if input_text.strip(): | |
| with st.spinner("Generating audio from text..."): | |
| try: | |
| audio_path = tts_talker.test(input_text) | |
| st.session_state.driven_audio = audio_path | |
| st.session_state.tts_audio = audio_path | |
| st.audio(audio_path) | |
| st.success("Audio generated successfully!") | |
| except Exception as e: | |
| st.error(f"Error generating audio: {str(e)}") | |
| else: | |
| st.warning("Please enter some text first") | |
| else: | |
| st.markdown(""" | |
| <div class="warning-box"> | |
| β οΈ Text-to-speech is not available on Windows or in webui mode. | |
| Please use audio upload instead. | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown("### Generation Settings") | |
| with st.container(): | |
| st.markdown(""" | |
| <div class="settings-section"> | |
| <h4>βοΈ Animation Parameters</h4> | |
| """, unsafe_allow_html=True) | |
| # First row of settings | |
| col_a, col_b = st.columns(2) | |
| with col_a: | |
| preprocess_type = st.radio( | |
| "Preprocessing Method", | |
| ['crop', 'resize', 'full', 'extcrop', 'extfull'], | |
| index=0, | |
| key="preprocess_type", | |
| help="How to handle the input image before processing" | |
| ) | |
| size_of_image = st.radio( | |
| "Face Model Resolution", | |
| [256, 512], | |
| index=0, | |
| key="size_of_image", | |
| horizontal=True, | |
| help="Higher resolution (512) may produce better quality but requires more resources" | |
| ) | |
| with col_b: | |
| is_still_mode = st.checkbox( | |
| "Still Mode", | |
| value=False, | |
| key="is_still_mode", | |
| help="Produces fewer head movements (works best with 'full' preprocessing)" | |
| ) | |
| enhancer = st.checkbox( | |
| "Use GFPGAN Enhancer", | |
| value=False, | |
| key="enhancer", | |
| help="Improves face quality using GFPGAN (may slow down processing)" | |
| ) | |
| # Second row of settings | |
| pose_style = st.slider( | |
| "Pose Style", | |
| min_value=0, | |
| max_value=46, | |
| value=0, | |
| step=1, | |
| key="pose_style", | |
| help="Different head poses and expressions" | |
| ) | |
| batch_size = st.slider( | |
| "Batch Size", | |
| min_value=1, | |
| max_value=10, | |
| value=2, | |
| step=1, | |
| key="batch_size", | |
| help="Number of frames processed at once (higher may be faster but uses more memory)" | |
| ) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Generate button | |
| if st.button( | |
| "β¨ Generate Talking Face Animation", | |
| type="primary", | |
| use_container_width=True, | |
| key="generate_button" | |
| ): | |
| if not st.session_state.source_image: | |
| st.error("Please upload a source image first") | |
| elif input_method == "Upload audio file" and not st.session_state.driven_audio: | |
| st.error("Please upload an audio file first") | |
| elif input_method == "Text-to-speech" and not st.session_state.driven_audio: | |
| st.error("Please generate audio from text first") | |
| else: | |
| with st.spinner("Generating talking face animation. This may take a few minutes..."): | |
| try: | |
| # Save uploaded files to temp files | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_image: | |
| image = Image.open(st.session_state.source_image) | |
| image.save(tmp_image.name) | |
| audio_path = None | |
| if input_method == "Upload audio file": | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio: | |
| tmp_audio.write(st.session_state.driven_audio.read()) | |
| audio_path = tmp_audio.name | |
| else: | |
| audio_path = st.session_state.driven_audio | |
| # Generate video | |
| try: | |
| # Ensure all paths are absolute | |
| tmp_image_path = os.path.abspath(tmp_image.name) | |
| audio_path = os.path.abspath(audio_path) if audio_path else None | |
| # Convert all parameters to correct types | |
| video_path = sad_talker.test( | |
| source_image=tmp_image_path, | |
| driven_audio=audio_path, | |
| preprocess_type=str(preprocess_type), | |
| is_still_mode=bool(is_still_mode), | |
| enhancer=bool(enhancer), | |
| batch_size=int(batch_size), | |
| size_of_image=int(size_of_image), | |
| pose_style=int(pose_style), | |
| # These additional parameters might be needed: | |
| ) | |
| # Verify the output | |
| if not os.path.exists(video_path): | |
| raise FileNotFoundError(f"Output video not created at {video_path}") | |
| st.session_state.generated_video = video_path | |
| except Exception as e: | |
| st.error(f"Generation failed: {str(e)}") | |
| # Debug information | |
| st.text(f"Parameters used:") | |
| st.json({ | |
| "source_image": tmp_image_path, | |
| "driven_audio": audio_path, | |
| "preprocess_type": preprocess_type, | |
| "is_still_mode": is_still_mode, | |
| "enhancer": enhancer, | |
| "batch_size": batch_size, | |
| "size_of_image": size_of_image, | |
| "pose_style": pose_style | |
| }) | |
| # Store the generated video in session state | |
| st.session_state.generated_video = video_path | |
| # Clean up temp files | |
| os.unlink(tmp_image.name) | |
| if audio_path and os.path.exists(audio_path) and input_method == "Upload audio file": | |
| os.unlink(audio_path) | |
| st.success("Generation complete! View your result below.") | |
| except Exception as e: | |
| st.error(f"An error occurred during generation: {str(e)}") | |
| st.error("Please check your inputs and try again") | |
| # Results section | |
| if st.session_state.generated_video: | |
| st.markdown("---") | |
| st.markdown("### Generated Animation") | |
| # Display video and download options | |
| col_video, col_download = st.columns([3, 1]) | |
| with col_video: | |
| st.video(st.session_state.generated_video) | |
| with col_download: | |
| # Download button | |
| with open(st.session_state.generated_video, "rb") as f: | |
| video_bytes = f.read() | |
| st.download_button( | |
| label="Download Video", | |
| data=video_bytes, | |
| file_name="sadtalker_animation.mp4", | |
| mime="video/mp4", | |
| use_container_width=True, | |
| key="download_button" | |
| ) | |
| # Regenerate button | |
| if st.button( | |
| "π Regenerate with Same Settings", | |
| use_container_width=True, | |
| key="regenerate_button" | |
| ): | |
| st.experimental_rerun() | |
| # New generation button | |
| if st.button( | |
| "π Start New Generation", | |
| use_container_width=True, | |
| key="new_generation_button" | |
| ): | |
| st.session_state.generated_video = None | |
| st.session_state.tts_audio = None | |
| st.session_state.source_image = None | |
| st.session_state.driven_audio = None | |
| st.experimental_rerun() | |
| # Footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style="text-align: center; color: #666; padding: 1.5rem 0; font-size: 0.9rem;"> | |
| <p>SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation</p> | |
| <p>CVPR 2023 | <a href="https://github.com/Winfredy/SadTalker" target="_blank">GitHub Repository</a> | <a href="https://sadtalker.github.io" target="_blank">Project Page</a></p> | |
| </div> | |
| """, unsafe_allow_html=True) |