VTuberAI / src /utils /operations /stt /kobold.py
Saidie000's picture
Upload 90 files
1905805 verified
from io import BytesIO
import wave
import requests
import base64
from utils.config import Config
from utils.processes import ProcessManager, ProcessType
from .base import STTOperation
class KoboldSTT(STTOperation):
KOBOLD_LINK_ID = "kobold_stt"
def __init__(self):
super().__init__("kobold")
self.uri = None
self.suppress_non_speech: bool = True
self.langcode: str = "en"
async def start(self) -> None:
'''General setup needed to start generated'''
await super().start()
await ProcessManager().link(self.KOBOLD_LINK_ID, ProcessType.KOBOLD)
self.uri = "http://127.0.0.1:{}".format(ProcessManager().get_process(ProcessType.KOBOLD).port)
async def close(self) -> None:
'''Clean up resources before unloading'''
await super().close()
await ProcessManager().unlink(self.KOBOLD_LINK_ID, ProcessType.KOBOLD)
async def configure(self, config_d):
'''Configure and validate operation-specific configuration'''
if "suppress_non_speech" in config_d: self.suppress_non_speech = bool(config_d['suppress_non_speech'])
if "langcode" in config_d: self.langcode = str(config_d['langcode'])
assert self.langcode is not None and len(self.langcode) > 0
async def get_configuration(self):
'''Returns values of configurable fields'''
return {
"suppress_non_speech": self.suppress_non_speech,
"langcode": self.langcode
}
async def _generate(self, prompt: str = None, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs):
'''Generate a output stream'''
audio_data = BytesIO()
with wave.open(audio_data, 'wb') as f:
f.setframerate(sr)
f.setsampwidth(sw)
f.setnchannels(ch)
f.writeframes(audio_bytes)
audio_data.seek(0)
response = requests.post(
"{}/api/extra/transcribe".format(self.uri),
json={
"prompt": prompt,
"suppress_non_speech": self.suppress_non_speech,
"langcode": self.langcode,
"audio_data": base64.b64encode(audio_data.read()).decode('utf-8')
},
)
if response.status_code == 200:
result = response.json()['text']
yield {"transcription": result}
else:
raise Exception(f"Failed to get STT result: {response.status_code} {response.reason}")