Krish-05 commited on
Commit
a9ca228
·
verified ·
1 Parent(s): e563ca3

Update stt_module.py

Browse files
Files changed (1) hide show
  1. stt_module.py +37 -82
stt_module.py CHANGED
@@ -1,85 +1,40 @@
1
- import os
2
- import logging
3
- import asyncio
4
- from typing import Optional
5
- from faster_whisper import WhisperModel
6
 
7
- logger = logging.getLogger(__name__)
8
-
9
- # Global model variable for singleton pattern
10
- _stt_model: Optional[WhisperModel] = None
11
- _model_initialized = False
12
-
13
- def initialize_stt():
14
- """Initializes the Whisper model globally if not already initialized."""
15
- global _stt_model, _model_initialized
16
- if _model_initialized:
17
- logger.info("STT model already initialized.")
18
- return True
19
-
20
- try:
21
- logger.info("Loading Whisper model (base) on CPU...")
22
- # Explicitly set device to CPU and compute type to int8 for better performance on CPU.
23
- # Consider 'tiny' or 'small' for faster inference on limited CPU resources.
24
- _stt_model = WhisperModel(
25
- "base", # You can try "tiny" or "small" for faster but less accurate results
26
- device="cpu",
27
- compute_type="int8" # For CPU optimization
28
- )
29
- _model_initialized = True
30
- logger.info("STT model initialized successfully on CPU.")
31
- return True
32
- except Exception as e:
33
- logger.error(f"Failed to initialize STT model: {e}")
34
- _model_initialized = False # Mark as failed
35
- return False
36
-
37
- def get_stt_model() -> Optional[WhisperModel]:
38
- """Returns the initialized STT model, initializing it if necessary."""
39
- if not _model_initialized:
40
- initialize_stt()
41
- return _stt_model
42
-
43
- async def transcribe_audio_file(audio_path: str) -> Optional[str]:
44
  """
45
- Asynchronously transcribes an audio file to text using faster_whisper.
46
- Wraps the synchronous faster_whisper transcribe call in an asyncio.to_thread
47
- to prevent blocking the FastAPI event loop.
48
  """
49
- model = get_stt_model()
50
- if model is None:
51
- logger.error("STT model is not loaded. Cannot transcribe audio.")
52
- return None
53
-
54
- if not os.path.exists(audio_path):
55
- logger.error(f"Audio file not found for transcription: {audio_path}")
56
- return None
57
- if os.path.getsize(audio_path) == 0:
58
- logger.warning(f"Audio file is empty: {audio_path}")
59
- return ""
60
-
61
- logger.info(f"Starting transcription of {audio_path}...")
62
- try:
63
- # Run the synchronous transcription in a separate thread
64
- segments, info = await asyncio.to_thread(
65
- model.transcribe,
66
- audio_path,
67
- beam_size=5, # Number of beams for beam search, common value
68
- vad_filter=True # Use Voice Activity Detection to filter out non-speech segments
69
- )
70
-
71
- text_segments = []
72
- for segment in segments:
73
- if segment.text.strip():
74
- text_segments.append(segment.text.strip())
75
-
76
- transcribed_text = " ".join(text_segments)
77
- logger.info(f"Transcription complete. Detected language: {info.language} with probability {info.language_probability:.4f}. Text: {transcribed_text[:100]}...")
78
- return transcribed_text
79
- except Exception as e:
80
- logger.error(f"Error during audio transcription: {e}", exc_info=True)
81
- return None
82
-
83
- def is_model_loaded() -> bool:
84
- """Checks if the STT model is loaded and ready."""
85
- return _stt_model is not None and _model_initialized
 
1
+ import threading
2
+ import pydub
3
+ import av
4
+ import streamlit as st # Only imported for st.session_state access in recv method
5
+ from streamlit_webrtc import AudioProcessorBase
6
 
7
+ class AudioBufferProcessor(AudioProcessorBase):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
+ An audio processor that buffers incoming audio frames.
10
+ It accumulates audio only when `st.session_state.is_recording` is True.
 
11
  """
12
+ def __init__(self) -> None:
13
+ self._audio_buffer = pydub.AudioSegment.empty()
14
+ self._lock = threading.Lock() # Use a lock for thread-safe access to the buffer
15
+
16
+ def recv(self, frame: av.AudioFrame) -> None:
17
+ """
18
+ Receives audio frames from the WebRTC stream.
19
+ If recording is active, appends the frame to the internal buffer.
20
+ """
21
+ if st.session_state.is_recording:
22
+ sound = pydub.AudioSegment(
23
+ data=frame.to_ndarray().tobytes(),
24
+ sample_width=frame.format.bytes,
25
+ frame_rate=frame.sample_rate,
26
+ channels=len(frame.layout.channels),
27
+ )
28
+ sound = sound.set_channels(1).set_frame_rate(16000)
29
+ with self._lock:
30
+ self._audio_buffer += sound
31
+
32
+ def get_and_clear_buffered_audio(self) -> pydub.AudioSegment:
33
+ """
34
+ Retrieves the accumulated audio and clears the buffer.
35
+ This method is called when recording stops.
36
+ """
37
+ with self._lock:
38
+ recorded_audio = self._audio_buffer
39
+ self._audio_buffer = pydub.AudioSegment.empty() # Clear the buffer
40
+ return recorded_audio