clementBE commited on
Commit
2b96b70
Β·
verified Β·
1 Parent(s): 49df268

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -127
app.py CHANGED
@@ -1,165 +1,158 @@
 
 
 
1
  import os
2
- import tempfile
3
  import datetime
4
  import time
5
- import torch
6
- import gradio as gr
7
- import spaces
8
  from transformers import pipeline
9
  from docx import Document
10
- from pydub import AudioSegment
11
- from sumy.parsers.plaintext import PlaintextParser
12
- from sumy.nlp.tokenizers import Tokenizer
13
- from sumy.summarizers.lex_rank import LexRankSummarizer
14
- import nltk
15
-
16
- # --- Ensure NLTK punkt tokenizer is downloaded ---
17
- try:
18
- nltk.data.find("tokenizers/punkt")
19
- except LookupError:
20
- nltk.download("punkt")
21
 
22
- # --- Model definitions ---
23
  MODEL_SIZES = {
24
  "Tiny (Fastest)": "openai/whisper-tiny",
25
  "Base (Faster)": "openai/whisper-base",
26
  "Small (Balanced)": "openai/whisper-small",
27
  "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
28
- "Distil-Large-v3-FR (French-Specific)": "eustlb/distil-large-v3-fr"
29
  }
30
 
31
- # --- Caches ---
32
  model_cache = {}
33
 
34
- # --- Whisper pipeline loader ---
35
  def get_model_pipeline(model_name, progress):
36
  if model_name not in model_cache:
37
- progress(0, desc="πŸš€ Loading model...")
38
  model_id = MODEL_SIZES[model_name]
39
  device = 0 if torch.cuda.is_available() else "cpu"
 
 
40
  model_cache[model_name] = pipeline(
41
  "automatic-speech-recognition",
42
  model=model_id,
43
  device=device
44
  )
45
- progress(0.5, desc=f"βœ… {model_name} loaded")
46
  return model_cache[model_name]
47
 
48
- # --- Extractive summary ---
49
- def extractive_summary(text, sentences_count=7):
50
- """
51
- Summarize the text using LexRank (extractive summarization)
52
- """
53
- parser = PlaintextParser.from_string(text, Tokenizer("french"))
54
- summarizer = LexRankSummarizer()
55
- summary = summarizer(parser.document, sentences_count)
56
- return " ".join(str(s) for s in summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # --- Extract audio from video/audio ---
59
- def extract_audio(file_path):
60
- ext = os.path.splitext(file_path)[1].lower()
61
- if ext in [".wav", ".mp3", ".m4a", ".flac"]:
62
- return file_path
63
- temp_audio = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
64
- temp_audio.close()
65
- audio = AudioSegment.from_file(file_path)
66
- audio.export(temp_audio.name, format="wav")
67
- return temp_audio.name
68
-
69
- # --- Split audio into 10-minute chunks ---
70
- def split_audio(audio_path):
71
- audio = AudioSegment.from_file(audio_path)
72
- chunk_length_ms = 10 * 60 * 1000 # 10 minutes
73
- chunks = []
74
- labels = []
75
- for i, start in enumerate(range(0, len(audio), chunk_length_ms)):
76
- end = min(start + chunk_length_ms, len(audio))
77
- chunk = audio[start:end]
78
- temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
79
- chunk.export(temp_file.name, format="wav")
80
- chunks.append(temp_file.name)
81
- labels.append(f"{i*10}-{(i+1)*10} min")
82
- return chunks, labels
83
-
84
- # --- Export transcription to DOCX ---
85
- def export_transcription_docx(text, file_path="transcription_full.docx"):
86
- doc = Document()
87
- doc.add_heading("Full Transcription", 0)
88
- for paragraph in text.split("\n"):
89
- doc.add_paragraph(paragraph.strip())
90
- doc.save(file_path)
91
- return file_path
92
-
93
- # --- Transcribe selected chunks ---
94
  @spaces.GPU
95
- def transcribe_selected(file, model_size, selected_chunks, generate_summary, progress=gr.Progress()):
96
- if file is None:
97
- return None, None, None, "Please upload a file."
98
-
99
- progress(0, desc="🎬 Extracting audio...")
100
- audio_file = extract_audio(file)
101
- chunks, labels = split_audio(audio_file)
102
-
103
- # Select chunks
104
- chosen_files = [chunks[i] for i, label in enumerate(labels) if label in selected_chunks]
105
 
 
 
106
  pipe = get_model_pipeline(model_size, progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- full_text = ""
109
- for idx, chunk_file in enumerate(chosen_files):
110
- progress((idx+1)/len(chosen_files), desc=f"🎀 Transcribing chunk {idx+1}/{len(chosen_files)}...")
111
- if model_size == "Distil-Large-v3-FR (French-Specific)":
112
- output = pipe(chunk_file, return_timestamps=True, generate_kwargs={"language": "fr"})
113
- else:
114
- output = pipe(chunk_file, return_timestamps=True)
115
- full_text += output.get("text", "") + "\n"
116
-
117
- # Export full transcription DOCX
118
- docx_path = export_transcription_docx(full_text)
119
-
120
- # Generate extractive summary (optional, not shown in UI)
121
- summary_text = None
122
- if generate_summary and full_text.strip():
123
- summary_text = extractive_summary(full_text, sentences_count=7)
124
-
125
- return full_text, docx_path, summary_text, f"βœ… Done. Transcribed {len(chosen_files)} parts."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # --- Gradio UI ---
128
- with gr.Blocks(title="Whisper Chunked Transcription") as demo:
129
- gr.Markdown("# πŸŽ™οΈ Whisper Chunked Transcription")
130
- gr.Markdown("Upload audio/video, select 10-minute parts to transcribe, generate extractive summary (hidden), and export full transcription as DOCX.")
131
-
132
  with gr.Row():
133
- file_input = gr.Audio(sources=["upload"], type="filepath", label="Upload File")
134
- model_selector = gr.Dropdown(
135
- label="Whisper Model",
136
- choices=list(MODEL_SIZES.keys()),
137
- value="Distil-Large-v3-FR (French-Specific)"
138
- )
139
-
140
- chunk_selector = gr.CheckboxGroup(label="Select 10-minute parts", choices=[])
141
- summary_checkbox = gr.Checkbox(label="Generate Extractive Summary", value=True)
142
- transcribe_btn = gr.Button("Transcribe")
143
-
144
- transcription_output = gr.Textbox(label="Transcription", lines=10)
145
- docx_output = gr.File(label="Download DOCX")
146
- status_text = gr.Textbox(label="Status", interactive=False)
147
-
148
- # Update chunk choices after file upload
149
- def update_chunks(file):
150
- if file is None:
151
- return gr.update(choices=[])
152
- audio_file = extract_audio(file)
153
- _, labels = split_audio(audio_file)
154
- return gr.update(choices=labels, value=[])
155
-
156
- file_input.change(update_chunks, inputs=file_input, outputs=chunk_selector)
157
-
158
  transcribe_btn.click(
159
- fn=transcribe_selected,
160
- inputs=[file_input, model_selector, chunk_selector, summary_checkbox],
161
- outputs=[transcription_output, docx_output, gr.Textbox(visible=False), status_text]
162
  )
163
 
164
  if __name__ == "__main__":
165
- demo.launch()
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
  import os
 
5
  import datetime
6
  import time
 
 
 
7
  from transformers import pipeline
8
  from docx import Document
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Define the available models and their approximate relative speeds
11
  MODEL_SIZES = {
12
  "Tiny (Fastest)": "openai/whisper-tiny",
13
  "Base (Faster)": "openai/whisper-base",
14
  "Small (Balanced)": "openai/whisper-small",
15
  "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
16
+ "Distil-Large-v3-FR (French-Specific)": "eustlb/distil-large-v3-fr" # Corrected French-specific model
17
  }
18
 
19
+ # Use a dictionary to cache loaded models
20
  model_cache = {}
21
 
 
22
  def get_model_pipeline(model_name, progress):
23
  if model_name not in model_cache:
24
+ progress(0, desc="πŸš€ Initializing ZeroGPU instance...")
25
  model_id = MODEL_SIZES[model_name]
26
  device = 0 if torch.cuda.is_available() else "cpu"
27
+
28
+ progress(0.1, desc=f"⏳ Loading {model_name} model...")
29
  model_cache[model_name] = pipeline(
30
  "automatic-speech-recognition",
31
  model=model_id,
32
  device=device
33
  )
34
+ progress(0.5, desc="βœ… Model loaded successfully!")
35
  return model_cache[model_name]
36
 
37
+ def create_vtt(segments, file_path):
38
+ with open(file_path, "w", encoding="utf-8") as f:
39
+ f.write("WEBVTT\n\n")
40
+ for i, segment in enumerate(segments):
41
+ start_seconds = segment.get('start', 0)
42
+ end_seconds = segment.get('end', 0)
43
+ start = str(datetime.timedelta(seconds=int(start_seconds)))
44
+ end = str(datetime.timedelta(seconds=int(end_seconds)))
45
+ f.write(f"{i+1}\n")
46
+ f.write(f"{start} --> {end}\n")
47
+ f.write(f"{segment.get('text', '').strip()}\n\n")
48
+
49
+ def create_docx(segments, file_path, with_timestamps):
50
+ document = Document()
51
+ document.add_heading("Transcription", 0)
52
+
53
+ if with_timestamps:
54
+ for segment in segments:
55
+ text = segment.get('text', '').strip()
56
+ start_seconds = segment.get('start', 0)
57
+ end_seconds = segment.get('end', 0)
58
+ start = str(datetime.timedelta(seconds=int(start_seconds)))
59
+ end = str(datetime.timedelta(seconds=int(end_seconds)))
60
+ document.add_paragraph(f"[{start} - {end}] {text}")
61
+ else:
62
+ full_text = " ".join([segment.get('text', '').strip() for segment in segments])
63
+ document.add_paragraph(full_text)
64
+
65
+ document.save(file_path)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  @spaces.GPU
68
+ def transcribe_and_export(audio_file, model_size, vtt_output, docx_timestamp_output, docx_no_timestamp_output, progress=gr.Progress()):
69
+ if audio_file is None:
70
+ return (None, None, None, "Please upload an audio file.")
 
 
 
 
 
 
 
71
 
72
+ start_time = time.time()
73
+
74
  pipe = get_model_pipeline(model_size, progress)
75
+
76
+ progress(0.75, desc="🎀 Transcribing audio...")
77
+
78
+ # If the user selects the French-specific model, explicitly set the language
79
+ if model_size == "Distil-Large-v3-FR (French-Specific)":
80
+ raw_output = pipe(
81
+ audio_file,
82
+ return_timestamps=True,
83
+ generate_kwargs={"language": "fr"}
84
+ )
85
+ # For all other models, auto-detect the language
86
+ else:
87
+ raw_output = pipe(
88
+ audio_file,
89
+ return_timestamps=True,
90
+ )
91
 
92
+ segments = raw_output.get("chunks", [])
93
+ outputs = {}
94
+
95
+ progress(0.85, desc="πŸ“ Generating output files...")
96
+
97
+ if vtt_output:
98
+ vtt_path = "transcription.vtt"
99
+ create_vtt(segments, vtt_path)
100
+ outputs["VTT"] = vtt_path
101
+
102
+ if docx_timestamp_output:
103
+ docx_ts_path = "transcription_with_timestamps.docx"
104
+ create_docx(segments, docx_ts_path, with_timestamps=True)
105
+ outputs["DOCX (with timestamps)"] = docx_ts_path
106
+
107
+ if docx_no_timestamp_output:
108
+ docx_no_ts_path = "transcription_without_timestamps.docx"
109
+ create_docx(segments, docx_no_ts_path, with_timestamps=False)
110
+ outputs["DOCX (without timestamps)"] = docx_no_ts_path
111
+
112
+ end_time = time.time()
113
+ total_time = end_time - start_time
114
+ transcribed_text = raw_output['text']
115
+ downloadable_files = [path for path in outputs.values()]
116
+ status_message = f"βœ… Transcription complete! Total time: {total_time:.2f} seconds."
117
+
118
+ return (
119
+ transcribed_text,
120
+ gr.Files(value=downloadable_files, label="Download Transcripts"),
121
+ gr.Audio(value=None),
122
+ status_message
123
+ )
124
 
125
  # --- Gradio UI ---
126
+ with gr.Blocks(title="Whisper ZeroGPU Transcription") as demo:
127
+ gr.Markdown("# πŸŽ™οΈ Whisper ZeroGPU Transcription")
128
+ gr.Markdown("Transcribe audio with timestamps and choose your output format. The first run may take up to a minute due to cold start.")
129
+
130
  with gr.Row():
131
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio File")
132
+
133
+ with gr.Column(scale=2):
134
+ model_selector = gr.Dropdown(
135
+ label="Choose Whisper Model Size",
136
+ choices=list(MODEL_SIZES.keys()),
137
+ value="Distil-Large-v3-FR (French-Specific)" # Default to the French-specific model
138
+ )
139
+ gr.Markdown("### Choose Output Formats")
140
+ with gr.Row():
141
+ vtt_checkbox = gr.Checkbox(label="VTT", value=True)
142
+ docx_ts_checkbox = gr.Checkbox(label="DOCX (with timestamps)", value=False)
143
+ docx_no_ts_checkbox = gr.Checkbox(label="DOCX (without timestamps)", value=True)
144
+
145
+ transcribe_btn = gr.Button("Transcribe", variant="primary")
146
+ status_text = gr.Textbox(label="Status", interactive=False)
147
+
148
+ transcription_output = gr.Textbox(label="Full Transcription", lines=10)
149
+ downloadable_files_output = gr.Files(label="Download Transcripts")
150
+
 
 
 
 
 
151
  transcribe_btn.click(
152
+ fn=transcribe_and_export,
153
+ inputs=[audio_input, model_selector, vtt_checkbox, docx_ts_checkbox, docx_no_ts_checkbox],
154
+ outputs=[transcription_output, downloadable_files_output, audio_input, status_text]
155
  )
156
 
157
  if __name__ == "__main__":
158
+ demo.launch()