File size: 1,883 Bytes
b679b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from smolagents import SpeechToTextTool, Tool

class EnglishSpeechToTextTool(SpeechToTextTool):
    def encode(self, audio):
        from smolagents.agent_types import AgentAudio

        audio = AgentAudio(audio).to_raw()
        return self.pre_processor(audio, return_tensors="pt", sampling_rate=16_000)

    def forward(self, inputs):
        return self.model.generate(inputs["input_features"], language="en")

    def decode(self, outputs):
        return "## Transcription\n\n" + self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]


class GoogleSTTTool(Tool):
    description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
    name = "transcriber"
    inputs = {
        "audio": {
            "type": "audio",
            "description": "The audio to transcribe. Can be a local path, an url, or a tensor.",
        }
    }
    output_type = "string"

    def forward(self, inputs):
        from google.cloud.speech_v2 import SpeechClient
        from google.cloud.speech_v2.types import cloud_speech
        audio_file = inputs["audio"]

        with open(audio_file, "rb") as f:
            audio_content = f.read()

        # Instantiates a client
        client = SpeechClient()

        config = cloud_speech.RecognitionConfig(
            auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(),
            language_codes=["en-US"],
            model="long",
        )

        request = cloud_speech.RecognizeRequest(
            recognizer=f"projects/{PROJECT_ID}/locations/global/recognizers/_",
            config=config,
            content=audio_content,
        )

        # Transcribes the audio into text
        response = client.recognize(request=request)

        for result in response.results:
            print(f"Transcript: {result.alternatives[0].transcript}")

        return response