clementBE commited on
Commit
4091834
Β·
verified Β·
1 Parent(s): 3872026

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -17
app.py CHANGED
@@ -13,16 +13,22 @@ MODEL_SIZES = {
13
  "Base (Faster)": "openai/whisper-base",
14
  "Small (Balanced)": "openai/whisper-small",
15
  "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
16
- "Distil-Large-v3-FR (French-Specific)": "distil-whisper/distil-large-v3-fr" # New, French-specific model
 
 
17
  }
18
 
19
  # Use a dictionary to cache loaded models
20
  model_cache = {}
21
 
22
  def get_model_pipeline(model_name, progress):
 
 
 
23
  if model_name not in model_cache:
24
  progress(0, desc="πŸš€ Initializing ZeroGPU instance...")
25
  model_id = MODEL_SIZES[model_name]
 
26
  device = 0 if torch.cuda.is_available() else "cpu"
27
 
28
  progress(0.1, desc=f"⏳ Loading {model_name} model...")
@@ -35,37 +41,54 @@ def get_model_pipeline(model_name, progress):
35
  return model_cache[model_name]
36
 
37
  def create_vtt(segments, file_path):
 
 
 
38
  with open(file_path, "w", encoding="utf-8") as f:
39
  f.write("WEBVTT\n\n")
40
  for i, segment in enumerate(segments):
41
- start_seconds = segment.get('start', 0)
42
- end_seconds = segment.get('end', 0)
43
- start = str(datetime.timedelta(seconds=int(start_seconds)))
44
- end = str(datetime.timedelta(seconds=int(end_seconds)))
 
 
 
 
 
 
 
 
 
45
  f.write(f"{i+1}\n")
46
  f.write(f"{start} --> {end}\n")
47
  f.write(f"{segment.get('text', '').strip()}\n\n")
48
 
49
  def create_docx(segments, file_path, with_timestamps):
 
 
 
50
  document = Document()
51
  document.add_heading("Transcription", 0)
52
 
53
  if with_timestamps:
54
  for segment in segments:
55
  text = segment.get('text', '').strip()
56
- start_seconds = segment.get('start', 0)
57
- end_seconds = segment.get('end', 0)
58
- start = str(datetime.timedelta(seconds=int(start_seconds)))
59
- end = str(datetime.timedelta(seconds=int(end_seconds)))
60
  document.add_paragraph(f"[{start} - {end}] {text}")
61
  else:
62
  full_text = " ".join([segment.get('text', '').strip() for segment in segments])
63
  document.add_paragraph(full_text)
64
-
65
  document.save(file_path)
66
 
67
  @spaces.GPU
68
  def transcribe_and_export(audio_file, model_size, vtt_output, docx_timestamp_output, docx_no_timestamp_output, progress=gr.Progress()):
 
 
 
69
  if audio_file is None:
70
  return (None, None, None, "Please upload an audio file.")
71
 
@@ -75,22 +98,29 @@ def transcribe_and_export(audio_file, model_size, vtt_output, docx_timestamp_out
75
 
76
  progress(0.75, desc="🎀 Transcribing audio...")
77
 
78
- # Forcing French for the new specific model
79
- # Note: If the user picks a different model, the language auto-detection will work as normal.
80
  if model_size == "Distil-Large-v3-FR (French-Specific)":
 
81
  raw_output = pipe(
82
  audio_file,
83
- return_timestamps=True,
84
  generate_kwargs={"language": "fr"}
85
  )
86
  else:
87
- # For other models, let the model auto-detect
88
  raw_output = pipe(
89
  audio_file,
90
- return_timestamps=True,
91
  )
92
 
 
93
  segments = raw_output.get("chunks", [])
 
 
 
 
 
 
94
  outputs = {}
95
 
96
  progress(0.85, desc="πŸ“ Generating output files...")
@@ -119,7 +149,7 @@ def transcribe_and_export(audio_file, model_size, vtt_output, docx_timestamp_out
119
  return (
120
  transcribed_text,
121
  gr.Files(value=downloadable_files, label="Download Transcripts"),
122
- gr.Audio(value=None),
123
  status_message
124
  )
125
 
@@ -135,7 +165,8 @@ with gr.Blocks(title="Whisper ZeroGPU Transcription") as demo:
135
  model_selector = gr.Dropdown(
136
  label="Choose Whisper Model Size",
137
  choices=list(MODEL_SIZES.keys()),
138
- value="Distil-Large-v3-FR (French-Specific)" # Default to the French-specific model
 
139
  )
140
  gr.Markdown("### Choose Output Formats")
141
  with gr.Row():
 
13
  "Base (Faster)": "openai/whisper-base",
14
  "Small (Balanced)": "openai/whisper-small",
15
  "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
16
+ # FIX: The model 'distil-whisper/distil-large-v3-fr' does not exist.
17
+ # We use the general distil-large-v3 and rely on the code below to force French.
18
+ "Distil-Large-v3-FR (French-Specific)": "distil-whisper/distil-large-v3"
19
  }
20
 
21
  # Use a dictionary to cache loaded models
22
  model_cache = {}
23
 
24
  def get_model_pipeline(model_name, progress):
25
+ """
26
+ Initializes and caches the ASR pipeline for a given model name.
27
+ """
28
  if model_name not in model_cache:
29
  progress(0, desc="πŸš€ Initializing ZeroGPU instance...")
30
  model_id = MODEL_SIZES[model_name]
31
+ # Use GPU if available, otherwise fallback to CPU
32
  device = 0 if torch.cuda.is_available() else "cpu"
33
 
34
  progress(0.1, desc=f"⏳ Loading {model_name} model...")
 
41
  return model_cache[model_name]
42
 
43
  def create_vtt(segments, file_path):
44
+ """
45
+ Creates a WebVTT (.vtt) file from transcription segments.
46
+ """
47
  with open(file_path, "w", encoding="utf-8") as f:
48
  f.write("WEBVTT\n\n")
49
  for i, segment in enumerate(segments):
50
+ # Calculate time strings in "HH:MM:SS.mmm" format (though VTT only strictly requires up to milliseconds)
51
+ start_ms = int(segment.get('start', 0) * 1000)
52
+ end_ms = int(segment.get('end', 0) * 1000)
53
+
54
+ def format_time(ms):
55
+ hours, remainder = divmod(ms, 3600000)
56
+ minutes, remainder = divmod(remainder, 60000)
57
+ seconds, milliseconds = divmod(remainder, 1000)
58
+ return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}.{int(milliseconds):03}"
59
+
60
+ start = format_time(start_ms)
61
+ end = format_time(end_ms)
62
+
63
  f.write(f"{i+1}\n")
64
  f.write(f"{start} --> {end}\n")
65
  f.write(f"{segment.get('text', '').strip()}\n\n")
66
 
67
  def create_docx(segments, file_path, with_timestamps):
68
+ """
69
+ Creates a DOCX (.docx) file from transcription segments.
70
+ """
71
  document = Document()
72
  document.add_heading("Transcription", 0)
73
 
74
  if with_timestamps:
75
  for segment in segments:
76
  text = segment.get('text', '').strip()
77
+ # Format time as HH:MM:SS for DOCX
78
+ start = str(datetime.timedelta(seconds=int(segment.get('start', 0))))
79
+ end = str(datetime.timedelta(seconds=int(segment.get('end', 0))))
 
80
  document.add_paragraph(f"[{start} - {end}] {text}")
81
  else:
82
  full_text = " ".join([segment.get('text', '').strip() for segment in segments])
83
  document.add_paragraph(full_text)
84
+
85
  document.save(file_path)
86
 
87
  @spaces.GPU
88
  def transcribe_and_export(audio_file, model_size, vtt_output, docx_timestamp_output, docx_no_timestamp_output, progress=gr.Progress()):
89
+ """
90
+ Main function to transcribe audio and export to selected formats.
91
+ """
92
  if audio_file is None:
93
  return (None, None, None, "Please upload an audio file.")
94
 
 
98
 
99
  progress(0.75, desc="🎀 Transcribing audio...")
100
 
101
+ # Check if the French-specific model option was selected
 
102
  if model_size == "Distil-Large-v3-FR (French-Specific)":
103
+ # Force French for this specific option
104
  raw_output = pipe(
105
  audio_file,
106
+ return_timestamps="word", # Use word-level timestamps for more detail if needed, but 'True' works for chunk timestamps too
107
  generate_kwargs={"language": "fr"}
108
  )
109
  else:
110
+ # For other models, let the model auto-detect the language
111
  raw_output = pipe(
112
  audio_file,
113
+ return_timestamps="word",
114
  )
115
 
116
+ # Use 'chunks' if available, otherwise default to the whole text
117
  segments = raw_output.get("chunks", [])
118
+
119
+ # If no chunks are returned (e.g., if return_timestamps=False was used, though not in this code),
120
+ # create a single segment from the full text.
121
+ if not segments and 'text' in raw_output:
122
+ segments = [{'text': raw_output['text'].strip(), 'start': 0.0, 'end': 0.0}]
123
+
124
  outputs = {}
125
 
126
  progress(0.85, desc="πŸ“ Generating output files...")
 
149
  return (
150
  transcribed_text,
151
  gr.Files(value=downloadable_files, label="Download Transcripts"),
152
+ gr.Audio(value=None), # Clear the audio input
153
  status_message
154
  )
155
 
 
165
  model_selector = gr.Dropdown(
166
  label="Choose Whisper Model Size",
167
  choices=list(MODEL_SIZES.keys()),
168
+ # Default to the French-specific model, which now uses the correct ID
169
+ value="Distil-Large-v3-FR (French-Specific)"
170
  )
171
  gr.Markdown("### Choose Output Formats")
172
  with gr.Row():