File size: 2,571 Bytes
1905805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from io import BytesIO
import wave
import requests
import base64

from utils.config import Config
from utils.processes import ProcessManager, ProcessType

from .base import STTOperation

class KoboldSTT(STTOperation):
    KOBOLD_LINK_ID = "kobold_stt"
    
    def __init__(self):
        super().__init__("kobold")
        self.uri = None
        
        self.suppress_non_speech: bool = True
        self.langcode: str = "en"
        
    async def start(self) -> None:
        '''General setup needed to start generated'''
        await super().start()
        await ProcessManager().link(self.KOBOLD_LINK_ID, ProcessType.KOBOLD)
        self.uri = "http://127.0.0.1:{}".format(ProcessManager().get_process(ProcessType.KOBOLD).port)
    
    async def close(self) -> None:
        '''Clean up resources before unloading'''
        await super().close()
        await ProcessManager().unlink(self.KOBOLD_LINK_ID, ProcessType.KOBOLD)
                
    async def configure(self, config_d):
        '''Configure and validate operation-specific configuration'''
        if "suppress_non_speech" in config_d: self.suppress_non_speech = bool(config_d['suppress_non_speech'])
        if "langcode" in config_d: self.langcode = str(config_d['langcode'])
        
        assert self.langcode is not None and len(self.langcode) > 0
        
    async def get_configuration(self):
        '''Returns values of configurable fields'''
        return {
            "suppress_non_speech": self.suppress_non_speech,
            "langcode": self.langcode
        }

    async def _generate(self, prompt: str = None,  audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs):
        '''Generate a output stream'''
        audio_data = BytesIO()
        with wave.open(audio_data, 'wb') as f:
            f.setframerate(sr)
            f.setsampwidth(sw)
            f.setnchannels(ch)
            f.writeframes(audio_bytes)
        audio_data.seek(0)

        response = requests.post(
            "{}/api/extra/transcribe".format(self.uri), 
            json={
                "prompt": prompt,
                "suppress_non_speech": self.suppress_non_speech,
                "langcode": self.langcode,
                "audio_data": base64.b64encode(audio_data.read()).decode('utf-8')
            },
        )

        if response.status_code == 200:
            result = response.json()['text']
            yield {"transcription": result}
        else:
            raise Exception(f"Failed to get STT result: {response.status_code} {response.reason}")