Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| from difflib import Differ | |
| import ffmpeg | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| import time | |
| import aiohttp | |
| import asyncio | |
| import base64 | |
| from dotenv import load_dotenv | |
| import logging | |
| # --- Configuration --- | |
| # Set true if you're using huggingface inference API API https://huggingface.co/inference-api | |
| API_BACKEND = True | |
| MODEL = "facebook/wav2vec2-base-960h" | |
| API_URL = f'https://api-inference.huggingface.co/models/{MODEL}' | |
| RETRY_ATTEMPTS = 5 | |
| RETRY_DELAY = 5 | |
| TIMESTAMP_GROUPING_THRESHOLD = 0.1 | |
| # --- Logging Configuration --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s - %(funcName)s') | |
| # --- Initialization --- | |
| if API_BACKEND: | |
| load_dotenv(Path(".env")) | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| logging.error("HF_TOKEN environment variable not set. Please set it in a .env file.") | |
| raise ValueError("HF_TOKEN environment variable not set.") | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| else: | |
| import torch | |
| from transformers import pipeline | |
| device = 0 if torch.cuda.is_available() else -1 | |
| try: | |
| logging.info(f"Initializing local model: {MODEL} on device: {device}") | |
| speech_recognizer = pipeline( | |
| task="automatic-speech-recognition", | |
| model=MODEL, | |
| tokenizer=MODEL, | |
| framework="pt", | |
| device=device, | |
| ) | |
| logging.info("Local model initialized successfully.") | |
| except Exception as e: | |
| logging.error(f"Error initializing local model {MODEL}: {e}") | |
| raise RuntimeError(f"Error initializing local model {MODEL}: {e}") | |
| videos_out_path = Path("./videos_out") | |
| videos_out_path.mkdir(parents=True, exist_ok=True) | |
| logging.info(f"Output directory created: {videos_out_path}") | |
| samples_data_files = sorted(Path('examples').glob('*.json')) | |
| SAMPLES = [] | |
| for file in samples_data_files: | |
| try: | |
| with open(file, 'r') as f: | |
| sample = json.load(f) | |
| if 'video' in sample and 'transcription' in sample and 'timestamps' in sample: | |
| SAMPLES.append(sample) | |
| else: | |
| logging.warning(f"Skipping sample file {file} due to missing keys (video, transcription, or timestamps).") | |
| except (json.JSONDecodeError, FileNotFoundError) as e: | |
| logging.error(f"Error loading sample file {file}: {e}") | |
| VIDEOS = [[sample['video']] for sample in SAMPLES] | |
| logging.info(f"Loaded {len(SAMPLES)} example samples.") | |
| # --- Helper Functions --- | |
| async def query_api(audio_bytes: bytes): | |
| """ | |
| Query the Hugging Face Inference API for Automatic Speech Recognition. | |
| Includes retry logic with exponential backoff. | |
| """ | |
| payload = json.dumps({ | |
| "inputs": base64.b64encode(audio_bytes).decode("utf-8"), | |
| "parameters": { | |
| "return_timestamps": "char", | |
| "chunk_length_s": 10, | |
| "stride_length_s": [4, 2] | |
| }, | |
| "options": {"use_gpu": False} | |
| }).encode("utf-8") | |
| async with aiohttp.ClientSession() as session: | |
| for attempt in range(RETRY_ATTEMPTS): | |
| logging.info(f'Transcribing from API attempt {attempt + 1}/{RETRY_ATTEMPTS}') | |
| try: | |
| async with session.post(API_URL, headers=headers, data=payload) as response: | |
| logging.info(f"API Response Status: {response.status}") | |
| content_type = response.headers.get('Content-Type', '') | |
| if response.status == 200 and 'application/json' in content_type: | |
| return await response.json() | |
| elif response.status != 200 and 'application/json' in content_type: | |
| error_response = await response.json() | |
| if 'error' in error_response and 'estimated_time' in error_response: | |
| wait_time = error_response['estimated_time'] | |
| logging.warning(f"Model loading, waiting for {wait_time} seconds...") | |
| await asyncio.sleep(wait_time + RETRY_DELAY) | |
| elif 'error' in error_response: | |
| raise RuntimeError(f"API Error: {error_response['error']}") | |
| else: | |
| raise RuntimeError(f"Unknown API Error: {error_response}") | |
| else: | |
| response_text = await response.text() | |
| raise RuntimeError(f"Unexpected API response format (Status: {response.status}, Content-Type: {content_type}): {response_text}") | |
| except aiohttp.ClientError as e: | |
| logging.error(f"AIOHTTP Client Error during API call (Attempt {attempt + 1}): {e}") | |
| except RuntimeError as e: | |
| logging.error(f"Runtime error during API call (Attempt {attempt + 1}): {e}") | |
| if attempt < RETRY_ATTEMPTS - 1: | |
| wait_time = RETRY_DELAY * (2 ** attempt) | |
| logging.info(f"Retrying in {wait_time} seconds...") | |
| await asyncio.sleep(wait_time) | |
| raise RuntimeError(f"Failed to get transcription after {RETRY_ATTEMPTS} attempts.") | |
| def ping_telemetry(name: str): | |
| """ | |
| Send a telemetry ping to Hugging Face Spaces. | |
| This is fire-and-forget and doesn't affect the main process flow. | |
| """ | |
| url = f'https://huggingface.co/api/telemetry/spaces/radames/edit-video-by-editing-text/{name}' | |
| logging.info(f"Pinging telemetry: {url}") | |
| async def send_ping(): | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(url) as response: | |
| logging.info(f"Telemetry pong: {response.status}") | |
| except aiohttp.ClientError as e: | |
| logging.warning(f"Failed to send telemetry ping: {e}") | |
| asyncio.create_task(send_ping()) | |
| # --- Main Gradio Functions --- | |
| async def speech_to_text(video_file_path, progress=gr.Progress()): | |
| """ | |
| Takes a video path to convert to audio, transcribe audio channel to text and char timestamps. | |
| Includes progress reporting. | |
| """ | |
| if video_file_path is None: | |
| raise gr.Error("Error: No video input provided.") | |
| video_path = Path(video_file_path) | |
| if not video_path.exists(): | |
| raise gr.Error(f"Error: Video file not found at {video_path}") | |
| temp_audio_file = None | |
| try: | |
| progress(0, desc="Converting video to audio...") | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: | |
| temp_audio_file = Path(tmpfile.name) | |
| loop = asyncio.get_running_loop() | |
| await loop.run_in_executor( | |
| None, lambda: ffmpeg.input(video_path).output( | |
| str(temp_audio_file), format="wav", ac=1, ar='16k').overwrite_output().global_args('-loglevel', 'quiet').run() | |
| ) | |
| logging.info(f"Video converted to temporary audio file: {temp_audio_file}") | |
| with open(temp_audio_file, 'rb') as f: | |
| audio_memory = f.read() | |
| except ffmpeg.Error as e: | |
| logging.error(f"Error converting video to audio: {e.stderr.decode()}") | |
| raise gr.Error(f"Error converting video to audio: {e.stderr.decode()}") | |
| except Exception as e: | |
| logging.error(f"An unexpected error occurred during audio conversion: {e}") | |
| raise gr.Error(f"An unexpected error occurred during audio conversion: {e}") | |
| finally: | |
| if temp_audio_file and temp_audio_file.exists(): | |
| os.remove(temp_audio_file) | |
| logging.info(f"Cleaned up temporary audio file: {temp_audio_file}") | |
| ping_telemetry("speech_to_text") | |
| progress(0.5, desc="Transcribing audio...") | |
| if API_BACKEND: | |
| try: | |
| inference_response = await query_api(audio_memory) | |
| logging.info("Inference Response received from API.") | |
| if not isinstance(inference_response, dict) or 'text' not in inference_response or 'chunks' not in inference_response: | |
| raise RuntimeError(f"Unexpected API response structure: {inference_response}") | |
| transcription = inference_response["text"].lower() | |
| timestamps = [[chunk.get("text", "").lower(), chunk.get("timestamp", [None, None])[0], chunk.get("timestamp", [None, None])[1]] | |
| for chunk in inference_response.get('chunks', []) if isinstance(chunk, dict)] | |
| timestamps = [ts for ts in timestamps if ts[1] is not None and ts[2] is not None] | |
| progress(1.0, desc="Transcription complete.") | |
| return (transcription, transcription, timestamps) | |
| except Exception as e: | |
| logging.error(f"Error fetching transcription from API: {e}") | |
| raise gr.Error(f"Error fetching transcription from API: {e}") | |
| else: | |
| try: | |
| logging.info(f'Transcribing via local model {MODEL}') | |
| loop = asyncio.get_running_loop() | |
| output = await loop.run_in_executor( | |
| None, lambda: speech_recognizer( | |
| audio_memory, return_timestamps="char", chunk_length_s=10, stride_length_s=(4, 2)) | |
| ) | |
| logging.info("Inference complete with local model.") | |
| if not isinstance(output, dict) or 'text' not in output or 'chunks' not in output: | |
| raise RuntimeError(f"Unexpected model output structure: {output}") | |
| transcription = output["text"].lower() | |
| timestamps = [[chunk.get("text", "").lower(), | |
| chunk.get("timestamp", [None, None])[0] if not isinstance(chunk.get("timestamp", [None, None])[0], list) else chunk.get("timestamp", [None, None])[0][0], | |
| chunk.get("timestamp", [None, None])[1] if not isinstance(chunk.get("timestamp", [None, None])[1], list) else chunk.get("timestamp", [None, None])[1][0] | |
| ] | |
| for chunk in output.get('chunks', []) if isinstance(chunk, dict)] | |
| timestamps = [ts for ts in timestamps if ts[1] is not None and ts[2] is not None] | |
| progress(1.0, desc="Transcription complete.") | |
| return (transcription, transcription, timestamps) | |
| except Exception as e: | |
| logging.error(f"Error running inference with local model: {e}") | |
| raise gr.Error(f"Error running inference with local model: {e}") | |
| async def cut_timestamps_to_video(video_in, transcription, text_in, timestamps, progress=gr.Progress()): | |
| """ | |
| Given original video input, text transcript + timestamps, | |
| and edited text cuts video segments into a single video. | |
| Includes progress reporting and improved timestamp alignment. | |
| """ | |
| if video_in is None or text_in is None or transcription is None or timestamps is None: | |
| raise gr.Error("Inputs undefined. Please provide video, transcription, and edited text.") | |
| video_path = Path(video_in) | |
| if not video_path.exists(): | |
| raise gr.Error(f"Error: Video file not found at {video_path}") | |
| progress(0, desc="Analyzing text differences...") | |
| d = Differ() | |
| diff_chars = list(d.compare(transcription, text_in)) | |
| # --- Improved Timestamp Alignment --- | |
| timestamps_to_keep = [] | |
| timestamp_idx = 0 | |
| diff_idx = 0 | |
| while diff_idx < len(diff_chars) and timestamp_idx < len(timestamps): | |
| diff_line = diff_chars[diff_idx] | |
| ts_info = timestamps[timestamp_idx] | |
| ts_char = ts_info[0] | |
| if diff_line.startswith(' '): | |
| if diff_line[2:].lower() == ts_char.lower(): | |
| timestamps_to_keep.append(ts_info) | |
| timestamp_idx += 1 | |
| diff_idx += 1 | |
| else: | |
| logging.warning(f"Timestamp alignment mismatch: Diff char '{diff_line[2:]}' vs Timestamp char '{ts_char}'. Skipping timestamp.") | |
| diff_idx += 1 | |
| elif diff_line.startswith('-'): | |
| if diff_line[2:].lower() == ts_char.lower(): | |
| timestamp_idx += 1 | |
| diff_idx += 1 | |
| else: | |
| logging.warning(f"Timestamp alignment mismatch for deletion: Diff char '{diff_line[2:]}' vs Timestamp char '{ts_char}'. Skipping diff char.") | |
| diff_idx += 1 | |
| elif diff_line.startswith('+'): | |
| diff_idx += 1 | |
| elif diff_line.startswith('?'): | |
| diff_idx += 1 | |
| else: | |
| logging.warning(f"Unexpected diff line format: {diff_line}. Skipping.") | |
| diff_idx += 1 | |
| logging.info(f"Identified {len(timestamps_to_keep)} timestamps to keep after diff alignment.") | |
| progress(0.2, desc="Grouping timestamps...") | |
| grouped_segments = [] | |
| if timestamps_to_keep: | |
| current_segment = [timestamps_to_keep[0]] | |
| for i in range(1, len(timestamps_to_keep)): | |
| if timestamps_to_keep[i][1] - current_segment[-1][2] < TIMESTAMP_GROUPING_THRESHOLD: | |
| current_segment.append(timestamps_to_keep[i]) | |
| else: | |
| grouped_segments.append(current_segment) | |
| current_segment = [timestamps_to_keep[i]] | |
| grouped_segments.append(current_segment) | |
| logging.info(f"Grouped timestamps into {len(grouped_segments)} segments.") | |
| cut_intervals = [[segment[0][1], segment[-1][2]] for segment in grouped_segments] | |
| video_file_name = video_path.stem | |
| output_video_path = videos_out_path / f"{video_file_name}_cut.mp4" | |
| if cut_intervals: | |
| progress(0.4, desc="Cutting video segments...") | |
| input_video_stream = ffmpeg.input(video_in) | |
| filter_complex_parts = [] | |
| input_streams = [] | |
| for i, interval in enumerate(cut_intervals): | |
| start, end = interval | |
| filter_complex_parts.append(f"[0:v]trim=start={start},end={end},setpts=PTS-STARTPTS[v{i}]") | |
| filter_complex_parts.append(f"[0:a]atrim=start={start},end={end},asetpts=PTS-STARTPTS[a{i}]") | |
| input_streams.append(f"[v{i}][a{i}]") | |
| concat_input_str = ''.join(input_streams) | |
| concat_filter = f"{concat_input_str}concat=n={len(cut_intervals)}:v=1:a=1[outv][outa]" | |
| filter_complex_parts.append(concat_filter) | |
| filter_complex_str = ';'.join(filter_complex_parts) | |
| try: | |
| loop = asyncio.get_running_loop() | |
| await loop.run_in_executor( | |
| None, lambda: ffmpeg.output( | |
| input_video_stream, | |
| str(output_video_path), | |
| filter_complex=filter_complex_str, | |
| map=['[outv]', '[outa]'], | |
| preset='fast', | |
| crf=23 | |
| ).overwrite_output().global_args('-loglevel', 'quiet').run() | |
| ) | |
| logging.info(f"Video segments cut and concatenated to: {output_video_path}") | |
| except ffmpeg.Error as e: | |
| logging.error(f"Error cutting video: {e.stderr.decode()}") | |
| raise gr.Error(f"Error cutting video: {e.stderr.decode()}") | |
| except Exception as e: | |
| logging.error(f"An unexpected error occurred during video cutting: {e}") | |
| raise gr.Error(f"An unexpected error occurred during video cutting: {e}") | |
| else: | |
| logging.warning("No text was kept, creating a short empty video.") | |
| try: | |
| loop = asyncio.get_running_loop() | |
| await loop.run_in_executor( | |
| None, lambda: ffmpeg.input('color=c=black:s=1280x720:d=0.1', f='lavfi').output( | |
| str(output_video_path), | |
| format='mp4', | |
| vcodec='libx264', | |
| pix_fmt='yuv420p', | |
| t='0.1' | |
| ).overwrite_output().global_args('-loglevel', 'quiet').run() | |
| ) | |
| logging.info(f"Created short empty video at: {output_video_path}") | |
| except ffmpeg.Error as e: | |
| logging.error(f"Error creating empty video: {e.stderr.decode()}") | |
| output_video_path = Path(video_in) | |
| logging.warning("Failed to create empty video, returning original video path as fallback.") | |
| except Exception as e: | |
| logging.error(f"An unexpected error occurred during empty video creation: {e}") | |
| output_video_path = Path(video_in) | |
| logging.warning("Failed to create empty video, returning original video path as fallback.") | |
| diff_output_tokens = [(token[2:], token[0] if token[0] != ' ' else None) | |
| for token in diff_chars] | |
| ping_telemetry("video_cuts") | |
| progress(1.0, desc="Video cutting complete.") | |
| return (diff_output_tokens, str(output_video_path)) | |
| def load_example(id): | |
| """Loads example video and transcription.""" | |
| if 0 <= id < len(SAMPLES): | |
| sample = SAMPLES[id] | |
| video = sample.get('video') | |
| transcription = sample.get('transcription', '').lower() | |
| timestamps = sample.get('timestamps', []) | |
| if video is None: | |
| logging.error(f"Example at index {id} is missing video path.") | |
| raise gr.Error(f"Example at index {id} is missing video path.") | |
| return (video, transcription, transcription, timestamps) | |
| else: | |
| logging.error(f"Invalid example index: {id}") | |
| raise gr.Error(f"Invalid example index: {id}") | |
| # --- Gradio Layout --- | |
| css = """ | |
| #cut_btn, #reset_btn { align-self:stretch; } | |
| #\\31 3 { max-width: 540px; } | |
| .output-markdown {max-width: 65ch !important;} | |
| #video-container{ | |
| max-width: 40rem; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| transcription_var = gr.State(value="") | |
| timestamps_var = gr.State(value=[]) | |
| video_in = gr.Video(label="Video file", elem_id="video-container") | |
| text_in = gr.Textbox(label="Transcription", lines=10, interactive=True) | |
| video_out = gr.Video(label="Video Out", interactive=False) | |
| diff_out = gr.HighlightedText(label="Cuts Diffs", combine_adjacent=True, show_legend=True) | |
| gr.Markdown(""" | |
| # Edit Video By Editing Text | |
| This project is a quick proof of concept of a simple video editor where the edits | |
| are made by editing the audio transcription. | |
| Using the [Huggingface Automatic Speech Recognition Pipeline](https://huggingface.co/tasks/automatic-speech-recognition) | |
| with a fine tuned [Wav2Vec2 model using Connectionist Temporal Classification (CTC)](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self) | |
| you can predict not only the text transcription but also the [character or word base timestamps](https://huggingface.co/docs/transformers/v4.19.2/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__.return_timestamps) | |
| """) | |
| with gr.Row(): | |
| examples = gr.Dataset(components=[video_in], samples=VIDEOS, type="index", label="Examples") | |
| examples.click( | |
| load_example, | |
| inputs=[examples], | |
| outputs=[video_in, text_in, transcription_var, timestamps_var], | |
| queue=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # video_in is rendered when defined within gr.Blocks | |
| transcribe_btn = gr.Button("Transcribe Audio") | |
| transcribe_btn.click( | |
| speech_to_text, | |
| inputs=[video_in], | |
| outputs=[text_in, transcription_var, timestamps_var] | |
| ) | |
| gr.Markdown(""" | |
| ### Now edit as text | |
| After running the video transcription, you can make cuts to the text below (only cuts, not additions!)""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # text_in is rendered when defined within gr.Blocks | |
| with gr.Row(): | |
| cut_btn = gr.Button("Cut to video", elem_id="cut_btn") | |
| cut_btn.click( | |
| cut_timestamps_to_video, | |
| inputs=[video_in, transcription_var, text_in, timestamps_var], | |
| outputs=[diff_out, video_out] | |
| ) | |
| reset_transcription = gr.Button( | |
| "Reset to last transcription", elem_id="reset_btn") | |
| reset_tran |