File size: 5,635 Bytes
3dcada4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from configs import get_settings
import asyncio
# import librosa
# import numpy as np
from stores.sttremotes import STTRemoteManager
from faster_whisper.audio import decode_audio  # handles webm natively

class TranscriptionController:    
    def __init__(self,models,logger,remotename):
        self.settings = get_settings()
        self.models = models
        self.logger = logger
        self.remote_max_request_rate=60
        self.remotename=remotename
        self.remote=STTRemoteManager(default_provider=remotename) if remotename else None

    async def transcribe_audio(self,audio_path: str):
        if self.settings.INFERENCE_TYPE == "local":
            return await self.transcribe_local(audio_path)
        elif self.settings.INFERENCE_TYPE == "remote":
            return await self.transcribe_remote(audio_path)
        else:
            raise ValueError(f"Unsupported INFERENCE_TYPE: {self.settings.INFERENCE_TYPE}")



    async def language_detection(self, audio_path: str):
        if self.settings.INFERENCE_TYPE == "local":
            model_size = self.settings.LOCAL_INFERENCE_MODEL_SIZE
            model = self.models.get(f"{model_size}_english")
        if not model:
            raise ValueError(f"Model {model_size}_language_detection not available")

        print(f"Detecting language for {audio_path} with {model_size} model...")

        def process():
            waveform = decode_audio(audio_path)
            language, probability,_ = model.detect_language(waveform)
            return language, probability
        loop = asyncio.get_event_loop()
        language, language_probability = await loop.run_in_executor(None, process)

        return language, language_probability

    
    async def transcribe_local(self,audio_path: str):
        language, probability = await self.language_detection(audio_path)
        if language == "ar":
            self.logger.info(f"Processing Arabic audio with probability {probability:.2f}")
            return await self.transcribe_local_arabic(audio_path)
        elif language == "en":
            self.logger.info(f"Processing English audio with probability {probability:.2f}")
            return await self.transcribe_local_english(audio_path)
        else:
            self.logger.warning(f"Unsupported language detected: {language}. Skipping transcription.")
            return None, language
    
    async def transcribe_local_arabic(self,audio_path: str):
        if self.settings.INFERENCE_TYPE == "local":
            model_size = self.settings.LOCAL_INFERENCE_MODEL_SIZE
            model=self.models.get(f"{model_size}_arabic")
        if not model:
            raise ValueError(f"Model {model_size}_arabic not available")

        print(f"Transcribing {audio_path} with {model_size} model...")

        ALLOWED_LANGUAGES = ['ar']

        def process_with_filter():
            segments, info = model.transcribe(
                audio_path, 
                beam_size=5, 
                best_of=5, 
                language="ar",
                vad_filter=True,
                vad_parameters=dict(min_silence_duration_ms=500,threshold=0.3)
            )
            if info.language not in ALLOWED_LANGUAGES:
                self.logger.info(f"Skipping: Detected {info.language} with prob {info.language_probability:.2f}")
                return None, info.language

            full_text = ""
            for segment in segments:
                full_text += segment.text + " "
            
            return full_text.strip(), info.language

        loop = asyncio.get_event_loop()
        text, language = await loop.run_in_executor(None, process_with_filter)

        return text, language
    
    async def transcribe_local_english(self,audio_path: str):
        if self.settings.INFERENCE_TYPE == "local":
            model_size = self.settings.LOCAL_INFERENCE_MODEL_SIZE
            model=self.models.get(f"{model_size}_english")
        if not model:
            raise ValueError(f"Model {model_size}_english not available")

        print(f"Transcribing {audio_path} with {model_size} model...")

        ALLOWED_LANGUAGES = ['en']

        def process_with_filter():
            segments, info = model.transcribe(
                audio_path, 
                beam_size=5, 
                best_of=5, 
                language="en",
                vad_filter=True,
                vad_parameters=dict(min_silence_duration_ms=500,threshold=0.3)
            )
            if info.language not in ALLOWED_LANGUAGES:
                self.logger.info(f"Skipping: Detected {info.language} with prob {info.language_probability:.2f}")
                return None, info.language

            full_text = ""
            for segment in segments:
                full_text += segment.text + " "
            
            return full_text.strip(), info.language

        loop = asyncio.get_event_loop()
        text, language = await loop.run_in_executor(None, process_with_filter)

        return text, language
    
    async def transcribe_remote(self,audio_path: str):
        if not self.remote:
            raise ValueError("Remote STT provider not configured")
        
        if not hasattr(self, "_last_request_time"):
            self._last_request_time = 0
        
        elapsed = asyncio.get_event_loop().time() - self._last_request_time
        if elapsed < 1 / self.remote_max_request_rate:
            await asyncio.sleep((1 / self.remote_max_request_rate) - elapsed)

        self._last_request_time = asyncio.get_event_loop().time()
        return await self.remote.transcribe_remote(audio_path,self.remotename)