adiitya29 commited on
Commit
bddec1e
·
1 Parent(s): 8dced3a

Initialized project directories, added requirements, and implemented core Gradio UI with lazy-loaded Wav2Vec2 inference

Browse files
app.py CHANGED
@@ -1,18 +1,18 @@
1
  import gradio as gr
2
  from app.asr_model import load_model, transcribe_audio
3
- from app.language_detection import detect_language
4
  from app.history import save_to_history, export_history
5
 
6
  def process_audio(audio_path):
7
  if audio_path is None:
8
  return "No audio uploaded.", "Unknown"
9
 
10
- # Optional: Detect Language
11
- lang = detect_language(audio_path)
12
-
13
  # Transcribe Speech
14
  transcript = transcribe_audio(audio_path)
15
 
 
 
 
16
  # Save History
17
  save_to_history(audio_path, transcript, lang)
18
 
@@ -42,10 +42,5 @@ def create_ui():
42
  return demo
43
 
44
  if __name__ == "__main__":
45
- # Pre-load model on start
46
- print("Loading model...")
47
- load_model()
48
- print("Model loaded. Starting UI...")
49
-
50
  demo = create_ui()
51
  demo.launch()
 
1
  import gradio as gr
2
  from app.asr_model import load_model, transcribe_audio
3
+ from app.language_detection import detect_language_from_text
4
  from app.history import save_to_history, export_history
5
 
6
  def process_audio(audio_path):
7
  if audio_path is None:
8
  return "No audio uploaded.", "Unknown"
9
 
 
 
 
10
  # Transcribe Speech
11
  transcript = transcribe_audio(audio_path)
12
 
13
+ # Detect Language from transcript
14
+ lang = detect_language_from_text(transcript)
15
+
16
  # Save History
17
  save_to_history(audio_path, transcript, lang)
18
 
 
42
  return demo
43
 
44
  if __name__ == "__main__":
 
 
 
 
 
45
  demo = create_ui()
46
  demo.launch()
app/asr_model.py CHANGED
@@ -1,15 +1,57 @@
1
- # This module handles the loading and inferencing of the Wav2Vec model
 
 
2
 
3
- def load_model():
 
 
 
 
4
  """
5
  Loads the Hugging Face Wav2Vec model and processor.
6
- For Apple Silicon, we can utilize MPS (Metal Performance Shaders) later.
 
 
7
  """
8
- pass
 
 
 
 
 
 
 
 
 
9
 
10
  def transcribe_audio(audio_filepath: str) -> str:
11
  """
12
  Takes an audio filepath, processes it, and runs it through the Wav2Vec model
13
  to return a text transcription.
14
  """
15
- return "This is a placeholder transcription. Model integration is pending."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
3
+ from app.audio_processing import load_and_resample
4
 
5
+ # Global variables to hold the model and processor
6
+ processor = None
7
+ model = None
8
+
9
+ def load_model(model_name: str = "facebook/wav2vec2-base-960h"):
10
  """
11
  Loads the Hugging Face Wav2Vec model and processor.
12
+ Defaulting to English base model. For multilingual, consider models like:
13
+ - 'facebook/mms-1b-all'
14
+ - 'jonatasgrosman/wav2vec2-large-xlsr-53-english' (or other languages)
15
  """
16
+ global processor, model
17
+ print(f"Loading model {model_name}...")
18
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
19
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
20
+
21
+ # Move to GPU if available (MPS for Apple Silicon)
22
+ if torch.backends.mps.is_available():
23
+ model.to("mps")
24
+ elif torch.cuda.is_available():
25
+ model.to("cuda")
26
 
27
  def transcribe_audio(audio_filepath: str) -> str:
28
  """
29
  Takes an audio filepath, processes it, and runs it through the Wav2Vec model
30
  to return a text transcription.
31
  """
32
+ if model is None or processor is None:
33
+ load_model()
34
+
35
+ try:
36
+ # 1. Load and resample audio to 16kHz
37
+ speech = load_and_resample(audio_filepath, target_sr=16000)
38
+
39
+ # 2. Prepare inputs
40
+ inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
41
+
42
+ # Move inputs to the same device as model
43
+ device = next(model.parameters()).device
44
+ inputs = {k: v.to(device) for k, v in inputs.items()}
45
+
46
+ # 3. Perform inference
47
+ with torch.no_grad():
48
+ logits = model(**inputs).logits
49
+
50
+ # 4. Decode the output
51
+ predicted_ids = torch.argmax(logits, dim=-1)
52
+ transcription = processor.batch_decode(predicted_ids)[0]
53
+
54
+ return transcription.lower()
55
+ except Exception as e:
56
+ return f"Error during transcription: {str(e)}"
57
+
app/audio_processing.py CHANGED
@@ -1,7 +1,21 @@
1
- # This module handles audio preprocessing using libraries like librosa
 
2
 
3
- def load_and_resample(audio_filepath: str, target_sr: int = 16000):
4
  """
5
  Loads an audio file and resamples it to the target sample rate (default 16kHz for Wav2Vec).
 
 
 
 
 
 
 
6
  """
7
- pass
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
 
4
+ def load_and_resample(audio_filepath: str, target_sr: int = 16000) -> np.ndarray:
5
  """
6
  Loads an audio file and resamples it to the target sample rate (default 16kHz for Wav2Vec).
7
+
8
+ Args:
9
+ audio_filepath (str): Path to the audio file.
10
+ target_sr (int): The sample rate required by the model.
11
+
12
+ Returns:
13
+ np.ndarray: The audio time series.
14
  """
15
+ try:
16
+ # librosa automatically resamples if sr is provided
17
+ speech, _ = librosa.load(audio_filepath, sr=target_sr)
18
+ return speech
19
+ except Exception as e:
20
+ raise RuntimeError(f"Error processing audio file {audio_filepath}: {e}")
21
+
app/history.py CHANGED
@@ -1,13 +1,57 @@
1
- # This module manages saving transcriptions to history and exporting them
 
 
 
 
 
2
 
3
  def save_to_history(audio_filepath: str, transcript: str, language: str):
4
  """
5
- Saves the transcription data to a local JSON or CSV file in the data/ directory.
6
  """
7
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def export_history(format: str = "csv"):
10
  """
11
  Exports the saved history into a downloadable format.
 
12
  """
13
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import datetime
4
+ import csv
5
+
6
+ HISTORY_FILE = "data/history.json"
7
 
8
  def save_to_history(audio_filepath: str, transcript: str, language: str):
9
  """
10
+ Saves the transcription data to a local JSON file in the data/ directory.
11
  """
12
+ os.makedirs(os.path.dirname(HISTORY_FILE), exist_ok=True)
13
+
14
+ entry = {
15
+ "timestamp": datetime.datetime.now().isoformat(),
16
+ "audio_file": os.path.basename(audio_filepath),
17
+ "language": language,
18
+ "transcript": transcript
19
+ }
20
+
21
+ history = []
22
+ if os.path.exists(HISTORY_FILE):
23
+ try:
24
+ with open(HISTORY_FILE, "r") as f:
25
+ history = json.load(f)
26
+ except Exception:
27
+ pass
28
+
29
+ history.append(entry)
30
+
31
+ with open(HISTORY_FILE, "w") as f:
32
+ json.dump(history, f, indent=4)
33
 
34
  def export_history(format: str = "csv"):
35
  """
36
  Exports the saved history into a downloadable format.
37
+ Returns the path to the exported file.
38
  """
39
+ if not os.path.exists(HISTORY_FILE):
40
+ return None
41
+
42
+ export_path = f"data/export_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
43
+
44
+ try:
45
+ with open(HISTORY_FILE, "r") as f:
46
+ history = json.load(f)
47
+
48
+ with open(export_path, "w", newline='') as f:
49
+ writer = csv.DictWriter(f, fieldnames=["timestamp", "audio_file", "language", "transcript"])
50
+ writer.writeheader()
51
+ for row in history:
52
+ writer.writerow(row)
53
+
54
+ return export_path
55
+ except Exception as e:
56
+ print(f"Failed to export history: {e}")
57
+ return None
app/language_detection.py CHANGED
@@ -1,8 +1,19 @@
1
- # This module handles language detection logic
2
 
3
- def detect_language(audio_filepath: str) -> str:
 
 
 
4
  """
5
- Optional feature to detect the spoken language in the audio file.
6
- Could use a separate small classification model or an API.
7
  """
8
- return "English (Placeholder)"
 
 
 
 
 
 
 
 
 
1
+ from langdetect import detect, DetectorFactory
2
 
3
+ # Ensure consistent results
4
+ DetectorFactory.seed = 0
5
+
6
+ def detect_language_from_text(text: str) -> str:
7
  """
8
+ Detects language based on the transcribed text.
9
+ Returns the ISO 639-1 language code (e.g., 'en', 'es', 'fr').
10
  """
11
+ if not text or len(text.strip()) < 2:
12
+ return "Unknown"
13
+
14
+ try:
15
+ lang = detect(text)
16
+ return lang
17
+ except Exception as e:
18
+ return "Unknown"
19
+