| import io |
| import torch |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM |
| import requests |
| from bs4 import BeautifulSoup |
| import tempfile |
| import os |
| from pydub import AudioSegment |
| import dash |
| from dash import dcc, html, Input, Output, State |
| import dash_bootstrap_components as dbc |
| from dash.exceptions import PreventUpdate |
| import threading |
| from pytube import YouTube |
|
|
| print("Script started") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| |
| whisper_model_name = "openai/whisper-small" |
| whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name) |
| whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name).to(device) |
|
|
| |
| qwen_model_name = "Qwen/Qwen2.5-3B-Instruct" |
| qwen_tokenizer = AutoTokenizer.from_pretrained(qwen_model_name, trust_remote_code=True) |
| qwen_model = AutoModelForCausalLM.from_pretrained(qwen_model_name, trust_remote_code=True).to(device) |
|
|
| def download_audio_from_url(url): |
| try: |
| if "youtube.com" in url or "youtu.be" in url: |
| print("Processing YouTube URL...") |
| yt = YouTube(url) |
| audio_stream = yt.streams.filter(only_audio=True).first() |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file: |
| audio_stream.download(output_path=temp_file.name) |
| audio_bytes = open(temp_file.name, "rb").read() |
| os.unlink(temp_file.name) |
| elif "share" in url: |
| print("Processing shareable link...") |
| response = requests.get(url) |
| soup = BeautifulSoup(response.content, 'html.parser') |
| video_tag = soup.find('video') |
| if video_tag and 'src' in video_tag.attrs: |
| video_url = video_tag['src'] |
| print(f"Extracted video URL: {video_url}") |
| else: |
| raise ValueError("Direct video URL not found in the shareable link.") |
| response = requests.get(video_url) |
| audio_bytes = response.content |
| else: |
| print(f"Downloading video from URL: {url}") |
| response = requests.get(url) |
| audio_bytes = response.content |
| |
| print(f"Successfully downloaded {len(audio_bytes)} bytes of data") |
| return audio_bytes |
| except Exception as e: |
| print(f"Error in download_audio_from_url: {str(e)}") |
| raise |
|
|
| def transcribe_audio(audio_file): |
| try: |
| print("Loading audio file...") |
| audio = AudioSegment.from_file(audio_file) |
| audio = audio.set_channels(1).set_frame_rate(16000) |
| audio_array = torch.tensor(audio.get_array_of_samples()).float() |
| |
| print(f"Audio duration: {len(audio) / 1000:.2f} seconds") |
| print("Starting transcription...") |
| input_features = whisper_processor(audio_array, sampling_rate=16000, return_tensors="pt").input_features.to(device) |
| |
| |
| attention_mask = torch.ones_like(input_features) |
| |
| |
| predicted_ids = whisper_model.generate( |
| input_features, |
| attention_mask=attention_mask, |
| language='en', |
| task='translate' |
| ) |
| transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) |
| |
| print(f"Transcription complete. Length: {len(transcription[0])} characters") |
| if len(transcription[0]) < 10: |
| raise ValueError(f"Transcription too short: {transcription[0]}") |
| return transcription[0] |
| except Exception as e: |
| print(f"Error in transcribe_audio: {str(e)}") |
| raise |
|
|
| def separate_speakers(transcription): |
| print("Starting speaker separation...") |
| prompt = f"""Analyze the following transcribed text and separate it into different speakers. Identify potential speaker changes based on context, content shifts, or dialogue patterns. Format the output as follows: |
| |
| 1. Label speakers as "Speaker 1", "Speaker 2", etc. |
| 2. Start each speaker's text on a new line beginning with their label. |
| 3. Separate different speakers' contributions with a blank line. |
| 4. If the same speaker continues, do not insert a blank line or repeat the speaker label. |
| |
| Now, please process the following transcribed text: |
| |
| {transcription} |
| """ |
| |
| inputs = qwen_tokenizer(prompt, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = qwen_model.generate(**inputs, max_new_tokens=4000) |
| result = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| processed_text = result.split("Now, please process the following transcribed text:")[-1].strip() |
| |
| print("Speaker separation complete.") |
| return processed_text |
|
|
| def transcribe_video(url): |
| try: |
| print(f"Attempting to download audio from URL: {url}") |
| audio_bytes = download_audio_from_url(url) |
| print(f"Successfully downloaded {len(audio_bytes)} bytes of audio data") |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
| AudioSegment.from_file(io.BytesIO(audio_bytes)).export(temp_audio.name, format="wav") |
| transcript = transcribe_audio(temp_audio.name) |
| |
| os.unlink(temp_audio.name) |
| |
| if len(transcript) < 10: |
| raise ValueError("Transcription too short, possibly failed") |
| |
| print("Separating speakers...") |
| separated_transcript = separate_speakers(transcript) |
| |
| return separated_transcript |
| except Exception as e: |
| error_message = f"An error occurred: {str(e)}" |
| print(error_message) |
| return error_message |
|
|
| app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) |
|
|
| app.layout = dbc.Container([ |
| dbc.Row([ |
| dbc.Col([ |
| html.H1("Video Transcription with Speaker Separation", className="text-center mb-4"), |
| dbc.Card([ |
| dbc.CardBody([ |
| dbc.Input(id="video-url", type="text", placeholder="Enter video URL"), |
| dbc.Button("Transcribe", id="transcribe-button", color="primary", className="mt-3"), |
| dbc.Spinner(html.Div(id="transcription-output", className="mt-3")), |
| html.Div([ |
| dbc.Button("Download Transcript", id="download-button", color="secondary", className="mt-3", style={'display': 'none'}), |
| dcc.Download(id="download-transcript") |
| ]) |
| ]) |
| ]) |
| ], width=12) |
| ]) |
| ], fluid=True) |
|
|
| @app.callback( |
| Output("transcription-output", "children"), |
| Output("download-button", "style"), |
| Input("transcribe-button", "n_clicks"), |
| State("video-url", "value"), |
| prevent_initial_call=True |
| ) |
| def update_transcription(n_clicks, url): |
| if not url: |
| raise PreventUpdate |
|
|
| def transcribe(): |
| try: |
| transcript = transcribe_video(url) |
| return transcript |
| except Exception as e: |
| import traceback |
| return f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
|
| |
| thread = threading.Thread(target=transcribe) |
| thread.start() |
| thread.join(timeout=600) |
|
|
| if thread.is_alive(): |
| return "Transcription timed out after 10 minutes", {'display': 'none'} |
|
|
| transcript = thread.result if hasattr(thread, 'result') else "Transcription failed" |
|
|
| if transcript and not transcript.startswith("An error occurred"): |
| return dbc.Card([ |
| dbc.CardBody([ |
| html.H5("Transcription Result with Speaker Separation"), |
| html.Pre(transcript, style={"white-space": "pre-wrap", "word-wrap": "break-word"}) |
| ]) |
| ]), {'display': 'block'} |
| else: |
| return transcript, {'display': 'none'} |
|
|
| @app.callback( |
| Output("download-transcript", "data"), |
| Input("download-button", "n_clicks"), |
| State("transcription-output", "children"), |
| prevent_initial_call=True |
| ) |
| def download_transcript(n_clicks, transcription_output): |
| if not transcription_output: |
| raise PreventUpdate |
| |
| transcript = transcription_output['props']['children'][0]['props']['children'][1]['props']['children'] |
| return dict(content=transcript, filename="transcript.txt") |
|
|
| if __name__ == '__main__': |
| print("Starting the Dash application...") |
| app.run(debug=True, host='0.0.0.0', port=7860) |
| print("Dash application has finished running.") |