File size: 5,547 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
A test script to test the gRPC service for Moonshine transcription
"""
import unittest
import subprocess
import time
import os
import tempfile
import shutil
import backend_pb2
import backend_pb2_grpc

import grpc


class TestBackendServicer(unittest.TestCase):
    """
    TestBackendServicer is the class that tests the gRPC service
    """
    def setUp(self):
        """
        This method sets up the gRPC service by starting the server
        """
        self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
        time.sleep(10)

    def tearDown(self) -> None:
        """
        This method tears down the gRPC service by terminating the server
        """
        self.service.terminate()
        self.service.wait()

    def test_server_startup(self):
        """
        This method tests if the server starts up successfully
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.Health(backend_pb2.HealthMessage())
                self.assertEqual(response.message, b'OK')
        except Exception as err:
            print(err)
            self.fail("Server failed to start")
        finally:
            self.tearDown()

    def test_load_model(self):
        """
        This method tests if the model is loaded successfully
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="moonshine/tiny"))
                self.assertTrue(response.success)
                self.assertEqual(response.message, "Model loaded successfully")
        except Exception as err:
            print(err)
            self.fail("LoadModel service failed")
        finally:
            self.tearDown()

    def test_audio_transcription(self):
        """
        This method tests if audio transcription works successfully
        """
        # Create a temporary directory for the audio file
        temp_dir = tempfile.mkdtemp()
        audio_file = os.path.join(temp_dir, 'audio.wav')
        
        try:
            # Download the audio file to the temporary directory
            print(f"Downloading audio file to {audio_file}...")
            url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
            result = subprocess.run(
                ["wget", "-q", url, "-O", audio_file],
                capture_output=True,
                text=True
            )
            if result.returncode != 0:
                self.fail(f"Failed to download audio file: {result.stderr}")
            
            # Verify the file was downloaded
            if not os.path.exists(audio_file):
                self.fail(f"Audio file was not downloaded to {audio_file}")
            
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                # Load the model first
                load_response = stub.LoadModel(backend_pb2.ModelOptions(Model="moonshine/tiny"))
                self.assertTrue(load_response.success)
                
                # Perform transcription
                transcript_request = backend_pb2.TranscriptRequest(dst=audio_file)
                transcript_response = stub.AudioTranscription(transcript_request)
                
                # Print the transcribed text for debugging
                print(f"Transcribed text: {transcript_response.text}")
                print(f"Number of segments: {len(transcript_response.segments)}")
                
                # Verify response structure
                self.assertIsNotNone(transcript_response)
                self.assertIsNotNone(transcript_response.text)
                # Protobuf repeated fields return a sequence, not a list
                self.assertIsNotNone(transcript_response.segments)
                # Check if segments is iterable (has length)
                self.assertGreaterEqual(len(transcript_response.segments), 0)
                
                # Verify the transcription contains the expected text
                expected_text = "This is the micro machine man presenting the most midget miniature"
                self.assertIn(
                    expected_text.lower(),
                    transcript_response.text.lower(),
                    f"Expected text '{expected_text}' not found in transcription: '{transcript_response.text}'"
                )
                
                # If we got segments, verify they have the expected structure
                if len(transcript_response.segments) > 0:
                    segment = transcript_response.segments[0]
                    self.assertIsNotNone(segment.text)
                    self.assertIsInstance(segment.id, int)
                else:
                    # Even if no segments, we should have text
                    self.assertIsNotNone(transcript_response.text)
                    self.assertGreater(len(transcript_response.text), 0)
        except Exception as err:
            print(err)
            self.fail("AudioTranscription service failed")
        finally:
            self.tearDown()
            # Clean up the temporary directory
            if os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)